add half support

This commit is contained in:
John Smith 2023-03-20 09:37:51 +00:00
parent 5b64833390
commit dd0d5a31f7
1 changed files with 83 additions and 8 deletions

View File

@ -2,6 +2,7 @@ import quant
import torch
import numpy as np
import torch.nn as nn
import time
def matmul4bit(x, qweight, scales, zeros):
@ -46,24 +47,75 @@ def matmul4bit_transpose(x, qweight, scales, zeros):
return y.reshape(outshape)
def matmul4bit_half(x, qweight, scales, zeros):
"""
input x: (n, m)
qweight: (j, k)
where m == j*8
perform x @ qweight
return y:
"""
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=x.dtype, device=x.device)
dtype = x.dtype
quant.quant_cuda.vecquant4matmul_half(x, qweight, y, scales, zeros)
y = y.to(dtype)
return y.reshape(outshape)
def matmul4bit_transpose_half(x, qweight, scales, zeros):
"""
input x: (n, m)
qweight: (j, k)
where m == k
perform qweight @ x.T
return y:
"""
assert qweight.shape[1] == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8])
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=x.dtype, device=x.device)
dtype = x.dtype
quant.quant_cuda.vecquant4transposematmul_half(x, qweight, y, scales, zeros)
y = y.to(dtype)
return y.reshape(outshape)
class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod
def forward(ctx, x, qweight, scales, zeros):
ctx.save_for_backward(qweight, scales, zeros)
output = matmul4bit(x, qweight, scales, zeros).clone()
return output # equals to torch.matmul(x, qweight)
# equals to torch.matmul(x, qweight)
if x.dtype == torch.float32:
output = matmul4bit(x, qweight, scales, zeros).clone()
elif x.dtype == torch.float16:
output = matmul4bit_half(x, qweight, scales, zeros).clone()
else:
raise ValueError('Only float and half are supportted.')
return output
@staticmethod
def backward(ctx, grad_output):
qweight, scales, zeros = ctx.saved_tensors
# print(grad_output.shape, A.shape, B.shape)
# compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T
grad1 = matmul4bit_transpose(grad_output, qweight, scales, zeros)
# grad2 = torch.matmul(x.transpose(-1, -2), grad_output)
if grad_output.dtype == torch.float32:
grad = matmul4bit_transpose(grad_output, qweight, scales, zeros)
elif grad_output.dtype == torch.float16:
grad = matmul4bit_transpose_half(grad_output, qweight, scales, zeros)
else:
raise ValueError('Only float and half are supportted.')
return grad1, None, None, None
return grad, None, None, None
# Assumes layer is perfectly divisible into 256 * 256 blocks
@ -102,7 +154,27 @@ def make_quant_for_4bit_autograd(module, names, name=''):
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1)
def load_llama_model_4bit_low_ram(config_path, model_path):
def model_to_half(model):
model.half()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
print('Converted as Half.')
def model_to_float(model):
model.float()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
m.zeros = m.zeros.float()
m.scales = m.scales.float()
m.bias = m.bias.float()
print('Converted as Float.')
def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
import transformers
import accelerate
from transformers import LLaMAConfig, LLaMAForCausalLM, LLaMATokenizer
@ -128,6 +200,9 @@ def load_llama_model_4bit_low_ram(config_path, model_path):
model.cuda()
model.seqlen = 2048
if half:
model_to_half(model)
tokenizer = LLaMATokenizer.from_pretrained(config_path)
tokenizer.truncation_side = 'left'