optimize lora compute
This commit is contained in:
parent
82bbea2729
commit
b5af5c00e1
|
|
@ -195,19 +195,15 @@ def make_fused_mlp(m, parent_name=''):
|
|||
|
||||
class CustomLoraLayerMerged(torch.nn.Module):
|
||||
|
||||
def __init__(self, scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v):
|
||||
def __init__(self, lora_A, lora_B):
|
||||
super().__init__()
|
||||
self.lora_A_q = lora_A_q
|
||||
self.lora_B_q = lora_B_q
|
||||
self.lora_A_v = lora_A_v
|
||||
self.lora_B_v = lora_B_v
|
||||
self.scaling = scaling
|
||||
self.lora_A = torch.nn.Parameter(lora_A, requires_grad=False)
|
||||
self.lora_B = torch.nn.Parameter(lora_B, requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
q = self.lora_B_q(self.lora_A_q(x)) * self.scaling
|
||||
v = self.lora_B_v(self.lora_A_v(x)) * self.scaling
|
||||
return q, v
|
||||
|
||||
out = torch.einsum('bjm,ndm,nkd->nbjk', x, self.lora_A, self.lora_B)
|
||||
return out
|
||||
|
||||
|
||||
class LoraInjectionWrapper:
|
||||
|
||||
|
|
@ -222,7 +218,8 @@ class LoraInjectionWrapper:
|
|||
|
||||
def forward_with_lora(self, x):
|
||||
result = self.module.forward_before_lora(x)
|
||||
q, v = self.lora_layer(x)
|
||||
lora_out = self.lora_layer(x)
|
||||
q, v = lora_out[0], lora_out[1]
|
||||
dim = self.module.out_features // 3
|
||||
result[:, :, :dim] += q
|
||||
result[:, :, -dim:] += v
|
||||
|
|
@ -246,7 +243,7 @@ def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
|
|||
if prefix not in lora_weight_dic.keys():
|
||||
lora_weight_dic[prefix] = {}
|
||||
lora_weight_dic[prefix][k_new] = v
|
||||
|
||||
|
||||
lora_layers = {}
|
||||
for prefix, lora_weight_dic_tmp in lora_weight_dic.items():
|
||||
k1 = 'self_attn.q_proj.lora_A.weight'
|
||||
|
|
@ -254,31 +251,16 @@ def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
|
|||
k3 = 'self_attn.v_proj.lora_A.weight'
|
||||
k4 = 'self_attn.v_proj.lora_B.weight'
|
||||
|
||||
weight = lora_weight_dic_tmp[k1]
|
||||
l_dim = weight.shape[0]
|
||||
r_dim = weight.shape[1]
|
||||
lora_A_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
|
||||
lora_A_q.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
lora_A_q = lora_weight_dic_tmp[k1].to(device=device, dtype=dtype)
|
||||
lora_B_q = lora_weight_dic_tmp[k2].to(device=device, dtype=dtype)
|
||||
lora_A_v = lora_weight_dic_tmp[k3].to(device=device, dtype=dtype)
|
||||
lora_B_v = lora_weight_dic_tmp[k4].to(device=device, dtype=dtype)
|
||||
|
||||
loraA_weight = torch.concat([lora_A_q.unsqueeze(0), lora_A_v.unsqueeze(0)], dim=0)
|
||||
loraB_weight = torch.concat([lora_B_q.unsqueeze(0), lora_B_v.unsqueeze(0)], dim=0)
|
||||
loraA_weight *= scaling
|
||||
|
||||
weight = lora_weight_dic_tmp[k2]
|
||||
l_dim = weight.shape[0]
|
||||
r_dim = weight.shape[1]
|
||||
lora_B_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
|
||||
lora_B_q.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
weight = lora_weight_dic_tmp[k3]
|
||||
l_dim = weight.shape[0]
|
||||
r_dim = weight.shape[1]
|
||||
lora_A_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
|
||||
lora_A_v.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
weight = lora_weight_dic_tmp[k4]
|
||||
l_dim = weight.shape[0]
|
||||
r_dim = weight.shape[1]
|
||||
lora_B_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
|
||||
lora_B_v.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
lora_layer = CustomLoraLayerMerged(scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v)
|
||||
lora_layer = CustomLoraLayerMerged(loraA_weight, loraB_weight)
|
||||
lora_layer = lora_layer.to(device=device, dtype=dtype)
|
||||
lora_layers[prefix] = lora_layer
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue