diff --git a/autograd_4bit.py b/autograd_4bit.py index 3b12af8..77befb7 100644 --- a/autograd_4bit.py +++ b/autograd_4bit.py @@ -4,13 +4,14 @@ import torch.nn as nn import time import math from safetensors import safe_open +import numpy as np class AutogradMatmul4bit(torch.autograd.Function): @staticmethod def forward(ctx, x, qweight, scales, zeros, groupsize=-1): - ctx.save_for_backward(qweight, scales, zeros, groupsize) + ctx.save_for_backward(qweight, scales, zeros, torch.from_numpy(np.array([groupsize])).cuda()) if groupsize == -1: output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros) else: @@ -21,6 +22,7 @@ class AutogradMatmul4bit(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): qweight, scales, zeros, groupsize = ctx.saved_tensors + groupsize = groupsize.cpu().numpy()[0] if groupsize == -1: grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True) else: diff --git a/finetune.py b/finetune.py index 522ac90..36cfaea 100644 --- a/finetune.py +++ b/finetune.py @@ -66,7 +66,8 @@ else: print('Fitting 4bit scales and zeros to half') for n, m in model.named_modules(): if '4bit' in str(type(m)): - m.zeros = m.zeros.half() + if m.groupsize == -1: + m.zeros = m.zeros.half() m.scales = m.scales.half() # Set tokenizer diff --git a/matmul_utils_4bit.py b/matmul_utils_4bit.py index 31bb6a4..b897f70 100644 --- a/matmul_utils_4bit.py +++ b/matmul_utils_4bit.py @@ -45,7 +45,7 @@ def _matmul4bit_v1(x, qweight, scales, zeros): return y.reshape(outshape) -def _matmul4bit_v2(x, qweight, scales, zeros, group_size): +def _matmul4bit_v2(x, qweight, scales, zeros, groupsize): """ input x: (n, m) qweight: (j, k) @@ -63,7 +63,7 @@ def _matmul4bit_v2(x, qweight, scales, zeros, group_size): y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) dtype = x.dtype x = x.half() - quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, group_size, x.shape[-1] // 2) + quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, groupsize, x.shape[-1] // 2) y = y.to(dtype) return y.reshape(outshape) @@ -84,7 +84,7 @@ def _matmul4bit_v1_recons(x, qweight, scales, zeros, transpose=False): return output -def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False): +def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False): if debug: print('_matmul4bit_v2_recons') if not transpose: @@ -92,7 +92,7 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False else: assert qweight.shape[1] == x.shape[-1] buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device) - quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, group_size) + quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, groupsize) if not transpose: output = torch.matmul(x, buffer) if transpose: @@ -100,8 +100,8 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False return output -def matmul4bit(x, qweight, scales, zeros, group_size=-1): - if group_size == -1: +def matmul4bit(x, qweight, scales, zeros, groupsize=-1): + if groupsize == -1: # use v1 if use_new: if auto_switch: @@ -116,11 +116,11 @@ def matmul4bit(x, qweight, scales, zeros, group_size=-1): if use_new: if auto_switch: if np.prod(x.shape[:-1]) > auto_switch_thd: - output = _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size) + output = _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize) else: - output = _matmul4bit_v2(x, qweight, scales, zeros, group_size) + output = _matmul4bit_v2(x, qweight, scales, zeros, groupsize) else: - output = _matmul4bit_v2(x, qweight, scales, zeros, group_size) + output = _matmul4bit_v2(x, qweight, scales, zeros, groupsize) return output