diff --git a/monkeypatch/peft_tuners_lora_monkey_patch.py b/monkeypatch/peft_tuners_lora_monkey_patch.py index a946013..4e6b677 100644 --- a/monkeypatch/peft_tuners_lora_monkey_patch.py +++ b/monkeypatch/peft_tuners_lora_monkey_patch.py @@ -14,9 +14,11 @@ from autograd_4bit import Autograd4bitQuantLinear class Linear4bitLt(Autograd4bitQuantLinear, LoraLayer): - # Lora implemented in a dense layer - def __init__( + + # Lora implemented in a dense layer + def __init__( self, + adapter_name, in_features, out_features, groupsize: int = -1, @@ -25,20 +27,16 @@ class Linear4bitLt(Autograd4bitQuantLinear, LoraLayer): lora_alpha: int = 1, lora_dropout: float = 0.0, **kwargs, - ): - Autograd4bitQuantLinear.__init__( - self, - in_features, - out_features, - groupsize, - is_v1_model - ) - LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) - # Actual trainable parameters - if r > 0: - self.lora_A = nn.Linear(in_features, r, bias=False) - self.lora_B = nn.Linear(r, out_features, bias=False) - self.scaling = self.lora_alpha / self.r + ): + Autograd4bitQuantLinear.__init__( + self, + in_features, + out_features, + groupsize, + is_v1_model + ) + LoraLayer.__init__(self, in_features=in_features, out_features=out_features) + # Freezing the pre-trained weight matrix self.qweight.requires_grad = False self.scales.requires_grad = False @@ -48,31 +46,43 @@ class Linear4bitLt(Autograd4bitQuantLinear, LoraLayer): self.qzeros.requires_grad = False self.g_idx.requires_grad = False self.bias.requires_grad = False - self.reset_parameters() - def reset_parameters(self): - if hasattr(self, "lora_A"): - # initialize A the same way as the default for nn.Linear and B to zero - nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) - nn.init.zeros_(self.lora_B.weight) + init_lora_weights = kwargs.pop("init_lora_weights", True) + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) + self.active_adapter = adapter_name - def forward(self, x: torch.Tensor): - result = super().forward(x) + def forward(self, x: torch.Tensor): + result = super().forward(x) - if self.disable_adapters: + if self.disable_adapters or self.active_adapter not in self.lora_A.keys(): + return result + elif self.r[self.active_adapter] > 0: + if not torch.is_autocast_enabled(): + expected_dtype = result.dtype + + if x.dtype != torch.float32: + x = x.float() + output = ( + self.lora_B[self.active_adapter]( + self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) + ).to(expected_dtype) + * self.scaling[self.active_adapter] + ) + else: + output = ( + self.lora_B[self.active_adapter]( + self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) + ) + * self.scaling[self.active_adapter] + ) + result += output return result - elif self.r > 0: - if not torch.is_autocast_enabled(): - expected_dtype = result.dtype - - if x.dtype != torch.float32: - x = x.float() - output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling - result += output - else: - output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling - result += output - return result + + @property + def weight(self): + class WeightDeviceClass: + device = self.qweight.device + return WeightDeviceClass() class GPTQLoraModel(lora.LoraModel): @@ -124,6 +134,8 @@ class GPTQLoraModel(lora.LoraModel): new_module = Linear8bitLt( adapter_name, target.in_features, target.out_features, bias=bias, **kwargs ) + elif isinstance(target, Autograd4bitQuantLinear): + new_module = Linear4bitLt(adapter_name, target.in_features, target.out_features, target.groupsize, target.is_v1_model, bias=bias, **kwargs) else: if isinstance(target, torch.nn.Linear): in_features, out_features = target.in_features, target.out_features