optimized attention and mlp for performance, add lora monkey patch for models here and GPTQ_For_Llama models using optimization

This commit is contained in:
John Smith 2023-04-22 15:18:54 +08:00
parent 35caccd376
commit de3c91834e
4 changed files with 365 additions and 3 deletions

3
.gitignore vendored
View File

@ -5,3 +5,6 @@ llama-13b-4bit
llama-13b-4bit.pt
text-generation-webui/
repository/
build/
dist/
*.egg-info*

View File

@ -9,9 +9,12 @@ use_new = True
auto_switch = True
auto_switch_thd = 8
debug = False
faster = True
cache_buffer = True
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
if not cache_buffer:
return torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
if shape_of_qweight not in buffer_mat_dic.keys():
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
else:
@ -62,8 +65,12 @@ def _matmul4bit_v2(x, qweight, scales, zeros, g_idx):
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
dtype = x.dtype
x = x.half()
quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, g_idx, x.shape[-1] // 2)
if faster:
x = x.half()
quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, g_idx, x.shape[-1] // 2)
else:
x = x.float()
quant_cuda.vecquant4matmul(x, qweight, y, scales, zeros, g_idx)
y = y.to(dtype)
return y.reshape(outshape)

256
model_attn_mlp_patch.py Normal file
View File

@ -0,0 +1,256 @@
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, LlamaMLP
from autograd_4bit import Autograd4bitQuantLinear
import matmul_utils_4bit
import re
import json
import types
class QuantLlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self,hidden_size,num_heads,qkv_proj,o_proj,rotary_emb,):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"f" and `num_heads`: {num_heads}).")
self.qkv_proj = qkv_proj
self.o_proj = o_proj
self.rotary_emb = rotary_emb
def _shape(self, tensor, seq_len, bsz):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(self,hidden_states,past_key_value = None,attention_mask = None,position_ids = None, output_attentions = False,use_cache= False):
"""Input shape: Batch x Time x Channel"""
bsz, q_len, _ = hidden_states.size()
qkv_states = self.qkv_proj(hidden_states)
query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
is_causal = past_key_value is None
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
with torch.backends.cuda.sdp_kernel(enable_math=False):
attn_output = F.scaled_dot_product_attention(query_states,key_states,value_states,is_causal=is_causal)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def make_quant_attn(model):
"""
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
print('Turning off matmul cache ...')
matmul_utils_4bit.cache_buffer = False
for name, m in model.named_modules():
if not isinstance(m, LlamaAttention):
continue
q_proj = m.q_proj
k_proj = m.k_proj
v_proj = m.v_proj
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qkv_layer = Autograd4bitQuantLinear(q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.groupsize,
is_v1_model=False)
qkv_layer.qweight = qweights
qkv_layer.qzeros = qzeros
qkv_layer.scales = scales
qkv_layer.g_idx = g_idx
qkv_layer.bias = bias
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb)
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
setattr(parent, child_name, attn)
class QuantLlamaMLP(nn.Module):
def __init__(self, old_module):
super().__init__()
gate_proj = old_module.gate_proj
up_proj = old_module.up_proj
qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1)
qzeros = torch.cat([gate_proj.qzeros, up_proj.qzeros], dim=1)
scales = torch.cat([gate_proj.scales, up_proj.scales], dim=1)
g_idx = torch.cat([gate_proj.g_idx, up_proj.g_idx], dim=0)
bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None
self.gate_up_proj = Autograd4bitQuantLinear(gate_proj.in_features,
gate_proj.out_features + up_proj.out_features,
gate_proj.groupsize,
is_v1_model=False)
self.gate_up_proj.qweight = qweights
self.gate_up_proj.qzeros = qzeros
self.gate_up_proj.scales = scales
self.gate_up_proj.g_idx = g_idx
self.gate_up_proj.bias = bias
self.down_proj = old_module.down_proj
self.act_fn = old_module.act_fn
self.intermediate_size = gate_proj.out_features
def forward(self, x):
intermediate = self.gate_up_proj(x)
gate, up = torch.split(intermediate, self.intermediate_size, dim=-1)
return self.down_proj(self.act_fn(gate) * up)
def make_fused_mlp(m, parent_name=''):
"""
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
"""
if isinstance(m, LlamaMLP):
return QuantLlamaMLP(m)
for name, child in m.named_children():
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
if isinstance(child, QuantLlamaMLP):
setattr(m, name, child)
return m
class CustomLoraLayerMerged(torch.nn.Module):
def __init__(self, scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v):
super().__init__()
self.lora_A_q = lora_A_q
self.lora_B_q = lora_B_q
self.lora_A_v = lora_A_v
self.lora_B_v = lora_B_v
self.scaling = scaling
def forward(self, x):
q = self.lora_B_q(self.lora_A_q(x)) * self.scaling
v = self.lora_B_v(self.lora_A_v(x)) * self.scaling
return q, v
def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
print('Device: {}, dtype: {}'.format(device, dtype))
with open(lora_path + '/adapter_config.json', 'r') as file:
lora_config = json.load(file)
scaling = lora_config['lora_alpha'] / lora_config['r']
lora_weight_dic = {}
dic = torch.load(lora_path + '/adapter_model.bin')
for k, v in dic.items():
k_new = k.replace('base_model.model.', '')
prefix = re.findall('^model\.layers\.\d+\.', k_new)[0]
k_new = k_new.replace(prefix, '')
if prefix not in lora_weight_dic.keys():
lora_weight_dic[prefix] = {}
lora_weight_dic[prefix][k_new] = v
lora_layers = {}
for prefix, lora_weight_dic_tmp in lora_weight_dic.items():
k1 = 'self_attn.q_proj.lora_A.weight'
k2 = 'self_attn.q_proj.lora_B.weight'
k3 = 'self_attn.v_proj.lora_A.weight'
k4 = 'self_attn.v_proj.lora_B.weight'
weight = lora_weight_dic_tmp[k1]
l_dim = weight.shape[0]
r_dim = weight.shape[1]
lora_A_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
lora_A_q.weight = torch.nn.Parameter(weight, requires_grad=False)
weight = lora_weight_dic_tmp[k2]
l_dim = weight.shape[0]
r_dim = weight.shape[1]
lora_B_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
lora_B_q.weight = torch.nn.Parameter(weight, requires_grad=False)
weight = lora_weight_dic_tmp[k3]
l_dim = weight.shape[0]
r_dim = weight.shape[1]
lora_A_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
lora_A_v.weight = torch.nn.Parameter(weight, requires_grad=False)
weight = lora_weight_dic_tmp[k4]
l_dim = weight.shape[0]
r_dim = weight.shape[1]
lora_B_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
lora_B_v.weight = torch.nn.Parameter(weight, requires_grad=False)
lora_layer = CustomLoraLayerMerged(scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v)
lora_layer = lora_layer.to(device=device, dtype=dtype)
lora_layers[prefix] = lora_layer
# Injection
for n, m in model.named_modules():
if 'qkv_proj' in n and isinstance(m, Autograd4bitQuantLinear):
# restoring forward
if hasattr(m, 'is_lora_injected') and m.is_lora_injected:
m.forward = m.forward_before_lora
prefix = re.findall('^model\.layers\.\d+\.', n)[0]
lora_layer = lora_layers[prefix]
m.forward_before_lora = m.forward
def forward_with_lora(self, x):
result = self.forward_before_lora(x)
q, v = lora_layer(x)
dim = self.out_features // 3
result[:, :, :dim] += q
result[:, :, -dim:] += v
return result
m.forward = types.MethodType(forward_with_lora, m)
m.is_lora_injected = True
print('Lora Injected.')

View File

@ -0,0 +1,96 @@
import torch
import re
import json
from quant.quant_linear import QuantLinear # from GPTQ FOR LLAMA
import types
class CustomLoraLayerMerged(torch.nn.Module):
def __init__(self, scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v):
super().__init__()
self.lora_A_q = lora_A_q
self.lora_B_q = lora_B_q
self.lora_A_v = lora_A_v
self.lora_B_v = lora_B_v
self.scaling = scaling
def forward(self, x):
q = self.lora_B_q(self.lora_A_q(x)) * self.scaling
v = self.lora_B_v(self.lora_A_v(x)) * self.scaling
return q, v
def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16):
print('Device: {}, dtype: {}'.format(device, dtype))
with open(lora_path + '/adapter_config.json', 'r') as file:
lora_config = json.load(file)
scaling = lora_config['lora_alpha'] / lora_config['r']
lora_weight_dic = {}
dic = torch.load(lora_path + '/adapter_model.bin')
for k, v in dic.items():
k_new = k.replace('base_model.model.', '')
prefix = re.findall('^model\.layers\.\d+\.', k_new)[0]
k_new = k_new.replace(prefix, '')
if prefix not in lora_weight_dic.keys():
lora_weight_dic[prefix] = {}
lora_weight_dic[prefix][k_new] = v
lora_layers = {}
for prefix, lora_weight_dic_tmp in lora_weight_dic.items():
k1 = 'self_attn.q_proj.lora_A.weight'
k2 = 'self_attn.q_proj.lora_B.weight'
k3 = 'self_attn.v_proj.lora_A.weight'
k4 = 'self_attn.v_proj.lora_B.weight'
weight = lora_weight_dic_tmp[k1]
l_dim = weight.shape[0]
r_dim = weight.shape[1]
lora_A_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
lora_A_q.weight = torch.nn.Parameter(weight, requires_grad=False)
weight = lora_weight_dic_tmp[k2]
l_dim = weight.shape[0]
r_dim = weight.shape[1]
lora_B_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
lora_B_q.weight = torch.nn.Parameter(weight, requires_grad=False)
weight = lora_weight_dic_tmp[k3]
l_dim = weight.shape[0]
r_dim = weight.shape[1]
lora_A_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
lora_A_v.weight = torch.nn.Parameter(weight, requires_grad=False)
weight = lora_weight_dic_tmp[k4]
l_dim = weight.shape[0]
r_dim = weight.shape[1]
lora_B_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False)
lora_B_v.weight = torch.nn.Parameter(weight, requires_grad=False)
lora_layer = CustomLoraLayerMerged(scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v)
lora_layer = lora_layer.to(device=device, dtype=dtype)
lora_layers[prefix] = lora_layer
# Injection
for n, m in model.named_modules():
if 'qkv_proj' in n and isinstance(m, QuantLinear):
# restoring forward
if hasattr(m, 'is_lora_injected') and m.is_lora_injected:
m.forward = m.forward_before_lora
prefix = re.findall('^model\.layers\.\d+\.', n)[0]
lora_layer = lora_layers[prefix]
m.forward_before_lora = m.forward
def forward_with_lora(self, x):
result = self.forward_before_lora(x)
q, v = lora_layer(x)
dim = self.outfeatures // 3
result[:, :, :dim] += q
result[:, :, -dim:] += v
return result
m.forward = types.MethodType(forward_with_lora, m)
m.is_lora_injected = True
print('Lora Injected.')