alpaca_lora_4bit/matmul_utils_4bit.py

151 lines
5.2 KiB
Python

import torch
import numpy as np
from gptq_llama import quant_cuda
# Global Buffer
buffer_mat_dic = {}
use_new = True
auto_switch = True
auto_switch_thd = 8
debug = False
faster = True
cache_buffer = True
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
if not cache_buffer:
return torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
if shape_of_qweight not in buffer_mat_dic.keys():
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
else:
if buffer_mat_dic[shape_of_qweight].device != device:
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device)
if buffer_mat_dic[shape_of_qweight].dtype != dtype:
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(dtype=dtype)
return buffer_mat_dic[shape_of_qweight]
def _matmul4bit_v1(x, qweight, scales, zeros):
"""
input x: (n, m)
qweight: (j, k)
where m == j*8
perform x @ qweight
return y:
"""
if debug:
print('_matmul4bit_v1')
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = x.shape[:-1] + (qweight.shape[1],)
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float16, device=x.device)
dtype = x.dtype
x = x.half()
quant_cuda.vecquant4matmul_v1_faster(x, qweight, y, scales, zeros)
y = y.to(dtype)
return y.reshape(outshape)
def _matmul4bit_v2(x, qweight, scales, zeros, g_idx):
"""
input x: (n, m)
qweight: (j, k)
where m == j*8
perform x @ qweight
return y:
"""
if debug:
print('_matmul4bit_v2')
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = x.shape[:-1] + (qweight.shape[1],)
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float16, device=x.device)
dtype = x.dtype
if faster:
x = x.half()
quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, g_idx, x.shape[-1] // 2)
else:
x = x.float()
quant_cuda.vecquant4matmul(x, qweight, y, scales, zeros, g_idx)
y = y.to(dtype)
return y.reshape(outshape)
def _matmul4bit_v1_recons(x, qweight, scales, zeros, transpose=False):
if debug:
print('_matmul4bit_v1_recons')
if not transpose:
assert qweight.shape[0] * 8 == x.shape[-1]
else:
assert qweight.shape[1] == x.shape[-1]
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant_cuda.vecquant4recons_v1(qweight, buffer, scales, zeros)
if not transpose:
output = torch.matmul(x, buffer)
else:
output = torch.matmul(x, buffer.T)
return output
def _matmul4bit_v2_recons(x, qweight, scales, zeros, g_idx, transpose=False):
if debug:
print('_matmul4bit_v2_recons')
if not transpose:
assert qweight.shape[0] * 8 == x.shape[-1]
else:
assert qweight.shape[1] == x.shape[-1]
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, g_idx)
if not transpose:
output = torch.matmul(x, buffer)
else:
output = torch.matmul(x, buffer.T)
return output
def matmul4bit(x, qweight, scales, zeros, g_idx=None):
# detect if zeros is int32
if zeros.dtype != torch.int32:
# use v1
if use_new:
if auto_switch:
if np.prod(x.shape[:-1]) > auto_switch_thd:
output = _matmul4bit_v1_recons(x.half(), qweight, scales.half(), zeros.half())
else:
output = _matmul4bit_v1(x, qweight, scales, zeros)
else:
output = _matmul4bit_v1(x, qweight, scales, zeros)
else:
if g_idx is None:
g_idx = torch.zeros(qweight.shape[0] * 8, dtype=torch.int32, device=x.device)
# use v2
if use_new:
if auto_switch:
if np.prod(x.shape[:-1]) > auto_switch_thd:
output = _matmul4bit_v2_recons(x.half(), qweight, scales.half(), zeros, g_idx)
else:
output = _matmul4bit_v2(x, qweight, scales, zeros, g_idx)
else:
output = _matmul4bit_v2(x, qweight, scales, zeros, g_idx)
return output
def v2_to_v1(scales, zeros):
"""
Convert zeros in V2 model to V1 model when group_num = 1, for debugging
depreciated
"""
assert zeros.shape[0] == 1
z_mat = torch.zeros((zeros.shape[1], 256), dtype=torch.int, device=zeros.device) + zeros.reshape((-1,1))
z_buffer = torch.zeros((z_mat.shape[0] * 8, z_mat.shape[1]), dtype=torch.float16, device=zeros.device)
z_zeros = torch.zeros(z_mat.shape[1], dtype=torch.float16, device=zeros.device)
z_scales = torch.ones(z_mat.shape[1], dtype=torch.float16, device=zeros.device)
quant_cuda.vecquant4recons_v1(z_mat, z_buffer, z_scales, z_zeros)
z_buffer = z_buffer[:,0]
zeros_recons = z_buffer * scales + scales
return zeros_recons