import numpy as np import torch import torch.nn as nn from torch.nn import functional as F from torch.cuda.amp import custom_bwd, custom_fwd from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, LlamaMLP from autograd_4bit import Autograd4bitQuantLinear import matmul_utils_4bit import re import json import types class QuantLlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self,hidden_size,num_heads,qkv_proj,o_proj,rotary_emb,): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads if (self.head_dim * num_heads) != self.hidden_size: raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"f" and `num_heads`: {num_heads}).") self.qkv_proj = qkv_proj self.o_proj = o_proj self.rotary_emb = rotary_emb def _shape(self, tensor, seq_len, bsz): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward(self,hidden_states,past_key_value = None,attention_mask = None,position_ids = None, output_attentions = False,use_cache= False): """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() qkv_states = self.qkv_proj(hidden_states) query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # [bsz, nh, t, hd] is_causal = past_key_value is None if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None with torch.backends.cuda.sdp_kernel(enable_math=False): attn_output = F.scaled_dot_product_attention(query_states,key_states,value_states,is_causal=is_causal) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def make_quant_attn(model, is_v1_model=False): """ Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. """ print('Turning off matmul cache ...') matmul_utils_4bit.cache_buffer = False for name, m in model.named_modules(): if not isinstance(m, LlamaAttention): continue q_proj = m.q_proj k_proj = m.k_proj v_proj = m.v_proj if not is_v1_model: qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) del q_proj.qweight del k_proj.qweight del v_proj.qweight qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) del q_proj.qzeros del k_proj.qzeros del v_proj.qzeros scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) del q_proj.scales del k_proj.scales del v_proj.scales g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) del q_proj.g_idx del k_proj.g_idx del v_proj.g_idx bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None if q_proj.bias is not None: del q_proj.bias del k_proj.bias del v_proj.bias torch.cuda.empty_cache() qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features, out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features, groupsize=q_proj.groupsize, is_v1_model=False) qkv_layer.qweight = qweights qkv_layer.qzeros = qzeros qkv_layer.scales = scales qkv_layer.g_idx = g_idx qkv_layer.bias = bias else: qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) del q_proj.qweight del k_proj.qweight del v_proj.qweight zeros = torch.cat([q_proj.zeros, k_proj.zeros, v_proj.zeros], dim=0) del q_proj.zeros del k_proj.zeros del v_proj.zeros scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0) del q_proj.scales del k_proj.scales del v_proj.scales bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None if q_proj.bias is not None: del q_proj.bias del k_proj.bias del v_proj.bias torch.cuda.empty_cache() qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features, out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features, groupsize=-1, is_v1_model=True) qkv_layer.qweight = qweights qkv_layer.zeros = zeros qkv_layer.scales = scales qkv_layer.bias = bias attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb) if '.' in name: parent_name = name.rsplit('.', 1)[0] child_name = name[len(parent_name) + 1:] parent = model.get_submodule(parent_name) else: parent_name = '' parent = model child_name = name #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") setattr(parent, child_name, attn) class QuantLlamaMLP(nn.Module): def __init__(self, old_module, is_v1_model=False): super().__init__() gate_proj = old_module.gate_proj up_proj = old_module.up_proj if not is_v1_model: qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1) del gate_proj.qweight del up_proj.qweight qzeros = torch.cat([gate_proj.qzeros, up_proj.qzeros], dim=1) del gate_proj.qzeros del up_proj.qzeros scales = torch.cat([gate_proj.scales, up_proj.scales], dim=1) del gate_proj.scales del up_proj.scales g_idx = torch.cat([gate_proj.g_idx, up_proj.g_idx], dim=0) del gate_proj.g_idx del up_proj.g_idx bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None if gate_proj.bias is not None: del gate_proj.bias del up_proj.bias torch.cuda.empty_cache() self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features, out_features=gate_proj.out_features + up_proj.out_features, groupsize=gate_proj.groupsize, is_v1_model=False) self.gate_up_proj.qweight = qweights self.gate_up_proj.qzeros = qzeros self.gate_up_proj.scales = scales self.gate_up_proj.g_idx = g_idx self.gate_up_proj.bias = bias else: qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1) del gate_proj.qweight del up_proj.qweight zeros = torch.cat([gate_proj.zeros, up_proj.zeros], dim=0) del gate_proj.zeros del up_proj.zeros scales = torch.cat([gate_proj.scales, up_proj.scales], dim=0) del gate_proj.scales del up_proj.scales bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None if gate_proj.bias is not None: del gate_proj.bias del up_proj.bias torch.cuda.empty_cache() self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features, out_features=gate_proj.out_features + up_proj.out_features, groupsize=gate_proj.groupsize, is_v1_model=True) self.gate_up_proj.qweight = qweights self.gate_up_proj.zeros = zeros self.gate_up_proj.scales = scales self.gate_up_proj.bias = bias self.down_proj = old_module.down_proj self.act_fn = old_module.act_fn self.intermediate_size = gate_proj.out_features def forward(self, x): intermediate = self.gate_up_proj(x) gate, up = torch.split(intermediate, self.intermediate_size, dim=-1) return self.down_proj(self.act_fn(gate) * up) def make_fused_mlp(m, parent_name='', is_v1_model=False): """ Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations. """ if isinstance(m, LlamaMLP): return QuantLlamaMLP(m, is_v1_model=is_v1_model) for name, child in m.named_children(): child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}", is_v1_model=is_v1_model) if isinstance(child, QuantLlamaMLP): setattr(m, name, child) return m class CustomLoraLayerMerged(torch.nn.Module): def __init__(self, lora_A, lora_B): super().__init__() self.lora_A = torch.nn.Parameter(lora_A, requires_grad=False) self.lora_B = torch.nn.Parameter(lora_B, requires_grad=False) def forward(self, x): out = torch.einsum('bjm,ndm,nkd->nbjk', x, self.lora_A, self.lora_B) return out class LoraInjectionWrapper: def __init__(self, module, lora_layer): self.module = module self.lora_layer = lora_layer def apply(self): self.module.forward_before_lora = self.module.forward self.module.forward = self.forward_with_lora self.module.is_lora_injected = True def forward_with_lora(self, x): result = self.module.forward_before_lora(x) lora_out = self.lora_layer(x) q, v = lora_out[0], lora_out[1] dim = self.module.out_features // 3 result[:, :, :dim] += q result[:, :, -dim:] += v return result def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16): print('Device: {}, dtype: {}'.format(device, dtype)) with open(lora_path + '/adapter_config.json', 'r') as file: lora_config = json.load(file) scaling = lora_config['lora_alpha'] / lora_config['r'] lora_weight_dic = {} dic = torch.load(lora_path + '/adapter_model.bin') for k, v in dic.items(): k_new = k.replace('base_model.model.', '') prefix = re.findall('^model\.layers\.\d+\.', k_new)[0] k_new = k_new.replace(prefix, '') if prefix not in lora_weight_dic.keys(): lora_weight_dic[prefix] = {} lora_weight_dic[prefix][k_new] = v lora_layers = {} for prefix, lora_weight_dic_tmp in lora_weight_dic.items(): k1 = 'self_attn.q_proj.lora_A.weight' k2 = 'self_attn.q_proj.lora_B.weight' k3 = 'self_attn.v_proj.lora_A.weight' k4 = 'self_attn.v_proj.lora_B.weight' lora_A_q = lora_weight_dic_tmp[k1].to(device=device, dtype=dtype) lora_B_q = lora_weight_dic_tmp[k2].to(device=device, dtype=dtype) lora_A_v = lora_weight_dic_tmp[k3].to(device=device, dtype=dtype) lora_B_v = lora_weight_dic_tmp[k4].to(device=device, dtype=dtype) loraA_weight = torch.concat([lora_A_q.unsqueeze(0), lora_A_v.unsqueeze(0)], dim=0) loraB_weight = torch.concat([lora_B_q.unsqueeze(0), lora_B_v.unsqueeze(0)], dim=0) loraA_weight *= scaling lora_layer = CustomLoraLayerMerged(loraA_weight, loraB_weight) lora_layer = lora_layer.to(device=device, dtype=dtype) lora_layers[prefix] = lora_layer # Injection wrappers = [] for n, m in model.named_modules(): if 'qkv_proj' in n and isinstance(m, Autograd4bitQuantLinear): # restoring forward if hasattr(m, 'is_lora_injected') and m.is_lora_injected: m.forward = m.forward_before_lora prefix = re.findall('^model\.layers\.\d+\.', n)[0] lora_layer = lora_layers[prefix] wrapper = LoraInjectionWrapper(m, lora_layer) wrapper.apply() wrappers.append(wrapper) print('Lora Injected.') return wrappers