add v2 model support

This commit is contained in:
John Smith 2023-03-28 20:33:55 +08:00
parent 667e43cb5b
commit bff039de95
7 changed files with 213 additions and 154 deletions

View File

@ -13,7 +13,7 @@ class Finetune4bConfig:
gradient_checkpointing: bool,
gradient_checkpointing_ratio: float,
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
checkpoint: bool, skip: bool
checkpoint: bool, skip: bool, groupsize: int
):
"""
Args:
@ -40,6 +40,7 @@ class Finetune4bConfig:
logging_steps (int): Logging steps
checkpoint (bool): Produce checkpoint instead of LoRA
skip (bool): Don't train model
groupsize (int): Group size of V2 model, use -1 to load V1 model
"""
self.dataset = dataset
self.ds_type = ds_type
@ -71,6 +72,7 @@ class Finetune4bConfig:
self.device_map = "auto" if not self.ddp else {"": self.local_rank}
if self.ddp:
self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size
self.groupsize = groupsize
def __str__(self) -> str:

View File

@ -53,6 +53,9 @@ def parse_commandline():
parser_training.add_argument("-c", "--checkpoint", action="store_true", help="Produce checkpoint instead of LoRA. Default: %(default)s")
parser_training.add_argument("--skip", action="store_true", help="Don't train model. Can be useful to produce checkpoint from existing LoRA. Default: %(default)s")
# V2 model support
parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model, use -1 to load v1 model")
return vars(parser.parse_args())
@ -81,5 +84,6 @@ def get_config() -> Finetune4bConfig:
save_total_limit=args["save_total_limit"],
logging_steps=args["logging_steps"],
checkpoint=args["checkpoint"],
skip=args["skip"]
skip=args["skip"],
groupsize=args["groupsize"]
)

View File

