optimize lora compute

This commit is contained in:
John Smith 2023-04-23 20:00:28 +08:00
parent 82bbea2729
commit b5af5c00e1
1 changed files with 18 additions and 36 deletions

View File

@ -195,19 +195,15 @@ def make_fused_mlp(m, parent_name=''):
class CustomLoraLayerMerged(torch.nn.Module): class CustomLoraLayerMerged(torch.nn.Module):
def __init__(self, scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v): def __init__(self, lora_A, lora_B):
super().__init__() super().__init__()
self.lora_A_q = lora_A_q self.lora_A = torch.nn.Parameter(lora_A, requires_grad=False)
self.lora_B_q = lora_B_q self.lora_B = torch.nn.Parameter(lora_B, requires_grad=False)
self.lora_A_v = lora_A_v
self.lora_B_v = lora_B_v
self.scaling = scaling
def forward(self, x): def forward(self, x):
q = self.lora_B_q(self.lora_A_q(x)) * self.scaling out = torch.einsum('bjm,ndm,nkd->nbjk', x, self.lora_A, self.lora_B)
v = self.lora_B_v(self.lora_A_v(x)) * self.scaling return out
return q, v
class LoraInjectionWrapper: class LoraInjectionWrapper:
@ -222,7 +218,8 @@ class LoraInjectionWrapper:
def forward_with_lora(self, x): def forward_with_lora(self, x):
result = self.module.forward_before_lora(x) result = self.module.forward_before_lora(x)
q, v = self.lora_layer(x) lora_out = self.lora_layer(x)
q, v = lora_out[0], lora_out[1]
dim = self.module.out_features // 3 dim = self.module.out_features // 3
result[:, :, :dim] += q result[:, :, :dim] += q
result[:, :, -dim:] += v result[:, :, -dim:] += v
@ -246,7 +243,7 @@ def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
if prefix not in lora_weight_dic.keys(): if prefix not in lora_weight_dic.keys():
lora_weight_dic[prefix] = {} lora_weight_dic[prefix] = {}
lora_weight_dic[prefix][k_new] = v lora_weight_dic[prefix][k_new] = v
lora_layers = {} lora_layers = {}
for prefix, lora_weight_dic_tmp in lora_weight_dic.items(): for prefix, lora_weight_dic_tmp in lora_weight_dic.items():
k1 = 'self_attn.q_proj.lora_A.weight' k1 = 'self_attn.q_proj.lora_A.weight'
@ -254,31 +251,16 @@ def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
k3 = 'self_attn.v_proj.lora_A.weight' k3 = 'self_attn.v_proj.lora_A.weight'
k4 = 'self_attn.v_proj.lora_B.weight' k4 = 'self_attn.v_proj.lora_B.weight'
weight = lora_weight_dic_tmp[k1] lora_A_q = lora_weight_dic_tmp[k1].to(device=device, dtype=dtype)
l_dim = weight.shape[0] lora_B_q = lora_weight_dic_tmp[k2].to(device=device, dtype=dtype)
r_dim = weight.shape[1] lora_A_v = lora_weight_dic_tmp[k3].to(device=device, dtype=dtype)
lora_A_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) lora_B_v = lora_weight_dic_tmp[k4].to(device=device, dtype=dtype)
lora_A_q.weight = torch.nn.Parameter(weight, requires_grad=False)
loraA_weight = torch.concat([lora_A_q.unsqueeze(0), lora_A_v.unsqueeze(0)], dim=0)
loraB_weight = torch.concat([lora_B_q.unsqueeze(0), lora_B_v.unsqueeze(0)], dim=0)
loraA_weight *= scaling
weight = lora_weight_dic_tmp[k2] lora_layer = CustomLoraLayerMerged(loraA_weight, loraB_weight)
l_dim = weight.shape[0]
r_dim = weight.shape[1]
lora_B_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
lora_B_q.weight = torch.nn.Parameter(weight, requires_grad=False)
weight = lora_weight_dic_tmp[k3]
l_dim = weight.shape[0]
r_dim = weight.shape[1]
lora_A_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
lora_A_v.weight = torch.nn.Parameter(weight, requires_grad=False)
weight = lora_weight_dic_tmp[k4]
l_dim = weight.shape[0]
r_dim = weight.shape[1]
lora_B_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
lora_B_v.weight = torch.nn.Parameter(weight, requires_grad=False)
lora_layer = CustomLoraLayerMerged(scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v)
lora_layer = lora_layer.to(device=device, dtype=dtype) lora_layer = lora_layer.to(device=device, dtype=dtype)
lora_layers[prefix] = lora_layer lora_layers[prefix] = lora_layer