This commit is contained in:
John Smith 2023-03-28 21:12:51 +08:00
parent bff039de95
commit 211af574b6
3 changed files with 14 additions and 11 deletions

View File

@ -4,13 +4,14 @@ import torch.nn as nn
import time import time
import math import math
from safetensors import safe_open from safetensors import safe_open
import numpy as np
class AutogradMatmul4bit(torch.autograd.Function): class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, qweight, scales, zeros, groupsize=-1): def forward(ctx, x, qweight, scales, zeros, groupsize=-1):
ctx.save_for_backward(qweight, scales, zeros, groupsize) ctx.save_for_backward(qweight, scales, zeros, torch.from_numpy(np.array([groupsize])).cuda())
if groupsize == -1: if groupsize == -1:
output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros) output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros)
else: else:
@ -21,6 +22,7 @@ class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
qweight, scales, zeros, groupsize = ctx.saved_tensors qweight, scales, zeros, groupsize = ctx.saved_tensors
groupsize = groupsize.cpu().numpy()[0]
if groupsize == -1: if groupsize == -1:
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True) grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
else: else:

View File

@ -66,7 +66,8 @@ else:
print('Fitting 4bit scales and zeros to half') print('Fitting 4bit scales and zeros to half')
for n, m in model.named_modules(): for n, m in model.named_modules():
if '4bit' in str(type(m)): if '4bit' in str(type(m)):
m.zeros = m.zeros.half() if m.groupsize == -1:
m.zeros = m.zeros.half()
m.scales = m.scales.half() m.scales = m.scales.half()
# Set tokenizer # Set tokenizer

View File

@ -45,7 +45,7 @@ def _matmul4bit_v1(x, qweight, scales, zeros):
return y.reshape(outshape) return y.reshape(outshape)
def _matmul4bit_v2(x, qweight, scales, zeros, group_size): def _matmul4bit_v2(x, qweight, scales, zeros, groupsize):
""" """
input x: (n, m) input x: (n, m)
qweight: (j, k) qweight: (j, k)
@ -63,7 +63,7 @@ def _matmul4bit_v2(x, qweight, scales, zeros, group_size):
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
dtype = x.dtype dtype = x.dtype
x = x.half() x = x.half()
quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, group_size, x.shape[-1] // 2) quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, groupsize, x.shape[-1] // 2)
y = y.to(dtype) y = y.to(dtype)
return y.reshape(outshape) return y.reshape(outshape)
@ -84,7 +84,7 @@ def _matmul4bit_v1_recons(x, qweight, scales, zeros, transpose=False):
return output return output
def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False): def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False):
if debug: if debug:
print('_matmul4bit_v2_recons') print('_matmul4bit_v2_recons')
if not transpose: if not transpose:
@ -92,7 +92,7 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False
else: else:
assert qweight.shape[1] == x.shape[-1] assert qweight.shape[1] == x.shape[-1]
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device) buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, group_size) quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, groupsize)
if not transpose: if not transpose:
output = torch.matmul(x, buffer) output = torch.matmul(x, buffer)
if transpose: if transpose:
@ -100,8 +100,8 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False
return output return output
def matmul4bit(x, qweight, scales, zeros, group_size=-1): def matmul4bit(x, qweight, scales, zeros, groupsize=-1):
if group_size == -1: if groupsize == -1:
# use v1 # use v1
if use_new: if use_new:
if auto_switch: if auto_switch:
@ -116,11 +116,11 @@ def matmul4bit(x, qweight, scales, zeros, group_size=-1):
if use_new: if use_new:
if auto_switch: if auto_switch:
if np.prod(x.shape[:-1]) > auto_switch_thd: if np.prod(x.shape[:-1]) > auto_switch_thd:
output = _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size) output = _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize)
else: else:
output = _matmul4bit_v2(x, qweight, scales, zeros, group_size) output = _matmul4bit_v2(x, qweight, scales, zeros, groupsize)
else: else:
output = _matmul4bit_v2(x, qweight, scales, zeros, group_size) output = _matmul4bit_v2(x, qweight, scales, zeros, groupsize)
return output return output