This commit is contained in:
John Smith 2023-03-28 21:12:51 +08:00
parent bff039de95
commit 211af574b6
3 changed files with 14 additions and 11 deletions

View File

@ -4,13 +4,14 @@ import torch.nn as nn
import time
import math
from safetensors import safe_open
import numpy as np
class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod
def forward(ctx, x, qweight, scales, zeros, groupsize=-1):
ctx.save_for_backward(qweight, scales, zeros, groupsize)
ctx.save_for_backward(qweight, scales, zeros, torch.from_numpy(np.array([groupsize])).cuda())
if groupsize == -1:
output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros)
else:
@ -21,6 +22,7 @@ class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
qweight, scales, zeros, groupsize = ctx.saved_tensors
groupsize = groupsize.cpu().numpy()[0]
if groupsize == -1:
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
else:

View File

@ -66,7 +66,8 @@ else:
print('Fitting 4bit scales and zeros to half')
for n, m in model.named_modules():
if '4bit' in str(type(m)):
m.zeros = m.zeros.half()
if m.groupsize == -1:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
# Set tokenizer

View File

@ -45,7 +45,7 @@ def _matmul4bit_v1(x, qweight, scales, zeros):
return y.reshape(outshape)
def _matmul4bit_v2(x, qweight, scales, zeros, group_size):
def _matmul4bit_v2(x, qweight, scales, zeros, groupsize):
"""
input x: (n, m)
qweight: (j, k)
@ -63,7 +63,7 @@ def _matmul4bit_v2(x, qweight, scales, zeros, group_size):
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
dtype = x.dtype
x = x.half()
quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, group_size, x.shape[-1] // 2)
quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, groupsize, x.shape[-1] // 2)
y = y.to(dtype)
return y.reshape(outshape)
@ -84,7 +84,7 @@ def _matmul4bit_v1_recons(x, qweight, scales, zeros, transpose=False):
return output
def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False):
def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False):
if debug:
print('_matmul4bit_v2_recons')
if not transpose:
@ -92,7 +92,7 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False
else:
assert qweight.shape[1] == x.shape[-1]
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, group_size)
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, groupsize)
if not transpose:
output = torch.matmul(x, buffer)
if transpose:
@ -100,8 +100,8 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False
return output
def matmul4bit(x, qweight, scales, zeros, group_size=-1):
if group_size == -1:
def matmul4bit(x, qweight, scales, zeros, groupsize=-1):
if groupsize == -1:
# use v1
if use_new:
if auto_switch:
@ -116,11 +116,11 @@ def matmul4bit(x, qweight, scales, zeros, group_size=-1):
if use_new:
if auto_switch:
if np.prod(x.shape[:-1]) > auto_switch_thd:
output = _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size)
output = _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize)
else:
output = _matmul4bit_v2(x, qweight, scales, zeros, group_size)
output = _matmul4bit_v2(x, qweight, scales, zeros, groupsize)
else:
output = _matmul4bit_v2(x, qweight, scales, zeros, group_size)
output = _matmul4bit_v2(x, qweight, scales, zeros, groupsize)
return output