# coding=utf-8 # Copyright 2023-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import math import re import warnings from dataclasses import asdict, dataclass, field from enum import Enum from typing import List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers.pytorch_utils import Conv1D from ..utils import PeftConfig, PeftType, transpose def is_bnb_available(): return importlib.util.find_spec("bitsandbytes") is not None def is_gptq_available(): return importlib.util.find_spec("quant") is not None if is_bnb_available(): import bitsandbytes as bnb if is_gptq_available(): import quant @dataclass class LoraConfig(PeftConfig): """ This is the configuration class to store the configuration of a [`~peft.Lora`]. Args: r (`int`): Lora attention dimension target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to. lora_alpha (`float`): The alpha parameter for Lora scaling. lora_dropout (`float`): The dropout probability for Lora layers. merge_weights (`bool`): Whether to merge the weights of the Lora layers with the base transformer model in `eval` mode. fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out) enable_lora ( `List[bool]`): Used with `lora.MergedLinear`. bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only' modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. """ r: int = field(default=8, metadata={"help": "Lora attention dimension"}) target_modules: Optional[Union[List[str], str]] = field( default=None, metadata={ "help": "List of module names or regex expression of the module names to replace with Lora." "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " }, ) lora_alpha: int = field(default=None, metadata={"help": "Lora alpha"}) lora_dropout: float = field(default=None, metadata={"help": "Lora dropout"}) merge_weights: bool = field( default=False, metadata={"help": "Merge weights of the original model and the Lora model"} ) fan_in_fan_out: bool = field( default=False, metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, ) enable_lora: Optional[List[bool]] = field(default=None, metadata={"help": "Used with `lora.MergedLinear`."}) bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}) modules_to_save: Optional[List[str]] = field( default=None, metadata={ "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " "For example, in Sequence Classification or Token Classification tasks, " "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." }, ) def __post_init__(self): self.peft_type = PeftType.LORA class LoraModel(torch.nn.Module): """ Creates Low Rank Adapter (Lora) model from a pretrained transformers model. Args: model ([`transformers.PreTrainedModel`]): The model to be adapted. config ([`LoraConfig`]): The configuration of the Lora model. Returns: `torch.nn.Module`: The Lora model. Example:: >>> from transformers import AutoModelForSeq2SeqLM, LoraConfig >>> from peft import LoraModel, LoraConfig >>> config = LoraConfig( peft_type="LORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"], lora_dropout=0.01, ) >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> lora_model = LoraModel(config, model) **Attributes**: - **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted. - **peft_config** ([`LoraConfig`]): The configuration of the Lora model. """ def __init__(self, config, model): super().__init__() self.peft_config = config self.model = model self._find_and_replace() mark_only_lora_as_trainable(self.model, self.peft_config.bias) self.forward = self.model.forward def _find_and_replace(self): loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) if loaded_in_8bit and not is_bnb_available(): raise ImportError( "To use Lora with 8-bit quantization, please install the `bitsandbytes` package. " "You can install it with `pip install bitsandbytes`." ) is_target_modules_in_base_model = False is_hf_device_map_available = hasattr(self.model, "hf_device_map") kwargs = { "r": self.peft_config.r, "lora_alpha": self.peft_config.lora_alpha, "lora_dropout": self.peft_config.lora_dropout, "fan_in_fan_out": self.peft_config.fan_in_fan_out, "merge_weights": (self.peft_config.merge_weights or self.peft_config.inference_mode) and not is_hf_device_map_available, } key_list = [key for key, _ in self.model.named_modules()] for key in key_list: if isinstance(self.peft_config.target_modules, str): target_module_found = re.fullmatch(self.peft_config.target_modules, key) else: target_module_found = any(key.endswith(target_key) for target_key in self.peft_config.target_modules) if target_module_found: if not is_target_modules_in_base_model: is_target_modules_in_base_model = True parent, target, target_name = self._get_submodules(key) bias = target.bias is not None if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): kwargs.update( { "has_fp16_weights": target.state.has_fp16_weights, "memory_efficient_backward": target.state.memory_efficient_backward, "threshold": target.state.threshold, "index": target.index, } ) if self.peft_config.enable_lora is None: new_module = Linear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs) else: kwargs.update({"enable_lora": self.peft_config.enable_lora}) new_module = MergedLinear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs) elif isinstance(target, torch.nn.Linear) and self.peft_config.enable_lora is None: new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs) elif isinstance(target, Autograd4bitQuantLinear) and self.peft_config.enable_lora is None: new_module = Linear4bitLt(target.in_features, target.out_features, bias=bias, **kwargs) elif self.peft_config.enable_lora is not None: kwargs.update({"enable_lora": self.peft_config.enable_lora}) if isinstance(target, Conv1D): in_features, out_features = ( target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape ) else: in_features, out_features = target.in_features, target.out_features if kwargs["fan_in_fan_out"]: warnings.warn( "fan_in_fan_out is set to True but the target module is not a Conv1D. " "Setting fan_in_fan_out to False." ) kwargs["fan_in_fan_out"] = self.peft_config.fan_in_fan_out = False new_module = MergedLinear(in_features, out_features, bias=bias, **kwargs) self._replace_module(parent, target_name, new_module, target) if not is_target_modules_in_base_model: raise ValueError( f"Target modules {self.peft_config.target_modules} not found in the base model. " f"Please check the target modules and try again." ) def _get_submodules(self, key): parent = self.model.get_submodule(".".join(key.split(".")[:-1])) target_name = key.split(".")[-1] target = self.model.get_submodule(key) return parent, target, target_name def _replace_module(self, parent_module, child_name, new_module, old_module): setattr(parent_module, child_name, new_module) if isinstance(old_module, Autograd4bitQuantLinear) and isinstance(new_module, Linear4bitLt): new_module.qweight = old_module.qweight new_module.scales = old_module.scales new_module.zeros = old_module.zeros new_module.bias = old_module.bias if getattr(old_module, "state", None) is not None: new_module.state = old_module.state new_module.to(old_module.qweight.device) # dispatch to correct device for name, module in new_module.named_modules(): if "lora_" in name: module.to(old_module.qweight.device) else: new_module.weight = old_module.weight if old_module.bias is not None: new_module.bias = old_module.bias if getattr(old_module, "state", None) is not None: new_module.state = old_module.state new_module.to(old_module.weight.device) # dispatch to correct device for name, module in new_module.named_modules(): if "lora_" in name: module.to(old_module.weight.device) def __getattr__(self, name: str): """Forward missing attributes to the wrapped module.""" try: return super().__getattr__(name) # defer to nn.Module's logic except AttributeError: return getattr(self.model, name) @property def modules_to_save(self): return None def get_peft_config_as_dict(self, inference: bool = False): config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.peft_config).items()} if inference: config["inference_mode"] = True return config def _set_adapter_layers(self, enabled=True): for module in self.model.modules(): if isinstance(module, LoraLayer): module.disable_adapters = False if enabled else True def enable_adapter_layers(self): self._set_adapter_layers(enabled=True) def disable_adapter_layers(self): self._set_adapter_layers(enabled=False) # Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py # and modified to work with PyTorch FSDP # ------------------------------------------------------------------------------------------ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ # had to adapt it for `lora_only` to work def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: for n, p in model.named_parameters(): if "lora_" not in n: p.requires_grad = False if bias == "none": return elif bias == "all": for n, p in model.named_parameters(): if "bias" in n: p.requires_grad = True elif bias == "lora_only": for m in model.modules(): if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None: m.bias.requires_grad = True else: raise NotImplementedError class LoraLayer: def __init__( self, r: int, lora_alpha: int, lora_dropout: float, merge_weights: bool, ): self.r = r self.lora_alpha = lora_alpha # Optional dropout if lora_dropout > 0.0: self.lora_dropout = nn.Dropout(p=lora_dropout) else: self.lora_dropout = lambda x: x # Mark the weight as unmerged self.merged = False self.merge_weights = merge_weights self.disable_adapters = False class Linear(nn.Linear, LoraLayer): # Lora implemented in a dense layer def __init__( self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) merge_weights: bool = True, **kwargs, ): nn.Linear.__init__(self, in_features, out_features, **kwargs) LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) self.fan_in_fan_out = fan_in_fan_out # Actual trainable parameters if r > 0: self.lora_A = nn.Linear(in_features, r, bias=False) self.lora_B = nn.Linear(r, out_features, bias=False) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.T def reset_parameters(self): nn.Linear.reset_parameters(self) if hasattr(self, "lora_A"): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) nn.init.zeros_(self.lora_B.weight) def train(self, mode: bool = True): nn.Linear.train(self, mode) self.lora_A.train(mode) self.lora_B.train(mode) if not mode and self.merge_weights and not self.merged: # Merge the weights and mark it if self.r > 0: self.weight.data += ( transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling ) self.merged = True elif self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0: self.weight.data -= ( transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling ) self.merged = False def eval(self): nn.Linear.eval(self) self.lora_A.eval() self.lora_B.eval() def forward(self, x: torch.Tensor): if self.disable_adapters: if self.r > 0 and self.merged: self.weight.data -= ( transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling ) self.merged = False return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) elif self.r > 0 and not self.merged: result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) if self.r > 0: result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling return result else: return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) class MergedLinear(nn.Linear, LoraLayer): # Lora implemented in a dense layer def __init__( self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, enable_lora: List[bool] = [False], fan_in_fan_out: bool = False, merge_weights: bool = True, **kwargs, ): nn.Linear.__init__(self, in_features, out_features, **kwargs) LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) if out_features % len(enable_lora) != 0: raise ValueError("The length of enable_lora must divide out_features") self.enable_lora = enable_lora self.fan_in_fan_out = fan_in_fan_out # Actual trainable parameters if r > 0 and any(enable_lora): self.lora_A = nn.Linear(in_features, r * sum(enable_lora), bias=False) self.lora_B = nn.Conv1d( r * sum(enable_lora), out_features // len(enable_lora) * sum(enable_lora), kernel_size=1, groups=2, bias=False, ) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False # Compute the indices self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1) self.lora_ind[enable_lora, :] = True self.lora_ind = self.lora_ind.view(-1) self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.T def reset_parameters(self): nn.Linear.reset_parameters(self) if hasattr(self, "lora_A"): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) nn.init.zeros_(self.lora_B.weight) def zero_pad(self, x): result = x.new_zeros((*x.shape[:-1], self.out_features)) result = result.view(-1, self.out_features) result[:, self.lora_ind] = x.reshape(-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)) return result.view((*x.shape[:-1], self.out_features)) def train(self, mode: bool = True): nn.Linear.train(self, mode) self.lora_A.train(mode) self.lora_B.train(mode) if not mode and self.merge_weights and not self.merged: # Merge the weights and mark it if self.r > 0 and any(self.enable_lora): delta_w = ( F.conv1d( self.lora_A.weight.data.unsqueeze(0), self.lora_B.weight.data, groups=sum(self.enable_lora), ) .squeeze(0) .transpose(-2, -1) ) self.weight.data += transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out) self.merged = True elif self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0 and any(self.enable_lora): delta_w = ( F.conv1d( self.lora_A.weight.data.unsqueeze(0), self.lora_B.weight.data, groups=sum(self.enable_lora), ) .squeeze(0) .transpose(-2, -1) ) self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out) self.merged = False def eval(self): nn.Linear.eval(self) self.lora_A.eval() self.lora_B.eval() def forward(self, x: torch.Tensor): if self.disable_adapters: if self.r > 0 and self.merged and any(self.enable_lora): delta_w = ( F.conv1d( self.lora_A.weight.data.unsqueeze(0), self.lora_B.weight.data, groups=sum(self.enable_lora), ) .squeeze(0) .transpose(-2, -1) ) self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out) self.merged = False return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) elif self.merged: return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) else: result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) if self.r > 0: after_A = self.lora_A(self.lora_dropout(x)) after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1) result += self.zero_pad(after_B) * self.scaling return result if is_bnb_available(): class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer): # Lora implemented in a dense layer def __init__( self, in_features, out_features, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, **kwargs, ): bnb.nn.Linear8bitLt.__init__( self, in_features, out_features, bias=kwargs.get("bias", True), has_fp16_weights=kwargs.get("has_fp16_weights", True), memory_efficient_backward=kwargs.get("memory_efficient_backward", False), threshold=kwargs.get("threshold", 0.0), index=kwargs.get("index", None), ) LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) # Actual trainable parameters if r > 0: self.lora_A = nn.Linear(in_features, r, bias=False) self.lora_B = nn.Linear(r, out_features, bias=False) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False self.reset_parameters() def reset_parameters(self): if hasattr(self, "lora_A"): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) nn.init.zeros_(self.lora_B.weight) def forward(self, x: torch.Tensor): result = super().forward(x) if self.disable_adapters: return result elif self.r > 0: if not torch.is_autocast_enabled(): expected_dtype = result.dtype if x.dtype != torch.float32: x = x.float() output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling result += output else: output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling result += output return result class MergedLinear8bitLt(bnb.nn.Linear8bitLt, LoraLayer): # Lora implemented in a dense layer def __init__( self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, enable_lora: List[bool] = [False], **kwargs, ): bnb.nn.Linear8bitLt.__init__( self, in_features, out_features, bias=kwargs.get("bias", True), has_fp16_weights=kwargs.get("has_fp16_weights", True), memory_efficient_backward=kwargs.get("memory_efficient_backward", False), threshold=kwargs.get("threshold", 0.0), index=kwargs.get("index", None), ) LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) if out_features % len(enable_lora) != 0: raise ValueError("The length of enable_lora must divide out_features") self.enable_lora = enable_lora # Actual trainable parameters if r > 0 and any(enable_lora): self.lora_A = nn.Linear(in_features, r * sum(enable_lora), bias=False) self.lora_B = nn.Conv1d( r * sum(enable_lora), out_features // len(enable_lora) * sum(enable_lora), kernel_size=1, groups=2, bias=False, ) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False # Compute the indices self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1) self.lora_ind[enable_lora, :] = True self.lora_ind = self.lora_ind.view(-1) self.reset_parameters() def reset_parameters(self): if hasattr(self, "lora_A"): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) nn.init.zeros_(self.lora_B.weight) def zero_pad(self, x): result = x.new_zeros((*x.shape[:-1], self.out_features)) result = result.view(-1, self.out_features) result[:, self.lora_ind] = x.reshape( -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora) ) return result.view((*x.shape[:-1], self.out_features)) def forward(self, x: torch.Tensor): result = super().forward(x) if self.disable_adapters: return result elif self.r > 0: if not torch.is_autocast_enabled(): expected_dtype = result.dtype if x.dtype != torch.float32: x = x.float() after_A = self.lora_A(self.lora_dropout(x)) after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1) output = self.zero_pad(after_B).to(expected_dtype) * self.scaling result += output else: after_A = self.lora_A(self.lora_dropout(x)) after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1) output = self.zero_pad(after_B) * self.scaling result += output return result if is_gptq_available(): from autograd_4bit import Autograd4bitQuantLinear class Linear4bitLt(Autograd4bitQuantLinear, LoraLayer): # Lora implemented in a dense layer def __init__( self, in_features, out_features, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, **kwargs, ): Autograd4bitQuantLinear.__init__( self, in_features, out_features ) LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) # Actual trainable parameters if r > 0: self.lora_A = nn.Linear(in_features, r, bias=False) self.lora_B = nn.Linear(r, out_features, bias=False) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.qweight.requires_grad = False self.scales.requires_grad = False self.zeros.requires_grad = False self.bias.requires_grad = False self.reset_parameters() def reset_parameters(self): if hasattr(self, "lora_A"): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) nn.init.zeros_(self.lora_B.weight) def forward(self, x: torch.Tensor): result = super().forward(x) if self.disable_adapters: return result elif self.r > 0: if not torch.is_autocast_enabled(): expected_dtype = result.dtype if x.dtype != torch.float32: x = x.float() output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling result += output else: output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling result += output return result