This commit is contained in:
John Smith 2023-04-22 17:23:24 +08:00
parent 4e42965c0d
commit 9fe5ab3642
2 changed files with 50 additions and 20 deletions

View File

@ -209,6 +209,26 @@ class CustomLoraLayerMerged(torch.nn.Module):
return q, v
class LoraInjectionWrapper:
def __init__(self, module, lora_layer):
self.module = module
self.lora_layer = lora_layer
def apply(self):
self.module.forward_before_lora = self.module.forward
self.module.forward = self.forward_with_lora
self.module.is_lora_injected = True
def forward_with_lora(self, x):
result = self.module.forward_before_lora(x)
q, v = self.lora_layer(x)
dim = self.module.out_features // 3
result[:, :, :dim] += q
result[:, :, -dim:] += v
return result
def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
print('Device: {}, dtype: {}'.format(device, dtype))
@ -263,6 +283,7 @@ def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
lora_layers[prefix] = lora_layer
# Injection
wrappers = []
for n, m in model.named_modules():
if 'qkv_proj' in n and isinstance(m, Autograd4bitQuantLinear):
# restoring forward
@ -270,16 +291,10 @@ def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
m.forward = m.forward_before_lora
prefix = re.findall('^model\.layers\.\d+\.', n)[0]
lora_layer = lora_layers[prefix]
m.forward_before_lora = m.forward
def forward_with_lora(self, x):
result = self.forward_before_lora(x)
q, v = lora_layer(x)
dim = self.out_features // 3
result[:, :, :dim] += q
result[:, :, -dim:] += v
return result
m.forward = types.MethodType(forward_with_lora, m)
m.is_lora_injected = True
wrapper = LoraInjectionWrapper(m, lora_layer)
wrapper.apply()
wrappers.append(wrapper)
print('Lora Injected.')
return wrappers

View File

@ -21,6 +21,26 @@ class CustomLoraLayerMerged(torch.nn.Module):
return q, v
class LoraInjectionWrapper:
def __init__(self, module, lora_layer):
self.module = module
self.lora_layer = lora_layer
def apply(self):
self.module.forward_before_lora = self.module.forward
self.module.forward = self.forward_with_lora
self.module.is_lora_injected = True
def forward_with_lora(self, x):
result = self.module.forward_before_lora(x)
q, v = self.lora_layer(x)
dim = self.module.outfeatures // 3
result[:, :, :dim] += q
result[:, :, -dim:] += v
return result
def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
print('Device: {}, dtype: {}'.format(device, dtype))
@ -75,6 +95,7 @@ def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
lora_layers[prefix] = lora_layer
# Injection
wrappers = []
for n, m in model.named_modules():
if 'qkv_proj' in n and isinstance(m, QuantLinear):
# restoring forward
@ -82,15 +103,9 @@ def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
m.forward = m.forward_before_lora
prefix = re.findall('^model\.layers\.\d+\.', n)[0]
lora_layer = lora_layers[prefix]
m.forward_before_lora = m.forward
def forward_with_lora(self, x):
result = self.forward_before_lora(x)
q, v = lora_layer(x)
dim = self.outfeatures // 3
result[:, :, :dim] += q
result[:, :, -dim:] += v
return result
m.forward = types.MethodType(forward_with_lora, m)
m.is_lora_injected = True
wrapper = LoraInjectionWrapper(m, lora_layer)
wrapper.apply()
wrappers.append(wrapper)
print('Lora Injected.')
return wrappers