add v2 model support
This commit is contained in:
parent
667e43cb5b
commit
bff039de95
|
|
@ -13,7 +13,7 @@ class Finetune4bConfig:
|
||||||
gradient_checkpointing: bool,
|
gradient_checkpointing: bool,
|
||||||
gradient_checkpointing_ratio: float,
|
gradient_checkpointing_ratio: float,
|
||||||
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
|
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
|
||||||
checkpoint: bool, skip: bool
|
checkpoint: bool, skip: bool, groupsize: int
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -40,6 +40,7 @@ class Finetune4bConfig:
|
||||||
logging_steps (int): Logging steps
|
logging_steps (int): Logging steps
|
||||||
checkpoint (bool): Produce checkpoint instead of LoRA
|
checkpoint (bool): Produce checkpoint instead of LoRA
|
||||||
skip (bool): Don't train model
|
skip (bool): Don't train model
|
||||||
|
groupsize (int): Group size of V2 model, use -1 to load V1 model
|
||||||
"""
|
"""
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.ds_type = ds_type
|
self.ds_type = ds_type
|
||||||
|
|
@ -71,6 +72,7 @@ class Finetune4bConfig:
|
||||||
self.device_map = "auto" if not self.ddp else {"": self.local_rank}
|
self.device_map = "auto" if not self.ddp else {"": self.local_rank}
|
||||||
if self.ddp:
|
if self.ddp:
|
||||||
self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size
|
self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size
|
||||||
|
self.groupsize = groupsize
|
||||||
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,9 @@ def parse_commandline():
|
||||||
parser_training.add_argument("-c", "--checkpoint", action="store_true", help="Produce checkpoint instead of LoRA. Default: %(default)s")
|
parser_training.add_argument("-c", "--checkpoint", action="store_true", help="Produce checkpoint instead of LoRA. Default: %(default)s")
|
||||||
parser_training.add_argument("--skip", action="store_true", help="Don't train model. Can be useful to produce checkpoint from existing LoRA. Default: %(default)s")
|
parser_training.add_argument("--skip", action="store_true", help="Don't train model. Can be useful to produce checkpoint from existing LoRA. Default: %(default)s")
|
||||||
|
|
||||||
|
# V2 model support
|
||||||
|
parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model, use -1 to load v1 model")
|
||||||
|
|
||||||
return vars(parser.parse_args())
|
return vars(parser.parse_args())
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -81,5 +84,6 @@ def get_config() -> Finetune4bConfig:
|
||||||
save_total_limit=args["save_total_limit"],
|
save_total_limit=args["save_total_limit"],
|
||||||
logging_steps=args["logging_steps"],
|
logging_steps=args["logging_steps"],
|
||||||
checkpoint=args["checkpoint"],
|
checkpoint=args["checkpoint"],
|
||||||
skip=args["skip"]
|
skip=args["skip"],
|
||||||
|
groupsize=args["groupsize"]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
200
autograd_4bit.py
200
autograd_4bit.py
|
|
@ -1,169 +1,70 @@
|
||||||
from gptq_llama import quant
|
import matmul_utils_4bit as mm4b
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import time
|
import time
|
||||||
|
import math
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
|
||||||
# Global Buffer
|
|
||||||
buffer_mat_dic = {}
|
|
||||||
use_new = True
|
|
||||||
auto_switch = True
|
|
||||||
auto_switch_thd = 16
|
|
||||||
|
|
||||||
|
|
||||||
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
|
|
||||||
if shape_of_qweight not in buffer_mat_dic.keys():
|
|
||||||
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
|
|
||||||
elif buffer_mat_dic[shape_of_qweight].device != device:
|
|
||||||
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device)
|
|
||||||
return buffer_mat_dic[shape_of_qweight]
|
|
||||||
|
|
||||||
|
|
||||||
def matmul4bit(x, qweight, scales, zeros):
|
|
||||||
"""
|
|
||||||
input x: (n, m)
|
|
||||||
qweight: (j, k)
|
|
||||||
where m == j*8
|
|
||||||
|
|
||||||
perform x @ qweight
|
|
||||||
|
|
||||||
return y:
|
|
||||||
"""
|
|
||||||
assert qweight.shape[0] * 8 == x.shape[-1]
|
|
||||||
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
|
|
||||||
x = x.reshape(-1, x.shape[-1])
|
|
||||||
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
|
||||||
dtype = x.dtype
|
|
||||||
x = x.float()
|
|
||||||
quant.quant_cuda.vecquant4matmul(x, qweight, y, scales, zeros)
|
|
||||||
y = y.to(dtype)
|
|
||||||
return y.reshape(outshape)
|
|
||||||
|
|
||||||
|
|
||||||
def matmul4bit_transpose(x, qweight, scales, zeros):
|
|
||||||
"""
|
|
||||||
input x: (n, m)
|
|
||||||
qweight: (j, k)
|
|
||||||
where m == k
|
|
||||||
|
|
||||||
perform qweight @ x.T
|
|
||||||
|
|
||||||
return y:
|
|
||||||
"""
|
|
||||||
assert qweight.shape[1] == x.shape[-1]
|
|
||||||
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8])
|
|
||||||
x = x.reshape(-1, x.shape[-1])
|
|
||||||
y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=torch.float32, device=x.device)
|
|
||||||
dtype = x.dtype
|
|
||||||
x = x.float()
|
|
||||||
quant.quant_cuda.vecquant4transposematmul(x, qweight, y, scales, zeros)
|
|
||||||
y = y.to(dtype)
|
|
||||||
return y.reshape(outshape)
|
|
||||||
|
|
||||||
|
|
||||||
def matmul4bit_half(x, qweight, scales, zeros):
|
|
||||||
"""
|
|
||||||
input x: (n, m)
|
|
||||||
qweight: (j, k)
|
|
||||||
where m == j*8
|
|
||||||
|
|
||||||
perform x @ qweight
|
|
||||||
|
|
||||||
return y:
|
|
||||||
"""
|
|
||||||
assert qweight.shape[0] * 8 == x.shape[-1]
|
|
||||||
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
|
|
||||||
x = x.reshape(-1, x.shape[-1])
|
|
||||||
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=x.dtype, device=x.device)
|
|
||||||
dtype = x.dtype
|
|
||||||
quant.quant_cuda.vecquant4matmul_half(x, qweight, y, scales, zeros)
|
|
||||||
y = y.to(dtype)
|
|
||||||
return y.reshape(outshape)
|
|
||||||
|
|
||||||
|
|
||||||
def matmul4bit_transpose_half(x, qweight, scales, zeros):
|
|
||||||
"""
|
|
||||||
input x: (n, m)
|
|
||||||
qweight: (j, k)
|
|
||||||
where m == k
|
|
||||||
|
|
||||||
perform qweight @ x.T
|
|
||||||
|
|
||||||
return y:
|
|
||||||
"""
|
|
||||||
assert qweight.shape[1] == x.shape[-1]
|
|
||||||
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8])
|
|
||||||
x = x.reshape(-1, x.shape[-1])
|
|
||||||
y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=x.dtype, device=x.device)
|
|
||||||
dtype = x.dtype
|
|
||||||
quant.quant_cuda.vecquant4transposematmul_half(x, qweight, y, scales, zeros)
|
|
||||||
y = y.to(dtype)
|
|
||||||
return y.reshape(outshape)
|
|
||||||
|
|
||||||
|
|
||||||
def fast_4bit_forward(x, qweight, scales, zeros, bias):
|
|
||||||
use_new_flag = use_new
|
|
||||||
if auto_switch:
|
|
||||||
if x.shape[1] > auto_switch_thd:
|
|
||||||
use_new_flag = True
|
|
||||||
else:
|
|
||||||
use_new_flag = False
|
|
||||||
if use_new_flag:
|
|
||||||
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
|
||||||
quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
|
|
||||||
output = torch.matmul(x, buffer)
|
|
||||||
else:
|
|
||||||
output = matmul4bit(x, qweight, scales.float(), zeros.float())
|
|
||||||
output += bias
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class AutogradMatmul4bit(torch.autograd.Function):
|
class AutogradMatmul4bit(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x, qweight, scales, zeros):
|
def forward(ctx, x, qweight, scales, zeros, groupsize=-1):
|
||||||
ctx.save_for_backward(qweight, scales, zeros)
|
ctx.save_for_backward(qweight, scales, zeros, groupsize)
|
||||||
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
if groupsize == -1:
|
||||||
quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
|
output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros)
|
||||||
output = torch.matmul(x, buffer).clone()
|
else:
|
||||||
|
output = mm4b._matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize)
|
||||||
|
output = output.clone()
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
qweight, scales, zeros = ctx.saved_tensors
|
qweight, scales, zeros, groupsize = ctx.saved_tensors
|
||||||
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
if groupsize == -1:
|
||||||
quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
|
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
|
||||||
grad = torch.matmul(grad_output, buffer.T)
|
else:
|
||||||
return grad, None, None, None
|
grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, groupsize=groupsize, transpose=True)
|
||||||
|
return grad, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
||||||
class Autograd4bitQuantLinear(nn.Module):
|
class Autograd4bitQuantLinear(nn.Module):
|
||||||
|
|
||||||
def __init__(self, infeatures, outfeatures):
|
def __init__(self, infeatures, outfeatures, groupsize=-1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
bits = 4
|
bits = 4
|
||||||
self.in_features = infeatures
|
self.in_features = infeatures
|
||||||
self.out_features = outfeatures
|
self.out_features = outfeatures
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
self.register_buffer('zeros', torch.empty((outfeatures, 1)))
|
self.groupsize = groupsize
|
||||||
self.register_buffer('scales', torch.empty((outfeatures, 1)))
|
if groupsize == -1:
|
||||||
|
self.register_buffer('zeros', torch.empty((outfeatures, 1)))
|
||||||
|
self.register_buffer('scales', torch.empty((outfeatures, 1)))
|
||||||
|
else:
|
||||||
|
self.register_buffer('qzeros',
|
||||||
|
torch.empty((math.ceil(infeatures/groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
|
||||||
|
)
|
||||||
|
self.register_buffer('scales', torch.empty((math.ceil(infeatures/groupsize),outfeatures)))
|
||||||
self.register_buffer('bias', torch.empty(outfeatures))
|
self.register_buffer('bias', torch.empty(outfeatures))
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)
|
'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if torch.is_grad_enabled():
|
if torch.is_grad_enabled():
|
||||||
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, self.zeros)
|
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales,
|
||||||
|
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
|
||||||
out += self.bias
|
out += self.bias
|
||||||
else:
|
else:
|
||||||
out = fast_4bit_forward(x, self.qweight, self.scales, self.zeros, self.bias)
|
out = mm4b.matmul4bit(x, self.qweight, self.scales,
|
||||||
|
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
|
||||||
|
out += self.bias
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def make_quant_for_4bit_autograd(module, names, name=''):
|
def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1):
|
||||||
if isinstance(module, Autograd4bitQuantLinear):
|
if isinstance(module, Autograd4bitQuantLinear):
|
||||||
return
|
return
|
||||||
for attr in dir(module):
|
for attr in dir(module):
|
||||||
|
|
@ -171,17 +72,18 @@ def make_quant_for_4bit_autograd(module, names, name=''):
|
||||||
name1 = name + '.' + attr if name != '' else attr
|
name1 = name + '.' + attr if name != '' else attr
|
||||||
if name1 in names:
|
if name1 in names:
|
||||||
setattr(
|
setattr(
|
||||||
module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features)
|
module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize)
|
||||||
)
|
)
|
||||||
for name1, child in module.named_children():
|
for name1, child in module.named_children():
|
||||||
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1)
|
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize)
|
||||||
|
|
||||||
|
|
||||||
def model_to_half(model):
|
def model_to_half(model):
|
||||||
model.half()
|
model.half()
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
if isinstance(m, Autograd4bitQuantLinear):
|
if isinstance(m, Autograd4bitQuantLinear):
|
||||||
m.zeros = m.zeros.half()
|
if m.groupsize == -1:
|
||||||
|
m.zeros = m.zeros.half()
|
||||||
m.scales = m.scales.half()
|
m.scales = m.scales.half()
|
||||||
m.bias = m.bias.half()
|
m.bias = m.bias.half()
|
||||||
print('Converted as Half.')
|
print('Converted as Half.')
|
||||||
|
|
@ -191,34 +93,40 @@ def model_to_float(model):
|
||||||
model.float()
|
model.float()
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
if isinstance(m, Autograd4bitQuantLinear):
|
if isinstance(m, Autograd4bitQuantLinear):
|
||||||
m.zeros = m.zeros.float()
|
if m.groupsize == -1:
|
||||||
|
m.zeros = m.zeros.float()
|
||||||
m.scales = m.scales.float()
|
m.scales = m.scales.float()
|
||||||
m.bias = m.bias.float()
|
m.bias = m.bias.float()
|
||||||
print('Converted as Float.')
|
print('Converted as Float.')
|
||||||
|
|
||||||
|
|
||||||
def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_map="auto"):
|
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
||||||
import transformers
|
if type(module) in layers:
|
||||||
|
return {name: module}
|
||||||
|
res = {}
|
||||||
|
for name1, child in module.named_children():
|
||||||
|
res.update(find_layers(
|
||||||
|
child, layers=layers, name=name + '.' + name1 if name != '' else name1
|
||||||
|
))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048):
|
||||||
import accelerate
|
import accelerate
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
from gptq_llama.modelutils import find_layers
|
|
||||||
|
|
||||||
print("Loading Model ...")
|
print("Loading Model ...")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
config = LlamaConfig.from_pretrained(config_path)
|
config = LlamaConfig.from_pretrained(config_path)
|
||||||
torch.set_default_dtype(torch.half)
|
|
||||||
transformers.modeling_utils._init_weights = False
|
|
||||||
torch.set_default_dtype(torch.half)
|
|
||||||
model = LlamaForCausalLM(config)
|
model = LlamaForCausalLM(config)
|
||||||
torch.set_default_dtype(torch.float)
|
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
layers = find_layers(model)
|
layers = find_layers(model)
|
||||||
for name in ['lm_head']:
|
for name in ['lm_head']:
|
||||||
if name in layers:
|
if name in layers:
|
||||||
del layers[name]
|
del layers[name]
|
||||||
make_quant_for_4bit_autograd(model, layers)
|
make_quant_for_4bit_autograd(model, layers, groupsize=groupsize)
|
||||||
model = accelerate.load_checkpoint_and_dispatch(
|
model = accelerate.load_checkpoint_and_dispatch(
|
||||||
model=model,
|
model=model,
|
||||||
checkpoint=model_path,
|
checkpoint=model_path,
|
||||||
|
|
@ -226,7 +134,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_ma
|
||||||
no_split_module_classes=["LlamaDecoderLayer"]
|
no_split_module_classes=["LlamaDecoderLayer"]
|
||||||
)
|
)
|
||||||
|
|
||||||
model.seqlen = 2048
|
model.seqlen = seqlen
|
||||||
|
|
||||||
if half:
|
if half:
|
||||||
model_to_half(model)
|
model_to_half(model)
|
||||||
|
|
@ -237,4 +145,4 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_ma
|
||||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
@ -42,7 +42,10 @@ if ft_config.gradient_checkpointing:
|
||||||
print('Disable Dropout.')
|
print('Disable Dropout.')
|
||||||
|
|
||||||
# Load Basic Model
|
# Load Basic Model
|
||||||
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model, device_map=ft_config.device_map)
|
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir,
|
||||||
|
ft_config.llama_q4_model,
|
||||||
|
device_map=ft_config.device_map,
|
||||||
|
groupsize=ft_config.groupsize)
|
||||||
|
|
||||||
# Config Lora
|
# Config Lora
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
|
|
|
||||||
|
|
@ -2,16 +2,18 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
from autograd_4bit import load_llama_model_4bit_low_ram
|
from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear
|
||||||
config_path = './llama-13b-4bit/'
|
config_path = './llama-13b-4bit/'
|
||||||
model_path = './llama-13b-4bit.pt'
|
model_path = './llama-13b-4bit.pt'
|
||||||
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path)
|
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path)
|
||||||
|
|
||||||
print('Fitting 4bit scales and zeros to half')
|
print('Fitting 4bit scales and zeros to half')
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
if '4bit' in str(type(m)):
|
if isinstance(m, Autograd4bitQuantLinear):
|
||||||
m.zeros = m.zeros.half()
|
if m.groupsize == -1:
|
||||||
|
m.zeros = m.zeros.half()
|
||||||
m.scales = m.scales.half()
|
m.scales = m.scales.half()
|
||||||
|
m.bias = m.bias.half()
|
||||||
|
|
||||||
prompt = '''I think the meaning of life is'''
|
prompt = '''I think the meaning of life is'''
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,139 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import quant_cuda
|
||||||
|
|
||||||
|
|
||||||
|
# Global Buffer
|
||||||
|
buffer_mat_dic = {}
|
||||||
|
use_new = True
|
||||||
|
auto_switch = True
|
||||||
|
auto_switch_thd = 8
|
||||||
|
debug = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
|
||||||
|
if shape_of_qweight not in buffer_mat_dic.keys():
|
||||||
|
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
if buffer_mat_dic[shape_of_qweight].device != device:
|
||||||
|
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(device)
|
||||||
|
if buffer_mat_dic[shape_of_qweight].dtype != dtype:
|
||||||
|
buffer_mat_dic[shape_of_qweight] = buffer_mat_dic[shape_of_qweight].to(dtype=dtype)
|
||||||
|
return buffer_mat_dic[shape_of_qweight]
|
||||||
|
|
||||||
|
|
||||||
|
def _matmul4bit_v1(x, qweight, scales, zeros):
|
||||||
|
"""
|
||||||
|
input x: (n, m)
|
||||||
|
qweight: (j, k)
|
||||||
|
where m == j*8
|
||||||
|
|
||||||
|
perform x @ qweight
|
||||||
|
|
||||||
|
return y:
|
||||||
|
"""
|
||||||
|
if debug:
|
||||||
|
print('_matmul4bit_v1')
|
||||||
|
assert qweight.shape[0] * 8 == x.shape[-1]
|
||||||
|
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
||||||
|
dtype = x.dtype
|
||||||
|
x = x.half()
|
||||||
|
quant_cuda.vecquant4matmul_v1_faster(x, qweight, y, scales, zeros)
|
||||||
|
y = y.to(dtype)
|
||||||
|
return y.reshape(outshape)
|
||||||
|
|
||||||
|
|
||||||
|
def _matmul4bit_v2(x, qweight, scales, zeros, group_size):
|
||||||
|
"""
|
||||||
|
input x: (n, m)
|
||||||
|
qweight: (j, k)
|
||||||
|
where m == j*8
|
||||||
|
|
||||||
|
perform x @ qweight
|
||||||
|
|
||||||
|
return y:
|
||||||
|
"""
|
||||||
|
if debug:
|
||||||
|
print('_matmul4bit_v2')
|
||||||
|
assert qweight.shape[0] * 8 == x.shape[-1]
|
||||||
|
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
||||||
|
dtype = x.dtype
|
||||||
|
x = x.half()
|
||||||
|
quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, group_size, x.shape[-1] // 2)
|
||||||
|
y = y.to(dtype)
|
||||||
|
return y.reshape(outshape)
|
||||||
|
|
||||||
|
|
||||||
|
def _matmul4bit_v1_recons(x, qweight, scales, zeros, transpose=False):
|
||||||
|
if debug:
|
||||||
|
print('_matmul4bit_v1_recons')
|
||||||
|
if not transpose:
|
||||||
|
assert qweight.shape[0] * 8 == x.shape[-1]
|
||||||
|
else:
|
||||||
|
assert qweight.shape[1] == x.shape[-1]
|
||||||
|
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
||||||
|
quant_cuda.vecquant4recons_v1(qweight, buffer, scales, zeros)
|
||||||
|
if not transpose:
|
||||||
|
output = torch.matmul(x, buffer)
|
||||||
|
else:
|
||||||
|
output = torch.matmul(x, buffer.T)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size, transpose=False):
|
||||||
|
if debug:
|
||||||
|
print('_matmul4bit_v2_recons')
|
||||||
|
if not transpose:
|
||||||
|
assert qweight.shape[0] * 8 == x.shape[-1]
|
||||||
|
else:
|
||||||
|
assert qweight.shape[1] == x.shape[-1]
|
||||||
|
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
||||||
|
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, group_size)
|
||||||
|
if not transpose:
|
||||||
|
output = torch.matmul(x, buffer)
|
||||||
|
if transpose:
|
||||||
|
output = torch.matmul(x, buffer.T)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def matmul4bit(x, qweight, scales, zeros, group_size=-1):
|
||||||
|
if group_size == -1:
|
||||||
|
# use v1
|
||||||
|
if use_new:
|
||||||
|
if auto_switch:
|
||||||
|
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
||||||
|
output = _matmul4bit_v1_recons(x, qweight, scales, zeros)
|
||||||
|
else:
|
||||||
|
output = _matmul4bit_v1(x, qweight, scales, zeros)
|
||||||
|
else:
|
||||||
|
output = _matmul4bit_v1(x, qweight, scales, zeros)
|
||||||
|
else:
|
||||||
|
# use v2
|
||||||
|
if use_new:
|
||||||
|
if auto_switch:
|
||||||
|
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
||||||
|
output = _matmul4bit_v2_recons(x, qweight, scales, zeros, group_size)
|
||||||
|
else:
|
||||||
|
output = _matmul4bit_v2(x, qweight, scales, zeros, group_size)
|
||||||
|
else:
|
||||||
|
output = _matmul4bit_v2(x, qweight, scales, zeros, group_size)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def v2_to_v1(scales, zeros):
|
||||||
|
"""
|
||||||
|
Convert zeros in V2 model to V1 model when group_num = 1, for debugging
|
||||||
|
"""
|
||||||
|
assert zeros.shape[0] == 1
|
||||||
|
z_mat = torch.zeros((zeros.shape[1], 256), dtype=torch.int, device=zeros.device) + zeros.reshape((-1,1))
|
||||||
|
z_buffer = torch.zeros((z_mat.shape[0] * 8, z_mat.shape[1]), dtype=torch.float16, device=zeros.device)
|
||||||
|
z_zeros = torch.zeros(z_mat.shape[1], dtype=torch.float16, device=zeros.device)
|
||||||
|
z_scales = torch.ones(z_mat.shape[1], dtype=torch.float16, device=zeros.device)
|
||||||
|
quant_cuda.vecquant4recons_v1(z_mat, z_buffer, z_scales, z_zeros)
|
||||||
|
z_buffer = z_buffer[:,0]
|
||||||
|
zeros_recons = z_buffer * scales + scales
|
||||||
|
return zeros_recons
|
||||||
|
|
@ -14,7 +14,7 @@ def load_model_llama(*args, **kwargs):
|
||||||
print("Loading {} ...".format(model_path))
|
print("Loading {} ...".format(model_path))
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path)
|
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1)
|
||||||
|
|
||||||
model = PeftModel.from_pretrained(model, lora_path, device_map={'': 0}, torch_dtype=torch.float32)
|
model = PeftModel.from_pretrained(model, lora_path, device_map={'': 0}, torch_dtype=torch.float32)
|
||||||
print('{} Lora Applied.'.format(lora_path))
|
print('{} Lora Applied.'.format(lora_path))
|
||||||
|
|
@ -22,7 +22,8 @@ def load_model_llama(*args, **kwargs):
|
||||||
print('Apply auto switch and half')
|
print('Apply auto switch and half')
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt):
|
if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt):
|
||||||
m.zeros = m.zeros.half()
|
if m.groupsize == -1:
|
||||||
|
m.zeros = m.zeros.half()
|
||||||
m.scales = m.scales.half()
|
m.scales = m.scales.half()
|
||||||
m.bias = m.bias.half()
|
m.bias = m.bias.half()
|
||||||
autograd_4bit.use_new = True
|
autograd_4bit.use_new = True
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue