From b5af5c00e1cfb7d54130781ae067ba6bad47a6d3 Mon Sep 17 00:00:00 2001 From: John Smith Date: Sun, 23 Apr 2023 20:00:28 +0800 Subject: [PATCH] optimize lora compute --- model_attn_mlp_patch.py | 54 ++++++++++++++--------------------------- 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/model_attn_mlp_patch.py b/model_attn_mlp_patch.py index 1b20465..4320f31 100644 --- a/model_attn_mlp_patch.py +++ b/model_attn_mlp_patch.py @@ -195,19 +195,15 @@ def make_fused_mlp(m, parent_name=''): class CustomLoraLayerMerged(torch.nn.Module): - def __init__(self, scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v): + def __init__(self, lora_A, lora_B): super().__init__() - self.lora_A_q = lora_A_q - self.lora_B_q = lora_B_q - self.lora_A_v = lora_A_v - self.lora_B_v = lora_B_v - self.scaling = scaling + self.lora_A = torch.nn.Parameter(lora_A, requires_grad=False) + self.lora_B = torch.nn.Parameter(lora_B, requires_grad=False) def forward(self, x): - q = self.lora_B_q(self.lora_A_q(x)) * self.scaling - v = self.lora_B_v(self.lora_A_v(x)) * self.scaling - return q, v - + out = torch.einsum('bjm,ndm,nkd->nbjk', x, self.lora_A, self.lora_B) + return out + class LoraInjectionWrapper: @@ -222,7 +218,8 @@ class LoraInjectionWrapper: def forward_with_lora(self, x): result = self.module.forward_before_lora(x) - q, v = self.lora_layer(x) + lora_out = self.lora_layer(x) + q, v = lora_out[0], lora_out[1] dim = self.module.out_features // 3 result[:, :, :dim] += q result[:, :, -dim:] += v @@ -246,7 +243,7 @@ def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16): if prefix not in lora_weight_dic.keys(): lora_weight_dic[prefix] = {} lora_weight_dic[prefix][k_new] = v - + lora_layers = {} for prefix, lora_weight_dic_tmp in lora_weight_dic.items(): k1 = 'self_attn.q_proj.lora_A.weight' @@ -254,31 +251,16 @@ def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16): k3 = 'self_attn.v_proj.lora_A.weight' k4 = 'self_attn.v_proj.lora_B.weight' - weight = lora_weight_dic_tmp[k1] - l_dim = weight.shape[0] - r_dim = weight.shape[1] - lora_A_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) - lora_A_q.weight = torch.nn.Parameter(weight, requires_grad=False) + lora_A_q = lora_weight_dic_tmp[k1].to(device=device, dtype=dtype) + lora_B_q = lora_weight_dic_tmp[k2].to(device=device, dtype=dtype) + lora_A_v = lora_weight_dic_tmp[k3].to(device=device, dtype=dtype) + lora_B_v = lora_weight_dic_tmp[k4].to(device=device, dtype=dtype) + + loraA_weight = torch.concat([lora_A_q.unsqueeze(0), lora_A_v.unsqueeze(0)], dim=0) + loraB_weight = torch.concat([lora_B_q.unsqueeze(0), lora_B_v.unsqueeze(0)], dim=0) + loraA_weight *= scaling - weight = lora_weight_dic_tmp[k2] - l_dim = weight.shape[0] - r_dim = weight.shape[1] - lora_B_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) - lora_B_q.weight = torch.nn.Parameter(weight, requires_grad=False) - - weight = lora_weight_dic_tmp[k3] - l_dim = weight.shape[0] - r_dim = weight.shape[1] - lora_A_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) - lora_A_v.weight = torch.nn.Parameter(weight, requires_grad=False) - - weight = lora_weight_dic_tmp[k4] - l_dim = weight.shape[0] - r_dim = weight.shape[1] - lora_B_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) - lora_B_v.weight = torch.nn.Parameter(weight, requires_grad=False) - - lora_layer = CustomLoraLayerMerged(scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v) + lora_layer = CustomLoraLayerMerged(loraA_weight, loraB_weight) lora_layer = lora_layer.to(device=device, dtype=dtype) lora_layers[prefix] = lora_layer