optimize lora compute
This commit is contained in:
parent
82bbea2729
commit
b5af5c00e1
|
|
@ -195,19 +195,15 @@ def make_fused_mlp(m, parent_name=''):
|
||||||
|
|
||||||
class CustomLoraLayerMerged(torch.nn.Module):
|
class CustomLoraLayerMerged(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v):
|
def __init__(self, lora_A, lora_B):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lora_A_q = lora_A_q
|
self.lora_A = torch.nn.Parameter(lora_A, requires_grad=False)
|
||||||
self.lora_B_q = lora_B_q
|
self.lora_B = torch.nn.Parameter(lora_B, requires_grad=False)
|
||||||
self.lora_A_v = lora_A_v
|
|
||||||
self.lora_B_v = lora_B_v
|
|
||||||
self.scaling = scaling
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
q = self.lora_B_q(self.lora_A_q(x)) * self.scaling
|
out = torch.einsum('bjm,ndm,nkd->nbjk', x, self.lora_A, self.lora_B)
|
||||||
v = self.lora_B_v(self.lora_A_v(x)) * self.scaling
|
return out
|
||||||
return q, v
|
|
||||||
|
|
||||||
|
|
||||||
class LoraInjectionWrapper:
|
class LoraInjectionWrapper:
|
||||||
|
|
||||||
|
|
@ -222,7 +218,8 @@ class LoraInjectionWrapper:
|
||||||
|
|
||||||
def forward_with_lora(self, x):
|
def forward_with_lora(self, x):
|
||||||
result = self.module.forward_before_lora(x)
|
result = self.module.forward_before_lora(x)
|
||||||
q, v = self.lora_layer(x)
|
lora_out = self.lora_layer(x)
|
||||||
|
q, v = lora_out[0], lora_out[1]
|
||||||
dim = self.module.out_features // 3
|
dim = self.module.out_features // 3
|
||||||
result[:, :, :dim] += q
|
result[:, :, :dim] += q
|
||||||
result[:, :, -dim:] += v
|
result[:, :, -dim:] += v
|
||||||
|
|
@ -246,7 +243,7 @@ def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
|
||||||
if prefix not in lora_weight_dic.keys():
|
if prefix not in lora_weight_dic.keys():
|
||||||
lora_weight_dic[prefix] = {}
|
lora_weight_dic[prefix] = {}
|
||||||
lora_weight_dic[prefix][k_new] = v
|
lora_weight_dic[prefix][k_new] = v
|
||||||
|
|
||||||
lora_layers = {}
|
lora_layers = {}
|
||||||
for prefix, lora_weight_dic_tmp in lora_weight_dic.items():
|
for prefix, lora_weight_dic_tmp in lora_weight_dic.items():
|
||||||
k1 = 'self_attn.q_proj.lora_A.weight'
|
k1 = 'self_attn.q_proj.lora_A.weight'
|
||||||
|
|
@ -254,31 +251,16 @@ def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
|
||||||
k3 = 'self_attn.v_proj.lora_A.weight'
|
k3 = 'self_attn.v_proj.lora_A.weight'
|
||||||
k4 = 'self_attn.v_proj.lora_B.weight'
|
k4 = 'self_attn.v_proj.lora_B.weight'
|
||||||
|
|
||||||
weight = lora_weight_dic_tmp[k1]
|
lora_A_q = lora_weight_dic_tmp[k1].to(device=device, dtype=dtype)
|
||||||
l_dim = weight.shape[0]
|
lora_B_q = lora_weight_dic_tmp[k2].to(device=device, dtype=dtype)
|
||||||
r_dim = weight.shape[1]
|
lora_A_v = lora_weight_dic_tmp[k3].to(device=device, dtype=dtype)
|
||||||
lora_A_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
|
lora_B_v = lora_weight_dic_tmp[k4].to(device=device, dtype=dtype)
|
||||||
lora_A_q.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
||||||
|
loraA_weight = torch.concat([lora_A_q.unsqueeze(0), lora_A_v.unsqueeze(0)], dim=0)
|
||||||
|
loraB_weight = torch.concat([lora_B_q.unsqueeze(0), lora_B_v.unsqueeze(0)], dim=0)
|
||||||
|
loraA_weight *= scaling
|
||||||
|
|
||||||
weight = lora_weight_dic_tmp[k2]
|
lora_layer = CustomLoraLayerMerged(loraA_weight, loraB_weight)
|
||||||
l_dim = weight.shape[0]
|
|
||||||
r_dim = weight.shape[1]
|
|
||||||
lora_B_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
|
|
||||||
lora_B_q.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
||||||
|
|
||||||
weight = lora_weight_dic_tmp[k3]
|
|
||||||
l_dim = weight.shape[0]
|
|
||||||
r_dim = weight.shape[1]
|
|
||||||
lora_A_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
|
|
||||||
lora_A_v.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
||||||
|
|
||||||
weight = lora_weight_dic_tmp[k4]
|
|
||||||
l_dim = weight.shape[0]
|
|
||||||
r_dim = weight.shape[1]
|
|
||||||
lora_B_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
|
|
||||||
lora_B_v.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
||||||
|
|
||||||
lora_layer = CustomLoraLayerMerged(scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v)
|
|
||||||
lora_layer = lora_layer.to(device=device, dtype=dtype)
|
lora_layer = lora_layer.to(device=device, dtype=dtype)
|
||||||
lora_layers[prefix] = lora_layer
|
lora_layers[prefix] = lora_layer
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue