diff --git a/.gitignore b/.gitignore index e531c1e..49faac3 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ llama-13b-4bit llama-13b-4bit.pt text-generation-webui/ repository/ +build/ +dist/ +*.egg-info* diff --git a/matmul_utils_4bit.py b/matmul_utils_4bit.py index 2aaa0ad..16575c6 100644 --- a/matmul_utils_4bit.py +++ b/matmul_utils_4bit.py @@ -9,9 +9,12 @@ use_new = True auto_switch = True auto_switch_thd = 8 debug = False - +faster = True +cache_buffer = True def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'): + if not cache_buffer: + return torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device) if shape_of_qweight not in buffer_mat_dic.keys(): buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device) else: @@ -62,8 +65,12 @@ def _matmul4bit_v2(x, qweight, scales, zeros, g_idx): x = x.reshape(-1, x.shape[-1]) y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) dtype = x.dtype - x = x.half() - quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, g_idx, x.shape[-1] // 2) + if faster: + x = x.half() + quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, g_idx, x.shape[-1] // 2) + else: + x = x.float() + quant_cuda.vecquant4matmul(x, qweight, y, scales, zeros, g_idx) y = y.to(dtype) return y.reshape(outshape) diff --git a/model_attn_mlp_patch.py b/model_attn_mlp_patch.py new file mode 100644 index 0000000..63e80cb --- /dev/null +++ b/model_attn_mlp_patch.py @@ -0,0 +1,256 @@ +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.cuda.amp import custom_bwd, custom_fwd +from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, LlamaMLP +from autograd_4bit import Autograd4bitQuantLinear +import matmul_utils_4bit +import re +import json +import types + + +class QuantLlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self,hidden_size,num_heads,qkv_proj,o_proj,rotary_emb,): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"f" and `num_heads`: {num_heads}).") + self.qkv_proj = qkv_proj + self.o_proj = o_proj + self.rotary_emb = rotary_emb + + def _shape(self, tensor, seq_len, bsz): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward(self,hidden_states,past_key_value = None,attention_mask = None,position_ids = None, output_attentions = False,use_cache= False): + """Input shape: Batch x Time x Channel""" + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.qkv_proj(hidden_states) + query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + is_causal = past_key_value is None + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + with torch.backends.cuda.sdp_kernel(enable_math=False): + attn_output = F.scaled_dot_product_attention(query_states,key_states,value_states,is_causal=is_causal) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def make_quant_attn(model): + """ + Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. + """ + print('Turning off matmul cache ...') + matmul_utils_4bit.cache_buffer = False + for name, m in model.named_modules(): + if not isinstance(m, LlamaAttention): + continue + + q_proj = m.q_proj + k_proj = m.k_proj + v_proj = m.v_proj + + qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) + qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) + scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) + bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None + + qkv_layer = Autograd4bitQuantLinear(q_proj.in_features, + q_proj.out_features + k_proj.out_features + v_proj.out_features, + q_proj.groupsize, + is_v1_model=False) + qkv_layer.qweight = qweights + qkv_layer.qzeros = qzeros + qkv_layer.scales = scales + qkv_layer.g_idx = g_idx + qkv_layer.bias = bias + + attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb) + + if '.' in name: + parent_name = name.rsplit('.', 1)[0] + child_name = name[len(parent_name) + 1:] + parent = model.get_submodule(parent_name) + else: + parent_name = '' + parent = model + child_name = name + + #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") + + setattr(parent, child_name, attn) + + +class QuantLlamaMLP(nn.Module): + def __init__(self, old_module): + super().__init__() + + gate_proj = old_module.gate_proj + up_proj = old_module.up_proj + + qweights = torch.cat([gate_proj.qweight, up_proj.qweight], dim=1) + qzeros = torch.cat([gate_proj.qzeros, up_proj.qzeros], dim=1) + scales = torch.cat([gate_proj.scales, up_proj.scales], dim=1) + g_idx = torch.cat([gate_proj.g_idx, up_proj.g_idx], dim=0) + bias = torch.cat([gate_proj.bias, up_proj.bias], dim=0) if gate_proj.bias is not None else None + + self.gate_up_proj = Autograd4bitQuantLinear(gate_proj.in_features, + gate_proj.out_features + up_proj.out_features, + gate_proj.groupsize, + is_v1_model=False) + self.gate_up_proj.qweight = qweights + self.gate_up_proj.qzeros = qzeros + self.gate_up_proj.scales = scales + self.gate_up_proj.g_idx = g_idx + self.gate_up_proj.bias = bias + + self.down_proj = old_module.down_proj + self.act_fn = old_module.act_fn + self.intermediate_size = gate_proj.out_features + + def forward(self, x): + intermediate = self.gate_up_proj(x) + gate, up = torch.split(intermediate, self.intermediate_size, dim=-1) + return self.down_proj(self.act_fn(gate) * up) + + +def make_fused_mlp(m, parent_name=''): + """ + Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations. + """ + if isinstance(m, LlamaMLP): + return QuantLlamaMLP(m) + + for name, child in m.named_children(): + child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}") + + if isinstance(child, QuantLlamaMLP): + setattr(m, name, child) + return m + + +class CustomLoraLayerMerged(torch.nn.Module): + + def __init__(self, scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v): + super().__init__() + self.lora_A_q = lora_A_q + self.lora_B_q = lora_B_q + self.lora_A_v = lora_A_v + self.lora_B_v = lora_B_v + self.scaling = scaling + + def forward(self, x): + q = self.lora_B_q(self.lora_A_q(x)) * self.scaling + v = self.lora_B_v(self.lora_A_v(x)) * self.scaling + return q, v + + +def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16): + + print('Device: {}, dtype: {}'.format(device, dtype)) + + with open(lora_path + '/adapter_config.json', 'r') as file: + lora_config = json.load(file) + scaling = lora_config['lora_alpha'] / lora_config['r'] + + lora_weight_dic = {} + dic = torch.load(lora_path + '/adapter_model.bin') + for k, v in dic.items(): + k_new = k.replace('base_model.model.', '') + prefix = re.findall('^model\.layers\.\d+\.', k_new)[0] + k_new = k_new.replace(prefix, '') + if prefix not in lora_weight_dic.keys(): + lora_weight_dic[prefix] = {} + lora_weight_dic[prefix][k_new] = v + + lora_layers = {} + for prefix, lora_weight_dic_tmp in lora_weight_dic.items(): + k1 = 'self_attn.q_proj.lora_A.weight' + k2 = 'self_attn.q_proj.lora_B.weight' + k3 = 'self_attn.v_proj.lora_A.weight' + k4 = 'self_attn.v_proj.lora_B.weight' + + weight = lora_weight_dic_tmp[k1] + l_dim = weight.shape[0] + r_dim = weight.shape[1] + lora_A_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) + lora_A_q.weight = torch.nn.Parameter(weight, requires_grad=False) + + weight = lora_weight_dic_tmp[k2] + l_dim = weight.shape[0] + r_dim = weight.shape[1] + lora_B_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) + lora_B_q.weight = torch.nn.Parameter(weight, requires_grad=False) + + weight = lora_weight_dic_tmp[k3] + l_dim = weight.shape[0] + r_dim = weight.shape[1] + lora_A_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) + lora_A_v.weight = torch.nn.Parameter(weight, requires_grad=False) + + weight = lora_weight_dic_tmp[k4] + l_dim = weight.shape[0] + r_dim = weight.shape[1] + lora_B_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) + lora_B_v.weight = torch.nn.Parameter(weight, requires_grad=False) + + lora_layer = CustomLoraLayerMerged(scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v) + lora_layer = lora_layer.to(device=device, dtype=dtype) + lora_layers[prefix] = lora_layer + + # Injection + for n, m in model.named_modules(): + if 'qkv_proj' in n and isinstance(m, Autograd4bitQuantLinear): + # restoring forward + if hasattr(m, 'is_lora_injected') and m.is_lora_injected: + m.forward = m.forward_before_lora + prefix = re.findall('^model\.layers\.\d+\.', n)[0] + lora_layer = lora_layers[prefix] + m.forward_before_lora = m.forward + def forward_with_lora(self, x): + result = self.forward_before_lora(x) + q, v = lora_layer(x) + dim = self.out_features // 3 + result[:, :, :dim] += q + result[:, :, -dim:] += v + return result + m.forward = types.MethodType(forward_with_lora, m) + m.is_lora_injected = True + + print('Lora Injected.') + \ No newline at end of file diff --git a/monkeypatch/gptq_for_llala_lora_monkey_patch.py b/monkeypatch/gptq_for_llala_lora_monkey_patch.py new file mode 100644 index 0000000..29802c4 --- /dev/null +++ b/monkeypatch/gptq_for_llala_lora_monkey_patch.py @@ -0,0 +1,96 @@ +import torch +import re +import json +from quant.quant_linear import QuantLinear # from GPTQ FOR LLAMA +import types + + +class CustomLoraLayerMerged(torch.nn.Module): + + def __init__(self, scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v): + super().__init__() + self.lora_A_q = lora_A_q + self.lora_B_q = lora_B_q + self.lora_A_v = lora_A_v + self.lora_B_v = lora_B_v + self.scaling = scaling + + def forward(self, x): + q = self.lora_B_q(self.lora_A_q(x)) * self.scaling + v = self.lora_B_v(self.lora_A_v(x)) * self.scaling + return q, v + + +def inject_lora_layers(model, lora_path, device='cuda', dtype=torch.float16): + + print('Device: {}, dtype: {}'.format(device, dtype)) + + with open(lora_path + '/adapter_config.json', 'r') as file: + lora_config = json.load(file) + scaling = lora_config['lora_alpha'] / lora_config['r'] + + lora_weight_dic = {} + dic = torch.load(lora_path + '/adapter_model.bin') + for k, v in dic.items(): + k_new = k.replace('base_model.model.', '') + prefix = re.findall('^model\.layers\.\d+\.', k_new)[0] + k_new = k_new.replace(prefix, '') + if prefix not in lora_weight_dic.keys(): + lora_weight_dic[prefix] = {} + lora_weight_dic[prefix][k_new] = v + + lora_layers = {} + for prefix, lora_weight_dic_tmp in lora_weight_dic.items(): + k1 = 'self_attn.q_proj.lora_A.weight' + k2 = 'self_attn.q_proj.lora_B.weight' + k3 = 'self_attn.v_proj.lora_A.weight' + k4 = 'self_attn.v_proj.lora_B.weight' + + weight = lora_weight_dic_tmp[k1] + l_dim = weight.shape[0] + r_dim = weight.shape[1] + lora_A_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) + lora_A_q.weight = torch.nn.Parameter(weight, requires_grad=False) + + weight = lora_weight_dic_tmp[k2] + l_dim = weight.shape[0] + r_dim = weight.shape[1] + lora_B_q = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) + lora_B_q.weight = torch.nn.Parameter(weight, requires_grad=False) + + weight = lora_weight_dic_tmp[k3] + l_dim = weight.shape[0] + r_dim = weight.shape[1] + lora_A_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) + lora_A_v.weight = torch.nn.Parameter(weight, requires_grad=False) + + weight = lora_weight_dic_tmp[k4] + l_dim = weight.shape[0] + r_dim = weight.shape[1] + lora_B_v = torch.nn.Linear(in_features=r_dim, out_features=l_dim, bias=False) + lora_B_v.weight = torch.nn.Parameter(weight, requires_grad=False) + + lora_layer = CustomLoraLayerMerged(scaling, lora_A_q, lora_B_q, lora_A_v, lora_B_v) + lora_layer = lora_layer.to(device=device, dtype=dtype) + lora_layers[prefix] = lora_layer + + # Injection + for n, m in model.named_modules(): + if 'qkv_proj' in n and isinstance(m, QuantLinear): + # restoring forward + if hasattr(m, 'is_lora_injected') and m.is_lora_injected: + m.forward = m.forward_before_lora + prefix = re.findall('^model\.layers\.\d+\.', n)[0] + lora_layer = lora_layers[prefix] + m.forward_before_lora = m.forward + def forward_with_lora(self, x): + result = self.forward_before_lora(x) + q, v = lora_layer(x) + dim = self.outfeatures // 3 + result[:, :, :dim] += q + result[:, :, -dim:] += v + return result + m.forward = types.MethodType(forward_with_lora, m) + m.is_lora_injected = True + + print('Lora Injected.')