add half support
This commit is contained in:
parent
5b64833390
commit
dd0d5a31f7
|
|
@ -2,6 +2,7 @@ import quant
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
def matmul4bit(x, qweight, scales, zeros):
|
def matmul4bit(x, qweight, scales, zeros):
|
||||||
|
|
@ -46,24 +47,75 @@ def matmul4bit_transpose(x, qweight, scales, zeros):
|
||||||
return y.reshape(outshape)
|
return y.reshape(outshape)
|
||||||
|
|
||||||
|
|
||||||
|
def matmul4bit_half(x, qweight, scales, zeros):
|
||||||
|
"""
|
||||||
|
input x: (n, m)
|
||||||
|
qweight: (j, k)
|
||||||
|
where m == j*8
|
||||||
|
|
||||||
|
perform x @ qweight
|
||||||
|
|
||||||
|
return y:
|
||||||
|
"""
|
||||||
|
assert qweight.shape[0] * 8 == x.shape[-1]
|
||||||
|
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=x.dtype, device=x.device)
|
||||||
|
dtype = x.dtype
|
||||||
|
quant.quant_cuda.vecquant4matmul_half(x, qweight, y, scales, zeros)
|
||||||
|
y = y.to(dtype)
|
||||||
|
return y.reshape(outshape)
|
||||||
|
|
||||||
|
|
||||||
|
def matmul4bit_transpose_half(x, qweight, scales, zeros):
|
||||||
|
"""
|
||||||
|
input x: (n, m)
|
||||||
|
qweight: (j, k)
|
||||||
|
where m == k
|
||||||
|
|
||||||
|
perform qweight @ x.T
|
||||||
|
|
||||||
|
return y:
|
||||||
|
"""
|
||||||
|
assert qweight.shape[1] == x.shape[-1]
|
||||||
|
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8])
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=x.dtype, device=x.device)
|
||||||
|
dtype = x.dtype
|
||||||
|
quant.quant_cuda.vecquant4transposematmul_half(x, qweight, y, scales, zeros)
|
||||||
|
y = y.to(dtype)
|
||||||
|
return y.reshape(outshape)
|
||||||
|
|
||||||
|
|
||||||
class AutogradMatmul4bit(torch.autograd.Function):
|
class AutogradMatmul4bit(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x, qweight, scales, zeros):
|
def forward(ctx, x, qweight, scales, zeros):
|
||||||
ctx.save_for_backward(qweight, scales, zeros)
|
ctx.save_for_backward(qweight, scales, zeros)
|
||||||
output = matmul4bit(x, qweight, scales, zeros).clone()
|
|
||||||
return output # equals to torch.matmul(x, qweight)
|
# equals to torch.matmul(x, qweight)
|
||||||
|
if x.dtype == torch.float32:
|
||||||
|
output = matmul4bit(x, qweight, scales, zeros).clone()
|
||||||
|
elif x.dtype == torch.float16:
|
||||||
|
output = matmul4bit_half(x, qweight, scales, zeros).clone()
|
||||||
|
else:
|
||||||
|
raise ValueError('Only float and half are supportted.')
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
qweight, scales, zeros = ctx.saved_tensors
|
qweight, scales, zeros = ctx.saved_tensors
|
||||||
# print(grad_output.shape, A.shape, B.shape)
|
|
||||||
|
|
||||||
# compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T
|
# compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T
|
||||||
grad1 = matmul4bit_transpose(grad_output, qweight, scales, zeros)
|
if grad_output.dtype == torch.float32:
|
||||||
# grad2 = torch.matmul(x.transpose(-1, -2), grad_output)
|
grad = matmul4bit_transpose(grad_output, qweight, scales, zeros)
|
||||||
|
elif grad_output.dtype == torch.float16:
|
||||||
|
grad = matmul4bit_transpose_half(grad_output, qweight, scales, zeros)
|
||||||
|
else:
|
||||||
|
raise ValueError('Only float and half are supportted.')
|
||||||
|
|
||||||
return grad1, None, None, None
|
return grad, None, None, None
|
||||||
|
|
||||||
|
|
||||||
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
||||||
|
|
@ -102,7 +154,27 @@ def make_quant_for_4bit_autograd(module, names, name=''):
|
||||||
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1)
|
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1)
|
||||||
|
|
||||||
|
|
||||||
def load_llama_model_4bit_low_ram(config_path, model_path):
|
def model_to_half(model):
|
||||||
|
model.half()
|
||||||
|
for n, m in model.named_modules():
|
||||||
|
if isinstance(m, Autograd4bitQuantLinear):
|
||||||
|
m.zeros = m.zeros.half()
|
||||||
|
m.scales = m.scales.half()
|
||||||
|
m.bias = m.bias.half()
|
||||||
|
print('Converted as Half.')
|
||||||
|
|
||||||
|
|
||||||
|
def model_to_float(model):
|
||||||
|
model.float()
|
||||||
|
for n, m in model.named_modules():
|
||||||
|
if isinstance(m, Autograd4bitQuantLinear):
|
||||||
|
m.zeros = m.zeros.float()
|
||||||
|
m.scales = m.scales.float()
|
||||||
|
m.bias = m.bias.float()
|
||||||
|
print('Converted as Float.')
|
||||||
|
|
||||||
|
|
||||||
|
def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
||||||
import transformers
|
import transformers
|
||||||
import accelerate
|
import accelerate
|
||||||
from transformers import LLaMAConfig, LLaMAForCausalLM, LLaMATokenizer
|
from transformers import LLaMAConfig, LLaMAForCausalLM, LLaMATokenizer
|
||||||
|
|
@ -128,6 +200,9 @@ def load_llama_model_4bit_low_ram(config_path, model_path):
|
||||||
model.cuda()
|
model.cuda()
|
||||||
model.seqlen = 2048
|
model.seqlen = 2048
|
||||||
|
|
||||||
|
if half:
|
||||||
|
model_to_half(model)
|
||||||
|
|
||||||
tokenizer = LLaMATokenizer.from_pretrained(config_path)
|
tokenizer = LLaMATokenizer.from_pretrained(config_path)
|
||||||
tokenizer.truncation_side = 'left'
|
tokenizer.truncation_side = 'left'
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue