add half support
This commit is contained in:
parent
5b64833390
commit
dd0d5a31f7
|
|
@ -2,6 +2,7 @@ import quant
|
|||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import time
|
||||
|
||||
|
||||
def matmul4bit(x, qweight, scales, zeros):
|
||||
|
|
@ -46,24 +47,75 @@ def matmul4bit_transpose(x, qweight, scales, zeros):
|
|||
return y.reshape(outshape)
|
||||
|
||||
|
||||
def matmul4bit_half(x, qweight, scales, zeros):
|
||||
"""
|
||||
input x: (n, m)
|
||||
qweight: (j, k)
|
||||
where m == j*8
|
||||
|
||||
perform x @ qweight
|
||||
|
||||
return y:
|
||||
"""
|
||||
assert qweight.shape[0] * 8 == x.shape[-1]
|
||||
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=x.dtype, device=x.device)
|
||||
dtype = x.dtype
|
||||
quant.quant_cuda.vecquant4matmul_half(x, qweight, y, scales, zeros)
|
||||
y = y.to(dtype)
|
||||
return y.reshape(outshape)
|
||||
|
||||
|
||||
def matmul4bit_transpose_half(x, qweight, scales, zeros):
|
||||
"""
|
||||
input x: (n, m)
|
||||
qweight: (j, k)
|
||||
where m == k
|
||||
|
||||
perform qweight @ x.T
|
||||
|
||||
return y:
|
||||
"""
|
||||
assert qweight.shape[1] == x.shape[-1]
|
||||
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8])
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=x.dtype, device=x.device)
|
||||
dtype = x.dtype
|
||||
quant.quant_cuda.vecquant4transposematmul_half(x, qweight, y, scales, zeros)
|
||||
y = y.to(dtype)
|
||||
return y.reshape(outshape)
|
||||
|
||||
|
||||
class AutogradMatmul4bit(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, qweight, scales, zeros):
|
||||
ctx.save_for_backward(qweight, scales, zeros)
|
||||
output = matmul4bit(x, qweight, scales, zeros).clone()
|
||||
return output # equals to torch.matmul(x, qweight)
|
||||
|
||||
# equals to torch.matmul(x, qweight)
|
||||
if x.dtype == torch.float32:
|
||||
output = matmul4bit(x, qweight, scales, zeros).clone()
|
||||
elif x.dtype == torch.float16:
|
||||
output = matmul4bit_half(x, qweight, scales, zeros).clone()
|
||||
else:
|
||||
raise ValueError('Only float and half are supportted.')
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
qweight, scales, zeros = ctx.saved_tensors
|
||||
# print(grad_output.shape, A.shape, B.shape)
|
||||
|
||||
# compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T
|
||||
grad1 = matmul4bit_transpose(grad_output, qweight, scales, zeros)
|
||||
# grad2 = torch.matmul(x.transpose(-1, -2), grad_output)
|
||||
if grad_output.dtype == torch.float32:
|
||||
grad = matmul4bit_transpose(grad_output, qweight, scales, zeros)
|
||||
elif grad_output.dtype == torch.float16:
|
||||
grad = matmul4bit_transpose_half(grad_output, qweight, scales, zeros)
|
||||
else:
|
||||
raise ValueError('Only float and half are supportted.')
|
||||
|
||||
return grad1, None, None, None
|
||||
return grad, None, None, None
|
||||
|
||||
|
||||
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
||||
|
|
@ -102,7 +154,27 @@ def make_quant_for_4bit_autograd(module, names, name=''):
|
|||
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1)
|
||||
|
||||
|
||||
def load_llama_model_4bit_low_ram(config_path, model_path):
|
||||
def model_to_half(model):
|
||||
model.half()
|
||||
for n, m in model.named_modules():
|
||||
if isinstance(m, Autograd4bitQuantLinear):
|
||||
m.zeros = m.zeros.half()
|
||||
m.scales = m.scales.half()
|
||||
m.bias = m.bias.half()
|
||||
print('Converted as Half.')
|
||||
|
||||
|
||||
def model_to_float(model):
|
||||
model.float()
|
||||
for n, m in model.named_modules():
|
||||
if isinstance(m, Autograd4bitQuantLinear):
|
||||
m.zeros = m.zeros.float()
|
||||
m.scales = m.scales.float()
|
||||
m.bias = m.bias.float()
|
||||
print('Converted as Float.')
|
||||
|
||||
|
||||
def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
||||
import transformers
|
||||
import accelerate
|
||||
from transformers import LLaMAConfig, LLaMAForCausalLM, LLaMATokenizer
|
||||
|
|
@ -127,6 +199,9 @@ def load_llama_model_4bit_low_ram(config_path, model_path):
|
|||
model = accelerate.load_checkpoint_and_dispatch(model=model, checkpoint=model_path, device_map='auto')
|
||||
model.cuda()
|
||||
model.seqlen = 2048
|
||||
|
||||
if half:
|
||||
model_to_half(model)
|
||||
|
||||
tokenizer = LLaMATokenizer.from_pretrained(config_path)
|
||||
tokenizer.truncation_side = 'left'
|
||||
|
|
@ -134,4 +209,4 @@ def load_llama_model_4bit_low_ram(config_path, model_path):
|
|||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
Loading…
Reference in New Issue