From bff039de95a0b3f038515c253e04eaaf56a5d5c3 Mon Sep 17 00:00:00 2001 From: John Smith Date: Tue, 28 Mar 2023 20:33:55 +0800 Subject: [PATCH] add v2 model support --- Finetune4bConfig.py | 4 +- arg_parser.py | 6 +- autograd_4bit.py | 200 +++++-------------- finetune.py | 5 +- inference.py | 8 +- matmul_utils_4bit.py | 139 +++++++++++++ text-generation-webui/custom_monkey_patch.py | 5 +- 7 files changed, 213 insertions(+), 154 deletions(-) create mode 100644 matmul_utils_4bit.py diff --git a/Finetune4bConfig.py b/Finetune4bConfig.py index 3514060..b583e4a 100644 --- a/Finetune4bConfig.py +++ b/Finetune4bConfig.py @@ -13,7 +13,7 @@ class Finetune4bConfig: gradient_checkpointing: bool, gradient_checkpointing_ratio: float, warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int, - checkpoint: bool, skip: bool + checkpoint: bool, skip: bool, groupsize: int ): """ Args: @@ -40,6 +40,7 @@ class Finetune4bConfig: logging_steps (int): Logging steps checkpoint (bool): Produce checkpoint instead of LoRA skip (bool): Don't train model + groupsize (int): Group size of V2 model, use -1 to load V1 model """ self.dataset = dataset self.ds_type = ds_type @@ -71,6 +72,7 @@ class Finetune4bConfig: self.device_map = "auto" if not self.ddp else {"": self.local_rank} if self.ddp: self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size + self.groupsize = groupsize def __str__(self) -> str: diff --git a/arg_parser.py b/arg_parser.py index bb25086..0f80eb4 100644 --- a/arg_parser.py +++ b/arg_parser.py @@ -53,6 +53,9 @@ def parse_commandline(): parser_training.add_argument("-c", "--checkpoint", action="store_true", help="Produce checkpoint instead of LoRA. Default: %(default)s") parser_training.add_argument("--skip", action="store_true", help="Don't train model. Can be useful to produce checkpoint from existing LoRA. Default: %(default)s") + # V2 model support + parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model, use -1 to load v1 model") + return vars(parser.parse_args()) @@ -81,5 +84,6 @@ def get_config() -> Finetune4bConfig: save_total_limit=args["save_total_limit"], logging_steps=args["logging_steps"], checkpoint=args["checkpoint"], - skip=args["skip"] + skip=args["skip"], + groupsize=args["groupsize"] ) diff --git a/autograd_4bit.py b/autograd_4bit.py index 731ad73..3b12af8 100644 --- a/autograd_4bit.py +++ b/autograd_4bit.py @@ -1,169 +1,70 @@ -from gptq_llama import quant +import matmul_utils_4bit as mm4b import torch -import numpy as np import torch.nn as nn import time +import math +from safetensors import safe_open -# Global Buffer -buffer_mat_dic = {} -use_new = True -auto_switch = True -auto_switch_thd = 16 - - -def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'): - if shape_of_qweight not in buffer_mat_dic.keys(): - buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device) - elif buffer_mat_dic[shape_of_qweight].device != device: - buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device) - return buffer_mat_dic[shape_of_qweight] - - -def matmul4bit(x, qweight, scales, zeros): - """ - input x: (n, m) - qweight: (j, k) - where m == j*8 - - perform x @ qweight - - return y: - """ - assert qweight.shape[0] * 8 == x.shape[-1] - outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]]) - x = x.reshape(-1, x.shape[-1]) - y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) - dtype = x.dtype - x = x.float() - quant.quant_cuda.vecquant4matmul(x, qweight, y, scales, zeros) - y = y.to(dtype) - return y.reshape(outshape) - - -def matmul4bit_transpose(x, qweight, scales, zeros): - """ - input x: (n, m) - qweight: (j, k) - where m == k - - perform qweight @ x.T - - return y: - """ - assert qweight.shape[1] == x.shape[-1] - outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8]) - x = x.reshape(-1, x.shape[-1]) - y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=torch.float32, device=x.device) - dtype = x.dtype - x = x.float() - quant.quant_cuda.vecquant4transposematmul(x, qweight, y, scales, zeros) - y = y.to(dtype) - return y.reshape(outshape) - - -def matmul4bit_half(x, qweight, scales, zeros): - """ - input x: (n, m) - qweight: (j, k) - where m == j*8 - - perform x @ qweight - - return y: - """ - assert qweight.shape[0] * 8 == x.shape[-1] - outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]]) - x = x.reshape(-1, x.shape[-1]) - y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=x.dtype, device=x.device) - dtype = x.dtype - quant.quant_cuda.vecquant4matmul_half(x, qweight, y, scales, zeros) - y = y.to(dtype) - return y.reshape(outshape) - - -def matmul4bit_transpose_half(x, qweight, scales, zeros): - """ - input x: (n, m) - qweight: (j, k) - where m == k - - perform qweight @ x.T - - return y: - """ - assert qweight.shape[1] == x.shape[-1] - outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8]) - x = x.reshape(-1, x.shape[-1]) - y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=x.dtype, device=x.device) - dtype = x.dtype - quant.quant_cuda.vecquant4transposematmul_half(x, qweight, y, scales, zeros) - y = y.to(dtype) - return y.reshape(outshape) - - -def fast_4bit_forward(x, qweight, scales, zeros, bias): - use_new_flag = use_new - if auto_switch: - if x.shape[1] > auto_switch_thd: - use_new_flag = True - else: - use_new_flag = False - if use_new_flag: - buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device) - quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros) - output = torch.matmul(x, buffer) - else: - output = matmul4bit(x, qweight, scales.float(), zeros.float()) - output += bias - return output - - class AutogradMatmul4bit(torch.autograd.Function): @staticmethod - def forward(ctx, x, qweight, scales, zeros): - ctx.save_for_backward(qweight, scales, zeros) - buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device) - quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros) - output = torch.matmul(x, buffer).clone() + def forward(ctx, x, qweight, scales, zeros, groupsize=-1): + ctx.save_for_backward(qweight, scales, zeros, groupsize) + if groupsize == -1: + output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros) + else: + output = mm4b._matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize) + output = output.clone() return output @staticmethod def backward(ctx, grad_output): - qweight, scales, zeros = ctx.saved_tensors - buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device) - quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros) - grad = torch.matmul(grad_output, buffer.T) - return grad, None, None, None + qweight, scales, zeros, groupsize = ctx.saved_tensors + if groupsize == -1: + grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True) + else: + grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, groupsize=groupsize, transpose=True) + return grad, None, None, None, None # Assumes layer is perfectly divisible into 256 * 256 blocks class Autograd4bitQuantLinear(nn.Module): - def __init__(self, infeatures, outfeatures): + def __init__(self, infeatures, outfeatures, groupsize=-1): super().__init__() bits = 4 self.in_features = infeatures self.out_features = outfeatures self.bits = bits - self.register_buffer('zeros', torch.empty((outfeatures, 1))) - self.register_buffer('scales', torch.empty((outfeatures, 1))) + self.groupsize = groupsize + if groupsize == -1: + self.register_buffer('zeros', torch.empty((outfeatures, 1))) + self.register_buffer('scales', torch.empty((outfeatures, 1))) + else: + self.register_buffer('qzeros', + torch.empty((math.ceil(infeatures/groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int) + ) + self.register_buffer('scales', torch.empty((math.ceil(infeatures/groupsize),outfeatures))) self.register_buffer('bias', torch.empty(outfeatures)) self.register_buffer( 'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int) ) + def forward(self, x): if torch.is_grad_enabled(): - out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, self.zeros) + out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, + self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize) out += self.bias else: - out = fast_4bit_forward(x, self.qweight, self.scales, self.zeros, self.bias) + out = mm4b.matmul4bit(x, self.qweight, self.scales, + self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize) + out += self.bias return out -def make_quant_for_4bit_autograd(module, names, name=''): +def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1): if isinstance(module, Autograd4bitQuantLinear): return for attr in dir(module): @@ -171,17 +72,18 @@ def make_quant_for_4bit_autograd(module, names, name=''): name1 = name + '.' + attr if name != '' else attr if name1 in names: setattr( - module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features) + module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize) ) for name1, child in module.named_children(): - make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1) + make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize) def model_to_half(model): model.half() for n, m in model.named_modules(): if isinstance(m, Autograd4bitQuantLinear): - m.zeros = m.zeros.half() + if m.groupsize == -1: + m.zeros = m.zeros.half() m.scales = m.scales.half() m.bias = m.bias.half() print('Converted as Half.') @@ -191,34 +93,40 @@ def model_to_float(model): model.float() for n, m in model.named_modules(): if isinstance(m, Autograd4bitQuantLinear): - m.zeros = m.zeros.float() + if m.groupsize == -1: + m.zeros = m.zeros.float() m.scales = m.scales.float() m.bias = m.bias.float() print('Converted as Float.') -def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_map="auto"): - import transformers +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers( + child, layers=layers, name=name + '.' + name1 if name != '' else name1 + )) + return res + + +def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048): import accelerate from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer - from gptq_llama.modelutils import find_layers print("Loading Model ...") t0 = time.time() with accelerate.init_empty_weights(): config = LlamaConfig.from_pretrained(config_path) - torch.set_default_dtype(torch.half) - transformers.modeling_utils._init_weights = False - torch.set_default_dtype(torch.half) model = LlamaForCausalLM(config) - torch.set_default_dtype(torch.float) model = model.eval() layers = find_layers(model) for name in ['lm_head']: if name in layers: del layers[name] - make_quant_for_4bit_autograd(model, layers) + make_quant_for_4bit_autograd(model, layers, groupsize=groupsize) model = accelerate.load_checkpoint_and_dispatch( model=model, checkpoint=model_path, @@ -226,7 +134,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_ma no_split_module_classes=["LlamaDecoderLayer"] ) - model.seqlen = 2048 + model.seqlen = seqlen if half: model_to_half(model) @@ -237,4 +145,4 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_ma print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") return model, tokenizer - + \ No newline at end of file diff --git a/finetune.py b/finetune.py index 32124f9..522ac90 100644 --- a/finetune.py +++ b/finetune.py @@ -42,7 +42,10 @@ if ft_config.gradient_checkpointing: print('Disable Dropout.') # Load Basic Model -model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model, device_map=ft_config.device_map) +model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, + ft_config.llama_q4_model, + device_map=ft_config.device_map, + groupsize=ft_config.groupsize) # Config Lora lora_config = LoraConfig( diff --git a/inference.py b/inference.py index 7b172ae..84ade3d 100644 --- a/inference.py +++ b/inference.py @@ -2,16 +2,18 @@ import os import sys import time import torch -from autograd_4bit import load_llama_model_4bit_low_ram +from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear config_path = './llama-13b-4bit/' model_path = './llama-13b-4bit.pt' model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path) print('Fitting 4bit scales and zeros to half') for n, m in model.named_modules(): - if '4bit' in str(type(m)): - m.zeros = m.zeros.half() + if isinstance(m, Autograd4bitQuantLinear): + if m.groupsize == -1: + m.zeros = m.zeros.half() m.scales = m.scales.half() + m.bias = m.bias.half() prompt = '''I think the meaning of life is''' batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) diff --git a/matmul_utils_4bit.py b/matmul_utils_4bit.py new file mode 100644 index 0000000..31bb6a4 --- /dev/null +++ b/matmul_utils_4bit.py @@ -0,0 +1,139 @@ +import torch +import numpy as np +import quant_cuda + + +# Global Buffer +buffer_mat_dic = {} +use_new = True +auto_switch = True +auto_switch_thd = 8 +debug = False + + +def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'): + if shape_of_qweight not in buffer_mat_dic.keys(): + buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device) + else: + if buffer_mat_dic[shape_of_qweight].device != device: + buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device) + if buffer_mat_dic[shape_of_qweight].dtype != dtype: + buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(dtype=dtype) + return buffer_mat_dic[shape_of_qweight] + + +def _matmul4bit_v1(x, qweight, scales, zeros): + """ + input x: (n, m) + qweight: (j, k) + where m == j*8 + + perform x @ qweight + + return y: + """ + if debug: + print('_matmul4bit_v1') + assert qweight.shape[0] * 8 == x.shape[-1] + outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]]) + x = x.reshape(-1, x.shape[-1]) + y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) + dtype = x.dtype + x = x.half() + quant_cuda.vecquant4matmul_v1_faster(x, qweight, y, scales, zeros) + y = y.to(dtype) + return y.reshape(outshape) + + +def _matmul4bit_v2(x, qweight, scales, zeros, group_size): + """ + input x: (n, m) + qweight: (j, k) + where m == j*8 + + perform x @ qweight + + return y: + """ + if debug: + print('_matmul4bit_v2') + assert qweight.shape[0] * 8 == x.shape[-1] + outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]]) + x = x.reshape(-1, x.shape[-1]) + y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) + dtype = x.dtype + x = x.half() + quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, group_size, x.shape[-1] // 2) + y = y.to(dtype) + return y.reshape(outshape) + + +def _matmul4bit_v1_recons(x, qweight, scales, zeros, transpose=False): + if debug: + print('_matmul4bit_v1_recons') + if not transpose: + assert qweight.shape[0] * 8 == x.shape[-1] + else: + assert qweight.shape[1] == x.shape[-1] + buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device) + quant_cuda.vecquant4recons_v1(qweight, buffer, scales, zeros) + if not transpose: + output = torch.matmul(x, buffer) + else: + output = torch.matmul(x, buffer.T) + return output + + +def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False): + if debug: + print('_matmul4bit_v2_recons') + if not transpose: + assert qweight.shape[0] * 8 == x.shape[-1] + else: + assert qweight.shape[1] == x.shape[-1] + buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device) + quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, group_size) + if not transpose: + output = torch.matmul(x, buffer) + if transpose: + output = torch.matmul(x, buffer.T) + return output + + +def matmul4bit(x, qweight, scales, zeros, group_size=-1): + if group_size == -1: + # use v1 + if use_new: + if auto_switch: + if np.prod(x.shape[:-1]) > auto_switch_thd: + output = _matmul4bit_v1_recons(x, qweight, scales, zeros) + else: + output = _matmul4bit_v1(x, qweight, scales, zeros) + else: + output = _matmul4bit_v1(x, qweight, scales, zeros) + else: + # use v2 + if use_new: + if auto_switch: + if np.prod(x.shape[:-1]) > auto_switch_thd: + output = _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size) + else: + output = _matmul4bit_v2(x, qweight, scales, zeros, group_size) + else: + output = _matmul4bit_v2(x, qweight, scales, zeros, group_size) + return output + + +def v2_to_v1(scales, zeros): + """ + Convert zeros in V2 model to V1 model when group_num = 1, for debugging + """ + assert zeros.shape[0] == 1 + z_mat = torch.zeros((zeros.shape[1], 256), dtype=torch.int, device=zeros.device) + zeros.reshape((-1,1)) + z_buffer = torch.zeros((z_mat.shape[0] * 8, z_mat.shape[1]), dtype=torch.float16, device=zeros.device) + z_zeros = torch.zeros(z_mat.shape[1], dtype=torch.float16, device=zeros.device) + z_scales = torch.ones(z_mat.shape[1], dtype=torch.float16, device=zeros.device) + quant_cuda.vecquant4recons_v1(z_mat, z_buffer, z_scales, z_zeros) + z_buffer = z_buffer[:,0] + zeros_recons = z_buffer * scales + scales + return zeros_recons diff --git a/text-generation-webui/custom_monkey_patch.py b/text-generation-webui/custom_monkey_patch.py index 139d4f6..6f586e3 100644 --- a/text-generation-webui/custom_monkey_patch.py +++ b/text-generation-webui/custom_monkey_patch.py @@ -14,7 +14,7 @@ def load_model_llama(*args, **kwargs): print("Loading {} ...".format(model_path)) t0 = time.time() - model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path) + model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1) model = PeftModel.from_pretrained(model, lora_path, device_map={'': 0}, torch_dtype=torch.float32) print('{} Lora Applied.'.format(lora_path)) @@ -22,7 +22,8 @@ def load_model_llama(*args, **kwargs): print('Apply auto switch and half') for n, m in model.named_modules(): if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt): - m.zeros = m.zeros.half() + if m.groupsize == -1: + m.zeros = m.zeros.half() m.scales = m.scales.half() m.bias = m.bias.half() autograd_4bit.use_new = True