import numpy as np import torch import torch.nn as nn from torch.nn import functional as F from torch.cuda.amp import custom_bwd, custom_fwd from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, LlamaMLP from autograd_4bit import Autograd4bitQuantLinear import matmul_utils_4bit import re import json import types class QuantLlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self,hidden_size,num_heads,qkv_proj,o_proj,rotary_emb,): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads if (self.head_dim * num_heads) != self.hidden_size: raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"f" and `num_heads`: {num_heads}).") self.qkv_proj = qkv_proj self.o_proj = o_proj self.rotary_emb = rotary_emb def _shape(self, tensor, seq_len, bsz): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward(self,hidden_states,past_key_value = None,attention_mask = None,position_ids = None, output_attentions = False,use_cache= False): """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() qkv_states = self.qkv_proj(hidden_states) query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # [bsz, nh, t, hd] is_causal = past_key_value is None if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None with torch.backends.cuda.sdp_kernel(enable_math=False): attn_output = F.scaled_dot_product_attention(query_states,key_states,value_states,is_causal=is_causal) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def make_quant_attn(model): """ Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. """ print('Turning off matmul cache ...') matmul_utils_4bit.cache_buffer = False for name, m in model.named_modules(): if not isinstance(m, LlamaAttention): continue q_proj = m.q_proj k_proj = m.k_proj v_proj = m.v_proj qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None qkv_layer = Autograd4bitQuantLinear(q_proj.in_features, q_proj.out_features + k_proj.out_features + v_proj.out_features, q_proj.groupsize, is_v1_model=False) qkv_layer.qweight = qweights qkv_layer.qzeros = qzeros qkv_layer.scales = scales qkv_layer.g_idx = g_idx qkv_layer.bias = bias attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb) if '.' in name: parent_name = name.rsplit('.', 1)[0] child_name = name[len(parent_name) + 1:] parent = model.get_submodule(parent_name) else: parent_name = '' parent = model child_name = name #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") setattr(parent, child_name, attn) class QuantLlamaMLP(nn.Module): def __init__(self, old_module): super().__init__() gate_proj = old_module.gate_proj up_proj = old_module.up_proj qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1) qzeros = torch.cat([gate_proj.qzeros, up_proj.qzeros], dim=1) scales = torch.cat([gate_proj.scales, up_proj.scales], dim=1) g_idx = torch.cat([gate_proj.g_idx, up_proj.g_idx], dim=0) bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None self.gate_up_proj = Autograd4bitQuantLinear(gate_proj.in_features, gate_proj.out_features + up_proj.out_features, gate_proj.groupsize, is_v1_model=False) self.gate_up_proj.qweight = qweights self.gate_up_proj.qzeros = qzeros self.gate_up_proj.scales = scales self.gate_up_proj.g_idx = g_idx self.gate_up_proj.bias = bias self.down_proj = old_module.down_proj self.act_fn = old_module.act_fn self.intermediate_size = gate_proj.out_features def forward(self, x): intermediate = self.gate_up_proj(x) gate, up = torch.split(intermediate, self.intermediate_size, dim=-1) return self.down_proj(self.act_fn(gate) * up) def make_fused_mlp(m, parent_name=''): """ Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations. """ if isinstance(m, LlamaMLP): return QuantLlamaMLP(m) for name, child in m.named_children(): child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}") if isinstance(child, QuantLlamaMLP): setattr(m, name, child) return m class CustomLoraLayerMerged(torch.nn.Module): def __init__(self, scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v): super().__init__() self.lora_A_q = lora_A_q self.lora_B_q = lora_B_q self.lora_A_v = lora_A_v self.lora_B_v = lora_B_v self.scaling = scaling def forward(self, x): q = self.lora_B_q(self.lora_A_q(x)) * self.scaling v = self.lora_B_v(self.lora_A_v(x)) * self.scaling return q, v def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16): print('Device: {}, dtype: {}'.format(device, dtype)) with open(lora_path + '/adapter_config.json', 'r') as file: lora_config = json.load(file) scaling = lora_config['lora_alpha'] / lora_config['r'] lora_weight_dic = {} dic = torch.load(lora_path + '/adapter_model.bin') for k, v in dic.items(): k_new = k.replace('base_model.model.', '') prefix = re.findall('^model\.layers\.\d+\.', k_new)[0] k_new = k_new.replace(prefix, '') if prefix not in lora_weight_dic.keys(): lora_weight_dic[prefix] = {} lora_weight_dic[prefix][k_new] = v lora_layers = {} for prefix, lora_weight_dic_tmp in lora_weight_dic.items(): k1 = 'self_attn.q_proj.lora_A.weight' k2 = 'self_attn.q_proj.lora_B.weight' k3 = 'self_attn.v_proj.lora_A.weight' k4 = 'self_attn.v_proj.lora_B.weight' weight = lora_weight_dic_tmp[k1] l_dim = weight.shape[0] r_dim = weight.shape[1] lora_A_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) lora_A_q.weight = torch.nn.Parameter(weight, requires_grad=False) weight = lora_weight_dic_tmp[k2] l_dim = weight.shape[0] r_dim = weight.shape[1] lora_B_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) lora_B_q.weight = torch.nn.Parameter(weight, requires_grad=False) weight = lora_weight_dic_tmp[k3] l_dim = weight.shape[0] r_dim = weight.shape[1] lora_A_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) lora_A_v.weight = torch.nn.Parameter(weight, requires_grad=False) weight = lora_weight_dic_tmp[k4] l_dim = weight.shape[0] r_dim = weight.shape[1] lora_B_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) lora_B_v.weight = torch.nn.Parameter(weight, requires_grad=False) lora_layer = CustomLoraLayerMerged(scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v) lora_layer = lora_layer.to(device=device, dtype=dtype) lora_layers[prefix] = lora_layer # Injection for n, m in model.named_modules(): if 'qkv_proj' in n and isinstance(m, Autograd4bitQuantLinear): # restoring forward if hasattr(m, 'is_lora_injected') and m.is_lora_injected: m.forward = m.forward_before_lora prefix = re.findall('^model\.layers\.\d+\.', n)[0] lora_layer = lora_layers[prefix] m.forward_before_lora = m.forward def forward_with_lora(self, x): result = self.forward_before_lora(x) q, v = lora_layer(x) dim = self.out_features // 3 result[:, :, :dim] += q result[:, :, -dim:] += v return result m.forward = types.MethodType(forward_with_lora, m) m.is_lora_injected = True print('Lora Injected.')