fix bug
This commit is contained in:
parent
bff039de95
commit
211af574b6
|
|
@ -4,13 +4,14 @@ import torch.nn as nn
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class AutogradMatmul4bit(torch.autograd.Function):
|
class AutogradMatmul4bit(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x, qweight, scales, zeros, groupsize=-1):
|
def forward(ctx, x, qweight, scales, zeros, groupsize=-1):
|
||||||
ctx.save_for_backward(qweight, scales, zeros, groupsize)
|
ctx.save_for_backward(qweight, scales, zeros, torch.from_numpy(np.array([groupsize])).cuda())
|
||||||
if groupsize == -1:
|
if groupsize == -1:
|
||||||
output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros)
|
output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros)
|
||||||
else:
|
else:
|
||||||
|
|
@ -21,6 +22,7 @@ class AutogradMatmul4bit(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
qweight, scales, zeros, groupsize = ctx.saved_tensors
|
qweight, scales, zeros, groupsize = ctx.saved_tensors
|
||||||
|
groupsize = groupsize.cpu().numpy()[0]
|
||||||
if groupsize == -1:
|
if groupsize == -1:
|
||||||
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
|
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,8 @@ else:
|
||||||
print('Fitting 4bit scales and zeros to half')
|
print('Fitting 4bit scales and zeros to half')
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
if '4bit' in str(type(m)):
|
if '4bit' in str(type(m)):
|
||||||
m.zeros = m.zeros.half()
|
if m.groupsize == -1:
|
||||||
|
m.zeros = m.zeros.half()
|
||||||
m.scales = m.scales.half()
|
m.scales = m.scales.half()
|
||||||
|
|
||||||
# Set tokenizer
|
# Set tokenizer
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ def _matmul4bit_v1(x, qweight, scales, zeros):
|
||||||
return y.reshape(outshape)
|
return y.reshape(outshape)
|
||||||
|
|
||||||
|
|
||||||
def _matmul4bit_v2(x, qweight, scales, zeros, group_size):
|
def _matmul4bit_v2(x, qweight, scales, zeros, groupsize):
|
||||||
"""
|
"""
|
||||||
input x: (n, m)
|
input x: (n, m)
|
||||||
qweight: (j, k)
|
qweight: (j, k)
|
||||||
|
|
@ -63,7 +63,7 @@ def _matmul4bit_v2(x, qweight, scales, zeros, group_size):
|
||||||
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
x = x.half()
|
x = x.half()
|
||||||
quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, group_size, x.shape[-1] // 2)
|
quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, groupsize, x.shape[-1] // 2)
|
||||||
y = y.to(dtype)
|
y = y.to(dtype)
|
||||||
return y.reshape(outshape)
|
return y.reshape(outshape)
|
||||||
|
|
||||||
|
|
@ -84,7 +84,7 @@ def _matmul4bit_v1_recons(x, qweight, scales, zeros, transpose=False):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False):
|
def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False):
|
||||||
if debug:
|
if debug:
|
||||||
print('_matmul4bit_v2_recons')
|
print('_matmul4bit_v2_recons')
|
||||||
if not transpose:
|
if not transpose:
|
||||||
|
|
@ -92,7 +92,7 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False
|
||||||
else:
|
else:
|
||||||
assert qweight.shape[1] == x.shape[-1]
|
assert qweight.shape[1] == x.shape[-1]
|
||||||
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
||||||
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, group_size)
|
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, groupsize)
|
||||||
if not transpose:
|
if not transpose:
|
||||||
output = torch.matmul(x, buffer)
|
output = torch.matmul(x, buffer)
|
||||||
if transpose:
|
if transpose:
|
||||||
|
|
@ -100,8 +100,8 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def matmul4bit(x, qweight, scales, zeros, group_size=-1):
|
def matmul4bit(x, qweight, scales, zeros, groupsize=-1):
|
||||||
if group_size == -1:
|
if groupsize == -1:
|
||||||
# use v1
|
# use v1
|
||||||
if use_new:
|
if use_new:
|
||||||
if auto_switch:
|
if auto_switch:
|
||||||
|
|
@ -116,11 +116,11 @@ def matmul4bit(x, qweight, scales, zeros, group_size=-1):
|
||||||
if use_new:
|
if use_new:
|
||||||
if auto_switch:
|
if auto_switch:
|
||||||
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
||||||
output = _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size)
|
output = _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize)
|
||||||
else:
|
else:
|
||||||
output = _matmul4bit_v2(x, qweight, scales, zeros, group_size)
|
output = _matmul4bit_v2(x, qweight, scales, zeros, groupsize)
|
||||||
else:
|
else:
|
||||||
output = _matmul4bit_v2(x, qweight, scales, zeros, group_size)
|
output = _matmul4bit_v2(x, qweight, scales, zeros, groupsize)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue