optimize lora compute

This commit is contained in:
John Smith 2023-04-23 20:00:28 +08:00
parent 82bbea2729
commit b5af5c00e1
1 changed files with 18 additions and 36 deletions

View File

@ -195,19 +195,15 @@ def make_fused_mlp(m, parent_name=''):
class CustomLoraLayerMerged(torch.nn.Module):
def __init__(self, scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v):
def __init__(self, lora_A, lora_B):
super().__init__()
self.lora_A_q = lora_A_q
self.lora_B_q = lora_B_q
self.lora_A_v = lora_A_v
self.lora_B_v = lora_B_v
self.scaling = scaling
self.lora_A = torch.nn.Parameter(lora_A, requires_grad=False)
self.lora_B = torch.nn.Parameter(lora_B, requires_grad=False)
def forward(self, x):
q = self.lora_B_q(self.lora_A_q(x)) * self.scaling
v = self.lora_B_v(self.lora_A_v(x)) * self.scaling
return q, v
out = torch.einsum('bjm,ndm,nkd->nbjk', x, self.lora_A, self.lora_B)
return out
class LoraInjectionWrapper:
@ -222,7 +218,8 @@ class LoraInjectionWrapper:
def forward_with_lora(self, x):
result = self.module.forward_before_lora(x)
q, v = self.lora_layer(x)
lora_out = self.lora_layer(x)
q, v = lora_out[0], lora_out[1]
dim = self.module.out_features // 3
result[:, :, :dim] += q
result[:, :, -dim:] += v
@ -246,7 +243,7 @@ def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
if prefix not in lora_weight_dic.keys():
lora_weight_dic[prefix] = {}
lora_weight_dic[prefix][k_new] = v
lora_layers = {}
for prefix, lora_weight_dic_tmp in lora_weight_dic.items():
k1 = 'self_attn.q_proj.lora_A.weight'
@ -254,31 +251,16 @@ def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
k3 = 'self_attn.v_proj.lora_A.weight'
k4 = 'self_attn.v_proj.lora_B.weight'
weight = lora_weight_dic_tmp[k1]
l_dim = weight.shape[0]
r_dim = weight.shape[1]
lora_A_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
lora_A_q.weight = torch.nn.Parameter(weight, requires_grad=False)
lora_A_q = lora_weight_dic_tmp[k1].to(device=device, dtype=dtype)
lora_B_q = lora_weight_dic_tmp[k2].to(device=device, dtype=dtype)
lora_A_v = lora_weight_dic_tmp[k3].to(device=device, dtype=dtype)
lora_B_v = lora_weight_dic_tmp[k4].to(device=device, dtype=dtype)
loraA_weight = torch.concat([lora_A_q.unsqueeze(0), lora_A_v.unsqueeze(0)], dim=0)
loraB_weight = torch.concat([lora_B_q.unsqueeze(0), lora_B_v.unsqueeze(0)], dim=0)
loraA_weight *= scaling
weight = lora_weight_dic_tmp[k2]
l_dim = weight.shape[0]
r_dim = weight.shape[1]
lora_B_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
lora_B_q.weight = torch.nn.Parameter(weight, requires_grad=False)
weight = lora_weight_dic_tmp[k3]
l_dim = weight.shape[0]
r_dim = weight.shape[1]
lora_A_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
lora_A_v.weight = torch.nn.Parameter(weight, requires_grad=False)
weight = lora_weight_dic_tmp[k4]
l_dim = weight.shape[0]
r_dim = weight.shape[1]
lora_B_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
lora_B_v.weight = torch.nn.Parameter(weight, requires_grad=False)
lora_layer = CustomLoraLayerMerged(scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v)
lora_layer = CustomLoraLayerMerged(loraA_weight, loraB_weight)
lora_layer = lora_layer.to(device=device, dtype=dtype)
lora_layers[prefix] = lora_layer