336 lines
14 KiB
Python
336 lines
14 KiB
Python
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, LlamaMLP
|
|
from autograd_4bit import Autograd4bitQuantLinear
|
|
import matmul_utils_4bit
|
|
import re
|
|
import json
|
|
import types
|
|
|
|
|
|
class QuantLlamaAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(self,hidden_size,num_heads,qkv_proj,o_proj,rotary_emb,):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.num_heads = num_heads
|
|
self.head_dim = hidden_size // num_heads
|
|
|
|
if (self.head_dim * num_heads) != self.hidden_size:
|
|
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"f" and `num_heads`: {num_heads}).")
|
|
self.qkv_proj = qkv_proj
|
|
self.o_proj = o_proj
|
|
self.rotary_emb = rotary_emb
|
|
|
|
def _shape(self, tensor, seq_len, bsz):
|
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
|
|
def forward(self,hidden_states,past_key_value = None,attention_mask = None,position_ids = None, output_attentions = False,use_cache= False):
|
|
"""Input shape: Batch x Time x Channel"""
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
qkv_states = self.qkv_proj(hidden_states)
|
|
query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2)
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
kv_seq_len = key_states.shape[-2]
|
|
if past_key_value is not None:
|
|
kv_seq_len += past_key_value[0].shape[-2]
|
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
# [bsz, nh, t, hd]
|
|
|
|
is_causal = past_key_value is None
|
|
if past_key_value is not None:
|
|
# reuse k, v, self_attention
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
|
|
past_key_value = (key_states, value_states) if use_cache else None
|
|
|
|
with torch.backends.cuda.sdp_kernel(enable_math=False):
|
|
attn_output = F.scaled_dot_product_attention(query_states,key_states,value_states,is_causal=is_causal)
|
|
|
|
attn_output = attn_output.transpose(1, 2)
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
def make_quant_attn(model, is_v1_model=False):
|
|
"""
|
|
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
|
"""
|
|
print('Turning off matmul cache ...')
|
|
matmul_utils_4bit.cache_buffer = False
|
|
for name, m in model.named_modules():
|
|
if not isinstance(m, LlamaAttention):
|
|
continue
|
|
|
|
q_proj = m.q_proj
|
|
k_proj = m.k_proj
|
|
v_proj = m.v_proj
|
|
|
|
if not is_v1_model:
|
|
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
|
del q_proj.qweight
|
|
del k_proj.qweight
|
|
del v_proj.qweight
|
|
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
|
|
del q_proj.qzeros
|
|
del k_proj.qzeros
|
|
del v_proj.qzeros
|
|
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
|
del q_proj.scales
|
|
del k_proj.scales
|
|
del v_proj.scales
|
|
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
|
del q_proj.g_idx
|
|
del k_proj.g_idx
|
|
del v_proj.g_idx
|
|
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
|
|
if q_proj.bias is not None:
|
|
del q_proj.bias
|
|
del k_proj.bias
|
|
del v_proj.bias
|
|
torch.cuda.empty_cache()
|
|
|
|
qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features,
|
|
out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features,
|
|
groupsize=q_proj.groupsize,
|
|
is_v1_model=False)
|
|
qkv_layer.qweight = qweights
|
|
qkv_layer.qzeros = qzeros
|
|
qkv_layer.scales = scales
|
|
qkv_layer.g_idx = g_idx
|
|
qkv_layer.bias = bias
|
|
else:
|
|
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
|
del q_proj.qweight
|
|
del k_proj.qweight
|
|
del v_proj.qweight
|
|
zeros = torch.cat([q_proj.zeros, k_proj.zeros, v_proj.zeros], dim=0)
|
|
del q_proj.zeros
|
|
del k_proj.zeros
|
|
del v_proj.zeros
|
|
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
|
|
del q_proj.scales
|
|
del k_proj.scales
|
|
del v_proj.scales
|
|
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
|
|
if q_proj.bias is not None:
|
|
del q_proj.bias
|
|
del k_proj.bias
|
|
del v_proj.bias
|
|
torch.cuda.empty_cache()
|
|
|
|
qkv_layer = Autograd4bitQuantLinear(in_features=q_proj.in_features,
|
|
out_features=q_proj.out_features + k_proj.out_features + v_proj.out_features,
|
|
groupsize=-1,
|
|
is_v1_model=True)
|
|
qkv_layer.qweight = qweights
|
|
qkv_layer.zeros = zeros
|
|
qkv_layer.scales = scales
|
|
qkv_layer.bias = bias
|
|
|
|
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb)
|
|
|
|
if '.' in name:
|
|
parent_name = name.rsplit('.', 1)[0]
|
|
child_name = name[len(parent_name) + 1:]
|
|
parent = model.get_submodule(parent_name)
|
|
else:
|
|
parent_name = ''
|
|
parent = model
|
|
child_name = name
|
|
|
|
#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
|
|
|
|
setattr(parent, child_name, attn)
|
|
|
|
|
|
class QuantLlamaMLP(nn.Module):
|
|
def __init__(self, old_module, is_v1_model=False):
|
|
super().__init__()
|
|
|
|
gate_proj = old_module.gate_proj
|
|
up_proj = old_module.up_proj
|
|
|
|
if not is_v1_model:
|
|
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
|
|
del gate_proj.qweight
|
|
del up_proj.qweight
|
|
qzeros = torch.cat([gate_proj.qzeros, up_proj.qzeros], dim=1)
|
|
del gate_proj.qzeros
|
|
del up_proj.qzeros
|
|
scales = torch.cat([gate_proj.scales, up_proj.scales], dim=1)
|
|
del gate_proj.scales
|
|
del up_proj.scales
|
|
g_idx = torch.cat([gate_proj.g_idx, up_proj.g_idx], dim=0)
|
|
del gate_proj.g_idx
|
|
del up_proj.g_idx
|
|
bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None
|
|
if gate_proj.bias is not None:
|
|
del gate_proj.bias
|
|
del up_proj.bias
|
|
torch.cuda.empty_cache()
|
|
|
|
self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features,
|
|
out_features=gate_proj.out_features + up_proj.out_features,
|
|
groupsize=gate_proj.groupsize,
|
|
is_v1_model=False)
|
|
self.gate_up_proj.qweight = qweights
|
|
self.gate_up_proj.qzeros = qzeros
|
|
self.gate_up_proj.scales = scales
|
|
self.gate_up_proj.g_idx = g_idx
|
|
self.gate_up_proj.bias = bias
|
|
else:
|
|
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
|
|
del gate_proj.qweight
|
|
del up_proj.qweight
|
|
zeros = torch.cat([gate_proj.zeros, up_proj.zeros], dim=0)
|
|
del gate_proj.zeros
|
|
del up_proj.zeros
|
|
scales = torch.cat([gate_proj.scales, up_proj.scales], dim=0)
|
|
del gate_proj.scales
|
|
del up_proj.scales
|
|
bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None
|
|
if gate_proj.bias is not None:
|
|
del gate_proj.bias
|
|
del up_proj.bias
|
|
torch.cuda.empty_cache()
|
|
|
|
self.gate_up_proj = Autograd4bitQuantLinear(in_features=gate_proj.in_features,
|
|
out_features=gate_proj.out_features + up_proj.out_features,
|
|
groupsize=gate_proj.groupsize,
|
|
is_v1_model=True)
|
|
self.gate_up_proj.qweight = qweights
|
|
self.gate_up_proj.zeros = zeros
|
|
self.gate_up_proj.scales = scales
|
|
self.gate_up_proj.bias = bias
|
|
|
|
self.down_proj = old_module.down_proj
|
|
self.act_fn = old_module.act_fn
|
|
self.intermediate_size = gate_proj.out_features
|
|
|
|
def forward(self, x):
|
|
intermediate = self.gate_up_proj(x)
|
|
gate, up = torch.split(intermediate, self.intermediate_size, dim=-1)
|
|
return self.down_proj(self.act_fn(gate) * up)
|
|
|
|
|
|
def make_fused_mlp(m, parent_name='', is_v1_model=False):
|
|
"""
|
|
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
|
|
"""
|
|
if isinstance(m, LlamaMLP):
|
|
return QuantLlamaMLP(m, is_v1_model=is_v1_model)
|
|
|
|
for name, child in m.named_children():
|
|
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}", is_v1_model=is_v1_model)
|
|
|
|
if isinstance(child, QuantLlamaMLP):
|
|
setattr(m, name, child)
|
|
return m
|
|
|
|
|
|
class CustomLoraLayerMerged(torch.nn.Module):
|
|
|
|
def __init__(self, lora_A, lora_B):
|
|
super().__init__()
|
|
self.lora_A = torch.nn.Parameter(lora_A, requires_grad=False)
|
|
self.lora_B = torch.nn.Parameter(lora_B, requires_grad=False)
|
|
|
|
def forward(self, x):
|
|
out = torch.einsum('bjm,ndm,nkd->nbjk', x, self.lora_A, self.lora_B)
|
|
return out
|
|
|
|
|
|
class LoraInjectionWrapper:
|
|
|
|
def __init__(self, module, lora_layer):
|
|
self.module = module
|
|
self.lora_layer = lora_layer
|
|
|
|
def apply(self):
|
|
self.module.forward_before_lora = self.module.forward
|
|
self.module.forward = self.forward_with_lora
|
|
self.module.is_lora_injected = True
|
|
|
|
def forward_with_lora(self, x):
|
|
result = self.module.forward_before_lora(x)
|
|
lora_out = self.lora_layer(x)
|
|
q, v = lora_out[0], lora_out[1]
|
|
dim = self.module.out_features // 3
|
|
result[:, :, :dim] += q
|
|
result[:, :, -dim:] += v
|
|
return result
|
|
|
|
|
|
def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
|
|
|
|
print('Device: {}, dtype: {}'.format(device, dtype))
|
|
|
|
with open(lora_path + '/adapter_config.json', 'r') as file:
|
|
lora_config = json.load(file)
|
|
scaling = lora_config['lora_alpha'] / lora_config['r']
|
|
|
|
lora_weight_dic = {}
|
|
dic = torch.load(lora_path + '/adapter_model.bin')
|
|
for k, v in dic.items():
|
|
k_new = k.replace('base_model.model.', '')
|
|
prefix = re.findall('^model\.layers\.\d+\.', k_new)[0]
|
|
k_new = k_new.replace(prefix, '')
|
|
if prefix not in lora_weight_dic.keys():
|
|
lora_weight_dic[prefix] = {}
|
|
lora_weight_dic[prefix][k_new] = v
|
|
|
|
lora_layers = {}
|
|
for prefix, lora_weight_dic_tmp in lora_weight_dic.items():
|
|
k1 = 'self_attn.q_proj.lora_A.weight'
|
|
k2 = 'self_attn.q_proj.lora_B.weight'
|
|
k3 = 'self_attn.v_proj.lora_A.weight'
|
|
k4 = 'self_attn.v_proj.lora_B.weight'
|
|
|
|
lora_A_q = lora_weight_dic_tmp[k1].to(device=device, dtype=dtype)
|
|
lora_B_q = lora_weight_dic_tmp[k2].to(device=device, dtype=dtype)
|
|
lora_A_v = lora_weight_dic_tmp[k3].to(device=device, dtype=dtype)
|
|
lora_B_v = lora_weight_dic_tmp[k4].to(device=device, dtype=dtype)
|
|
|
|
loraA_weight = torch.concat([lora_A_q.unsqueeze(0), lora_A_v.unsqueeze(0)], dim=0)
|
|
loraB_weight = torch.concat([lora_B_q.unsqueeze(0), lora_B_v.unsqueeze(0)], dim=0)
|
|
loraA_weight *= scaling
|
|
|
|
lora_layer = CustomLoraLayerMerged(loraA_weight, loraB_weight)
|
|
lora_layer = lora_layer.to(device=device, dtype=dtype)
|
|
lora_layers[prefix] = lora_layer
|
|
|
|
# Injection
|
|
wrappers = []
|
|
for n, m in model.named_modules():
|
|
if 'qkv_proj' in n and isinstance(m, Autograd4bitQuantLinear):
|
|
# restoring forward
|
|
if hasattr(m, 'is_lora_injected') and m.is_lora_injected:
|
|
m.forward = m.forward_before_lora
|
|
prefix = re.findall('^model\.layers\.\d+\.', n)[0]
|
|
lora_layer = lora_layers[prefix]
|
|
wrapper = LoraInjectionWrapper(m, lora_layer)
|
|
wrapper.apply()
|
|
wrappers.append(wrapper)
|
|
|
|
print('Lora Injected.')
|
|
return wrappers
|
|
|