fix bug
This commit is contained in:
parent
bff039de95
commit
211af574b6
|
|
@ -4,13 +4,14 @@ import torch.nn as nn
|
|||
import time
|
||||
import math
|
||||
from safetensors import safe_open
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AutogradMatmul4bit(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, qweight, scales, zeros, groupsize=-1):
|
||||
ctx.save_for_backward(qweight, scales, zeros, groupsize)
|
||||
ctx.save_for_backward(qweight, scales, zeros, torch.from_numpy(np.array([groupsize])).cuda())
|
||||
if groupsize == -1:
|
||||
output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros)
|
||||
else:
|
||||
|
|
@ -21,6 +22,7 @@ class AutogradMatmul4bit(torch.autograd.Function):
|
|||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
qweight, scales, zeros, groupsize = ctx.saved_tensors
|
||||
groupsize = groupsize.cpu().numpy()[0]
|
||||
if groupsize == -1:
|
||||
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -66,7 +66,8 @@ else:
|
|||
print('Fitting 4bit scales and zeros to half')
|
||||
for n, m in model.named_modules():
|
||||
if '4bit' in str(type(m)):
|
||||
m.zeros = m.zeros.half()
|
||||
if m.groupsize == -1:
|
||||
m.zeros = m.zeros.half()
|
||||
m.scales = m.scales.half()
|
||||
|
||||
# Set tokenizer
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ def _matmul4bit_v1(x, qweight, scales, zeros):
|
|||
return y.reshape(outshape)
|
||||
|
||||
|
||||
def _matmul4bit_v2(x, qweight, scales, zeros, group_size):
|
||||
def _matmul4bit_v2(x, qweight, scales, zeros, groupsize):
|
||||
"""
|
||||
input x: (n, m)
|
||||
qweight: (j, k)
|
||||
|
|
@ -63,7 +63,7 @@ def _matmul4bit_v2(x, qweight, scales, zeros, group_size):
|
|||
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
||||
dtype = x.dtype
|
||||
x = x.half()
|
||||
quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, group_size, x.shape[-1] // 2)
|
||||
quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, groupsize, x.shape[-1] // 2)
|
||||
y = y.to(dtype)
|
||||
return y.reshape(outshape)
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ def _matmul4bit_v1_recons(x, qweight, scales, zeros, transpose=False):
|
|||
return output
|
||||
|
||||
|
||||
def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False):
|
||||
def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False):
|
||||
if debug:
|
||||
print('_matmul4bit_v2_recons')
|
||||
if not transpose:
|
||||
|
|
@ -92,7 +92,7 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False
|
|||
else:
|
||||
assert qweight.shape[1] == x.shape[-1]
|
||||
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
||||
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, group_size)
|
||||
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, groupsize)
|
||||
if not transpose:
|
||||
output = torch.matmul(x, buffer)
|
||||
if transpose:
|
||||
|
|
@ -100,8 +100,8 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False
|
|||
return output
|
||||
|
||||
|
||||
def matmul4bit(x, qweight, scales, zeros, group_size=-1):
|
||||
if group_size == -1:
|
||||
def matmul4bit(x, qweight, scales, zeros, groupsize=-1):
|
||||
if groupsize == -1:
|
||||
# use v1
|
||||
if use_new:
|
||||
if auto_switch:
|
||||
|
|
@ -116,11 +116,11 @@ def matmul4bit(x, qweight, scales, zeros, group_size=-1):
|
|||
if use_new:
|
||||
if auto_switch:
|
||||
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
||||
output = _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size)
|
||||
output = _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize)
|
||||
else:
|
||||
output = _matmul4bit_v2(x, qweight, scales, zeros, group_size)
|
||||
output = _matmul4bit_v2(x, qweight, scales, zeros, groupsize)
|
||||
else:
|
||||
output = _matmul4bit_v2(x, qweight, scales, zeros, group_size)
|
||||
output = _matmul4bit_v2(x, qweight, scales, zeros, groupsize)
|
||||
return output
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue