This commit is contained in:
John Smith 2023-04-13 11:34:53 +08:00
parent 76d7963dff
commit 9c3058c1de
1 changed files with 49 additions and 37 deletions

View File

@ -14,9 +14,11 @@ from autograd_4bit import Autograd4bitQuantLinear
class Linear4bitLt(Autograd4bitQuantLinear, LoraLayer): class Linear4bitLt(Autograd4bitQuantLinear, LoraLayer):
# Lora implemented in a dense layer
def __init__( # Lora implemented in a dense layer
def __init__(
self, self,
adapter_name,
in_features, in_features,
out_features, out_features,
groupsize: int = -1, groupsize: int = -1,
@ -25,20 +27,16 @@ class Linear4bitLt(Autograd4bitQuantLinear, LoraLayer):
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
**kwargs, **kwargs,
): ):
Autograd4bitQuantLinear.__init__( Autograd4bitQuantLinear.__init__(
self, self,
in_features, in_features,
out_features, out_features,
groupsize, groupsize,
is_v1_model is_v1_model
) )
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Linear(in_features, r, bias=False)
self.lora_B = nn.Linear(r, out_features, bias=False)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix # Freezing the pre-trained weight matrix
self.qweight.requires_grad = False self.qweight.requires_grad = False
self.scales.requires_grad = False self.scales.requires_grad = False
@ -48,31 +46,43 @@ class Linear4bitLt(Autograd4bitQuantLinear, LoraLayer):
self.qzeros.requires_grad = False self.qzeros.requires_grad = False
self.g_idx.requires_grad = False self.g_idx.requires_grad = False
self.bias.requires_grad = False self.bias.requires_grad = False
self.reset_parameters()
def reset_parameters(self): init_lora_weights = kwargs.pop("init_lora_weights", True)
if hasattr(self, "lora_A"): self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
# initialize A the same way as the default for nn.Linear and B to zero self.active_adapter = adapter_name
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
result = super().forward(x) result = super().forward(x)
if self.disable_adapters: if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
return result
elif self.r[self.active_adapter] > 0:
if not torch.is_autocast_enabled():
expected_dtype = result.dtype
if x.dtype != torch.float32:
x = x.float()
output = (
self.lora_B[self.active_adapter](
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
).to(expected_dtype)
* self.scaling[self.active_adapter]
)
else:
output = (
self.lora_B[self.active_adapter](
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
)
* self.scaling[self.active_adapter]
)
result += output
return result return result
elif self.r > 0:
if not torch.is_autocast_enabled(): @property
expected_dtype = result.dtype def weight(self):
class WeightDeviceClass:
if x.dtype != torch.float32: device = self.qweight.device
x = x.float() return WeightDeviceClass()
output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling
result += output
else:
output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
result += output
return result
class GPTQLoraModel(lora.LoraModel): class GPTQLoraModel(lora.LoraModel):
@ -124,6 +134,8 @@ class GPTQLoraModel(lora.LoraModel):
new_module = Linear8bitLt( new_module = Linear8bitLt(
adapter_name, target.in_features, target.out_features, bias=bias, **kwargs adapter_name, target.in_features, target.out_features, bias=bias, **kwargs
) )
elif isinstance(target, Autograd4bitQuantLinear):
new_module = Linear4bitLt(adapter_name, target.in_features, target.out_features, target.groupsize, target.is_v1_model, bias=bias, **kwargs)
else: else:
if isinstance(target, torch.nn.Linear): if isinstance(target, torch.nn.Linear):
in_features, out_features = target.in_features, target.out_features in_features, out_features = target.in_features, target.out_features