@ -1,169 +1,70 @@
from gptq_llama import quant
import matmul_utils_4bit as mm4b
import torch
import numpy as np
import torch.nn as nn
import time
import math
from safetensors import safe_open
# Global Buffer
buffer_mat_dic = {}
use_new = True
auto_switch = True
auto_switch_thd = 16
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
if shape_of_qweight not in buffer_mat_dic.keys():
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
elif buffer_mat_dic[shape_of_qweight].device != device:
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device)
return buffer_mat_dic[shape_of_qweight]
def matmul4bit(x, qweight, scales, zeros):
"""
input x: (n, m)
qweight: (j, k)
where m == j*8
perform x @ qweight
return y:
"""
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
dtype = x.dtype
x = x.float()
quant.quant_cuda.vecquant4matmul(x, qweight, y, scales, zeros)
y = y.to(dtype)
return y.reshape(outshape)
def matmul4bit_transpose(x, qweight, scales, zeros):
"""
input x: (n, m)
qweight: (j, k)
where m == k
perform qweight @ x.T
return y:
"""
assert qweight.shape[1] == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8])
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=torch.float32, device=x.device)
dtype = x.dtype
x = x.float()
quant.quant_cuda.vecquant4transposematmul(x, qweight, y, scales, zeros)
y = y.to(dtype)
return y.reshape(outshape)
def matmul4bit_half(x, qweight, scales, zeros):
"""
input x: (n, m)
qweight: (j, k)
where m == j*8
perform x @ qweight
return y:
"""
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=x.dtype, device=x.device)
dtype = x.dtype
quant.quant_cuda.vecquant4matmul_half(x, qweight, y, scales, zeros)
y = y.to(dtype)
return y.reshape(outshape)
def matmul4bit_transpose_half(x, qweight, scales, zeros):
"""
input x: (n, m)
qweight: (j, k)
where m == k
perform qweight @ x.T
return y:
"""
assert qweight.shape[1] == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8])
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=x.dtype, device=x.device)
dtype = x.dtype
quant.quant_cuda.vecquant4transposematmul_half(x, qweight, y, scales, zeros)
y = y.to(dtype)
return y.reshape(outshape)
def fast_4bit_forward(x, qweight, scales, zeros, bias):
use_new_flag = use_new
if auto_switch:
if x.shape[1] > auto_switch_thd:
use_new_flag = True
else:
use_new_flag = False
if use_new_flag:
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
output = torch.matmul(x, buffer)
else:
output = matmul4bit(x, qweight, scales.float(), zeros.float())
output += bias
return output
class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod
def forward(ctx, x, qweight, scales, zeros):
ctx.save_for_backward(qweight, scales, zeros)
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
output = torch.matmul(x, buffer).clone()
def forward(ctx, x, qweight, scales, zeros, groupsize=-1):
ctx.save_for_backward(qweight, scales, zeros, groupsize)
if groupsize == -1:
output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros)
else:
output = mm4b._matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize)
output = output.clone()
return output
@staticmethod
def backward(ctx, grad_output):
qweight, scales, zeros = ctx.saved_tensors
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
grad = torch.matmul(grad_output, buffer.T)
return grad, None, None, None
qweight, scales, zeros, groupsize = ctx.saved_tensors
if groupsize == -1:
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
else:
grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, groupsize=groupsize, transpose=True)
return grad, None, None, None, None
# Assumes layer is perfectly divisible into 256 * 256 blocks
class Autograd4bitQuantLinear(nn.Module):
def __init__(self, infeatures, outfeatures):
def __init__(self, infeatures, outfeatures, groupsize=-1):
super().__init__()
bits = 4
self.in_features = infeatures
self.out_features = outfeatures
self.bits = bits
self.register_buffer('zeros', torch.empty((outfeatures, 1)))
self.register_buffer('scales', torch.empty((outfeatures, 1)))
self.groupsize = groupsize
if groupsize == -1:
self.register_buffer('zeros', torch.empty((outfeatures, 1)))
self.register_buffer('scales', torch.empty((outfeatures, 1)))
else:
self.register_buffer('qzeros',
torch.empty((math.ceil(infeatures/groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
)
self.register_buffer('scales', torch.empty((math.ceil(infeatures/groupsize),outfeatures)))
self.register_buffer('bias', torch.empty(outfeatures))
self.register_buffer(
'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)
)
def forward(self, x):
if torch.is_grad_enabled():
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, self.zeros)
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales,
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
out += self.bias
else:
out = fast_4bit_forward(x, self.qweight, self.scales, self.zeros, self.bias)
out = mm4b.matmul4bit(x, self.qweight, self.scales,
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
out += self.bias
return out
def make_quant_for_4bit_autograd(module, names, name=''):
def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1):
if isinstance(module, Autograd4bitQuantLinear):
return
for attr in dir(module):
@ -171,17 +72,18 @@ def make_quant_for_4bit_autograd(module, names, name=''):
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
setattr(
module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features)
module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize)
)
for name1, child in module.named_children():
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1)
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize)
def model_to_half(model):
model.half()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
m.zeros = m.zeros.half()
if m.groupsize == -1:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
print('Converted as Half.')
@ -191,34 +93,40 @@ def model_to_float(model):
model.float()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
m.zeros = m.zeros.float()
if m.groupsize == -1:
m.zeros = m.zeros.float()
m.scales = m.scales.float()
m.bias = m.bias.float()
print('Converted as Float.')
def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_map="auto"):
import transformers
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(
child, layers=layers, name=name + '.' + name1 if name != '' else name1
))
return res
def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048):
import accelerate
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
from gptq_llama.modelutils import find_layers
print("Loading Model ...")
t0 = time.time()
with accelerate.init_empty_weights():
config = LlamaConfig.from_pretrained(config_path)
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = LlamaForCausalLM(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant_for_4bit_autograd(model, layers)
make_quant_for_4bit_autograd(model, layers, groupsize=groupsize)
model = accelerate.load_checkpoint_and_dispatch(
model=model,
checkpoint=model_path,
@ -226,7 +134,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_ma
no_split_module_classes=["LlamaDecoderLayer"]
)
model.seqlen = 2048
model.seqlen = seqlen
if half:
model_to_half(model)
@ -237,4 +145,4 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_ma
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer

View File

@ -42,7 +42,10 @@ if ft_config.gradient_checkpointing:
print('Disable Dropout.')
# Load Basic Model
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model, device_map=ft_config.device_map)
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir,
ft_config.llama_q4_model,
device_map=ft_config.device_map,
groupsize=ft_config.groupsize)
# Config Lora
lora_config = LoraConfig(

View File

@ -2,16 +2,18 @@ import os
import sys
import time
import torch
from autograd_4bit import load_llama_model_4bit_low_ram
from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear
config_path = './llama-13b-4bit/'
model_path = './llama-13b-4bit.pt'
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path)
print('Fitting 4bit scales and zeros to half')
for n, m in model.named_modules():
if '4bit' in str(type(m)):
m.zeros = m.zeros.half()
if isinstance(m, Autograd4bitQuantLinear):
if m.groupsize == -1:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
prompt = '''I think the meaning of life is'''
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)

