diff --git a/GPTQ-for-LLaMa/autograd_4bit.py b/GPTQ-for-LLaMa/autograd_4bit.py index 34bf9d7..09202f7 100644 --- a/GPTQ-for-LLaMa/autograd_4bit.py +++ b/GPTQ-for-LLaMa/autograd_4bit.py @@ -2,6 +2,7 @@ import quant import torch import numpy as np import torch.nn as nn +import time def matmul4bit(x, qweight, scales, zeros): @@ -46,24 +47,75 @@ def matmul4bit_transpose(x, qweight, scales, zeros): return y.reshape(outshape) +def matmul4bit_half(x, qweight, scales, zeros): + """ + input x: (n, m) + qweight: (j, k) + where m == j*8 + + perform x @ qweight + + return y: + """ + assert qweight.shape[0] * 8 == x.shape[-1] + outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]]) + x = x.reshape(-1, x.shape[-1]) + y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=x.dtype, device=x.device) + dtype = x.dtype + quant.quant_cuda.vecquant4matmul_half(x, qweight, y, scales, zeros) + y = y.to(dtype) + return y.reshape(outshape) + + +def matmul4bit_transpose_half(x, qweight, scales, zeros): + """ + input x: (n, m) + qweight: (j, k) + where m == k + + perform qweight @ x.T + + return y: + """ + assert qweight.shape[1] == x.shape[-1] + outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8]) + x = x.reshape(-1, x.shape[-1]) + y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=x.dtype, device=x.device) + dtype = x.dtype + quant.quant_cuda.vecquant4transposematmul_half(x, qweight, y, scales, zeros) + y = y.to(dtype) + return y.reshape(outshape) + + class AutogradMatmul4bit(torch.autograd.Function): @staticmethod def forward(ctx, x, qweight, scales, zeros): ctx.save_for_backward(qweight, scales, zeros) - output = matmul4bit(x, qweight, scales, zeros).clone() - return output # equals to torch.matmul(x, qweight) + + # equals to torch.matmul(x, qweight) + if x.dtype == torch.float32: + output = matmul4bit(x, qweight, scales, zeros).clone() + elif x.dtype == torch.float16: + output = matmul4bit_half(x, qweight, scales, zeros).clone() + else: + raise ValueError('Only float and half are supportted.') + + return output @staticmethod def backward(ctx, grad_output): qweight, scales, zeros = ctx.saved_tensors - # print(grad_output.shape, A.shape, B.shape) # compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T - grad1 = matmul4bit_transpose(grad_output, qweight, scales, zeros) - # grad2 = torch.matmul(x.transpose(-1, -2), grad_output) + if grad_output.dtype == torch.float32: + grad = matmul4bit_transpose(grad_output, qweight, scales, zeros) + elif grad_output.dtype == torch.float16: + grad = matmul4bit_transpose_half(grad_output, qweight, scales, zeros) + else: + raise ValueError('Only float and half are supportted.') - return grad1, None, None, None + return grad, None, None, None # Assumes layer is perfectly divisible into 256 * 256 blocks @@ -102,7 +154,27 @@ def make_quant_for_4bit_autograd(module, names, name=''): make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1) -def load_llama_model_4bit_low_ram(config_path, model_path): +def model_to_half(model): + model.half() + for n, m in model.named_modules(): + if isinstance(m, Autograd4bitQuantLinear): + m.zeros = m.zeros.half() + m.scales = m.scales.half() + m.bias = m.bias.half() + print('Converted as Half.') + + +def model_to_float(model): + model.float() + for n, m in model.named_modules(): + if isinstance(m, Autograd4bitQuantLinear): + m.zeros = m.zeros.float() + m.scales = m.scales.float() + m.bias = m.bias.float() + print('Converted as Float.') + + +def load_llama_model_4bit_low_ram(config_path, model_path, half=False): import transformers import accelerate from transformers import LLaMAConfig, LLaMAForCausalLM, LLaMATokenizer @@ -127,6 +199,9 @@ def load_llama_model_4bit_low_ram(config_path, model_path): model = accelerate.load_checkpoint_and_dispatch(model=model, checkpoint=model_path, device_map='auto') model.cuda() model.seqlen = 2048 + + if half: + model_to_half(model) tokenizer = LLaMATokenizer.from_pretrained(config_path) tokenizer.truncation_side = 'left' @@ -134,4 +209,4 @@ def load_llama_model_4bit_low_ram(config_path, model_path): print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") return model, tokenizer - + \ No newline at end of file