139
matmul_utils_4bit.py Normal file
View File

@ -0,0 +1,139 @@
import torch
import numpy as np
import quant_cuda
# Global Buffer
buffer_mat_dic = {}
use_new = True
auto_switch = True
auto_switch_thd = 8
debug = False
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
if shape_of_qweight not in buffer_mat_dic.keys():
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
else:
if buffer_mat_dic[shape_of_qweight].device != device:
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device)
if buffer_mat_dic[shape_of_qweight].dtype != dtype:
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(dtype=dtype)
return buffer_mat_dic[shape_of_qweight]
def _matmul4bit_v1(x, qweight, scales, zeros):
"""
input x: (n, m)
qweight: (j, k)
where m == j*8
perform x @ qweight
return y:
"""
if debug:
print('_matmul4bit_v1')
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
dtype = x.dtype
x = x.half()
quant_cuda.vecquant4matmul_v1_faster(x, qweight, y, scales, zeros)
y = y.to(dtype)
return y.reshape(outshape)
def _matmul4bit_v2(x, qweight, scales, zeros, group_size):
"""
input x: (n, m)
qweight: (j, k)
where m == j*8
perform x @ qweight
return y:
"""
if debug:
print('_matmul4bit_v2')
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
dtype = x.dtype
x = x.half()
quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, group_size, x.shape[-1] // 2)
y = y.to(dtype)
return y.reshape(outshape)
def _matmul4bit_v1_recons(x, qweight, scales, zeros, transpose=False):
if debug:
print('_matmul4bit_v1_recons')
if not transpose:
assert qweight.shape[0] * 8 == x.shape[-1]
else:
assert qweight.shape[1] == x.shape[-1]
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant_cuda.vecquant4recons_v1(qweight, buffer, scales, zeros)
if not transpose:
output = torch.matmul(x, buffer)
else:
output = torch.matmul(x, buffer.T)
return output
def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False):
if debug:
print('_matmul4bit_v2_recons')
if not transpose:
assert qweight.shape[0] * 8 == x.shape[-1]
else:
assert qweight.shape[1] == x.shape[-1]
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, group_size)
if not transpose:
output = torch.matmul(x, buffer)
if transpose:
output = torch.matmul(x, buffer.T)
return output
def matmul4bit(x, qweight, scales, zeros, group_size=-1):
if group_size == -1:
# use v1
if use_new:
if auto_switch:
if np.prod(x.shape[:-1]) > auto_switch_thd:
output = _matmul4bit_v1_recons(x, qweight, scales, zeros)
else:
output = _matmul4bit_v1(x, qweight, scales, zeros)
else:
output = _matmul4bit_v1(x, qweight, scales, zeros)
else:
# use v2
if use_new:
if auto_switch:
if np.prod(x.shape[:-1]) > auto_switch_thd:
output = _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size)
else:
output = _matmul4bit_v2(x, qweight, scales, zeros, group_size)
else:
output = _matmul4bit_v2(x, qweight, scales, zeros, group_size)
return output
def v2_to_v1(scales, zeros):
"""
Convert zeros in V2 model to V1 model when group_num = 1, for debugging
"""
assert zeros.shape[0] == 1
z_mat = torch.zeros((zeros.shape[1], 256), dtype=torch.int, device=zeros.device) + zeros.reshape((-1,1))
z_buffer = torch.zeros((z_mat.shape[0] * 8, z_mat.shape[1]), dtype=torch.float16, device=zeros.device)
z_zeros = torch.zeros(z_mat.shape[1], dtype=torch.float16, device=zeros.device)
z_scales = torch.ones(z_mat.shape[1], dtype=torch.float16, device=zeros.device)
quant_cuda.vecquant4recons_v1(z_mat, z_buffer, z_scales, z_zeros)
z_buffer = z_buffer[:,0]
zeros_recons = z_buffer * scales + scales
return zeros_recons

View File

@ -14,7 +14,7 @@ def load_model_llama(*args, **kwargs):
print("Loading {} ...".format(model_path))
t0 = time.time()
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path)
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1)
model = PeftModel.from_pretrained(model, lora_path, device_map={'': 0}, torch_dtype=torch.float32)
print('{} Lora Applied.'.format(lora_path))
@ -22,7 +22,8 @@ def load_model_llama(*args, **kwargs):
print('Apply auto switch and half')
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt):
m.zeros = m.zeros.half()
if m.groupsize == -1:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
autograd_4bit.use_new = True