add g_idx support on cuda backend
This commit is contained in:
parent
b73f4e5e64
commit
8cf3bd4086
|
|
@ -14,7 +14,7 @@ class Finetune4bConfig:
|
||||||
gradient_checkpointing_ratio: float,
|
gradient_checkpointing_ratio: float,
|
||||||
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
|
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
|
||||||
checkpoint: bool, skip: bool, verbose: bool,
|
checkpoint: bool, skip: bool, verbose: bool,
|
||||||
txt_row_thd: int, use_eos_token: bool, groupsize: int,
|
txt_row_thd: int, use_eos_token: bool, groupsize: int, v1: bool,
|
||||||
local_rank: int, flash_attention: bool, backend: str
|
local_rank: int, flash_attention: bool, backend: str
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -46,7 +46,8 @@ class Finetune4bConfig:
|
||||||
verbose (bool): If output log of training
|
verbose (bool): If output log of training
|
||||||
txt_row_thd (int): Custom row thd for txt file
|
txt_row_thd (int): Custom row thd for txt file
|
||||||
use_eos_token (bool): Use Eos token instead of padding with 0
|
use_eos_token (bool): Use Eos token instead of padding with 0
|
||||||
groupsize (int): Group size of V2 model, use -1 to load V1 model
|
groupsize (int): Group size of V2 model
|
||||||
|
v1 (bool): v1 model flag
|
||||||
local_rank (int): local rank if using torch.distributed.launch
|
local_rank (int): local rank if using torch.distributed.launch
|
||||||
flash_attention (bool): Enables flash attention
|
flash_attention (bool): Enables flash attention
|
||||||
"""
|
"""
|
||||||
|
|
@ -85,6 +86,7 @@ class Finetune4bConfig:
|
||||||
if self.ddp:
|
if self.ddp:
|
||||||
self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size
|
self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size
|
||||||
self.groupsize = groupsize
|
self.groupsize = groupsize
|
||||||
|
self.v1 = v1
|
||||||
self.flash_attention = flash_attention
|
self.flash_attention = flash_attention
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
|
|
||||||
|
|
@ -99,5 +101,5 @@ class Finetune4bConfig:
|
||||||
f"{self.logging_steps=}\n" +\
|
f"{self.logging_steps=}\n" +\
|
||||||
f"{self.checkpoint=}\n{self.skip=}\n" +\
|
f"{self.checkpoint=}\n{self.skip=}\n" +\
|
||||||
f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}\n" +\
|
f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}\n" +\
|
||||||
f"{self.groupsize=}\n{self.backend=}\n"
|
f"{self.groupsize=}\n{self.v1=}\n{self.backend=}\n"
|
||||||
return s.replace("self.", "")
|
return s.replace("self.", "")
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,8 @@ def parse_commandline():
|
||||||
parser_training.add_argument("--use_eos_token", default=1, type=int, help="Use eos token instead if padding with 0. enable with 1, disable with 0.")
|
parser_training.add_argument("--use_eos_token", default=1, type=int, help="Use eos token instead if padding with 0. enable with 1, disable with 0.")
|
||||||
|
|
||||||
# V2 model support
|
# V2 model support
|
||||||
parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model, use -1 to load v1 model")
|
parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model")
|
||||||
|
parser_training.add_argument("--v1", action="store_true", help="Use V1 model")
|
||||||
|
|
||||||
# Multi GPU Support
|
# Multi GPU Support
|
||||||
parser_training.add_argument("--local_rank", type=int, default=0, help="local rank if using torch.distributed.launch")
|
parser_training.add_argument("--local_rank", type=int, default=0, help="local rank if using torch.distributed.launch")
|
||||||
|
|
@ -107,6 +108,7 @@ def get_config() -> Finetune4bConfig:
|
||||||
txt_row_thd=args["txt_row_thd"],
|
txt_row_thd=args["txt_row_thd"],
|
||||||
use_eos_token=args["use_eos_token"]!=0,
|
use_eos_token=args["use_eos_token"]!=0,
|
||||||
groupsize=args["groupsize"],
|
groupsize=args["groupsize"],
|
||||||
|
v1=args["v1"],
|
||||||
local_rank=args["local_rank"],
|
local_rank=args["local_rank"],
|
||||||
flash_attention=args["flash_attention"],
|
flash_attention=args["flash_attention"],
|
||||||
backend=args["backend"],
|
backend=args["backend"],
|
||||||
|
|
|
||||||
|
|
@ -12,27 +12,25 @@ class AutogradMatmul4bitCuda(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx, x, qweight, scales, zeros, g_idx, bits, maxq, groupsize=-1):
|
def forward(ctx, x, qweight, scales, zeros, g_idx, bits, maxq):
|
||||||
ctx.save_for_backward(qweight, scales, zeros)
|
ctx.save_for_backward(qweight, scales, zeros, g_idx)
|
||||||
ctx.groupsize = groupsize
|
if g_idx is None:
|
||||||
if groupsize == -1:
|
|
||||||
output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros)
|
output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros)
|
||||||
else:
|
else:
|
||||||
output = mm4b._matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize)
|
output = mm4b._matmul4bit_v2_recons(x, qweight, scales, zeros, g_idx)
|
||||||
output = output.clone()
|
output = output.clone()
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
qweight, scales, zeros = ctx.saved_tensors
|
qweight, scales, zeros, g_idx = ctx.saved_tensors
|
||||||
groupsize = ctx.groupsize
|
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
if groupsize == -1:
|
if g_idx is None:
|
||||||
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
|
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
|
||||||
else:
|
else:
|
||||||
grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, groupsize=groupsize, transpose=True)
|
grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, g_idx, transpose=True)
|
||||||
return grad, None, None, None, None, None, None, None
|
return grad, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -42,7 +40,7 @@ try:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx, x, qweight, scales, qzeros, g_idx, bits, maxq, groupsize=-1):
|
def forward(ctx, x, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
output = tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq)
|
output = tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||||
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
|
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
|
||||||
ctx.bits, ctx.maxq = bits, maxq
|
ctx.bits, ctx.maxq = bits, maxq
|
||||||
|
|
@ -58,7 +56,7 @@ try:
|
||||||
|
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
grad_input = tu.triton_matmul_transpose(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
|
grad_input = tu.triton_matmul_transpose(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||||
return grad_input, None, None, None, None, None, None, None
|
return grad_input, None, None, None, None, None, None
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print('Triton not found. Please run "pip install triton".')
|
print('Triton not found. Please run "pip install triton".')
|
||||||
|
|
@ -86,9 +84,9 @@ def switch_backend_to(to_backend):
|
||||||
raise ValueError('Backend not supported.')
|
raise ValueError('Backend not supported.')
|
||||||
|
|
||||||
|
|
||||||
def matmul4bit_with_backend(x, qweight, scales, qzeros, g_idx, bits, maxq, groupsize):
|
def matmul4bit_with_backend(x, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
if backend == 'cuda':
|
if backend == 'cuda':
|
||||||
return mm4b.matmul4bit(x, qweight, scales, qzeros, groupsize)
|
return mm4b.matmul4bit(x, qweight, scales, qzeros, g_idx)
|
||||||
elif backend == 'triton':
|
elif backend == 'triton':
|
||||||
assert qzeros.dtype == torch.int32
|
assert qzeros.dtype == torch.int32
|
||||||
return tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq)
|
return tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||||
|
|
@ -99,17 +97,20 @@ def matmul4bit_with_backend(x, qweight, scales, qzeros, g_idx, bits, maxq, group
|
||||||
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
||||||
class Autograd4bitQuantLinear(nn.Module):
|
class Autograd4bitQuantLinear(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_features, out_features, groupsize=-1):
|
def __init__(self, in_features, out_features, groupsize=-1, is_v1_model=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
bits = 4
|
bits = 4
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
self.maxq = 2 ** self.bits - 1
|
self.maxq = 2 ** self.bits - 1
|
||||||
|
groupsize = groupsize if groupsize != -1 else in_features
|
||||||
self.groupsize = groupsize
|
self.groupsize = groupsize
|
||||||
if groupsize == -1:
|
self.is_v1_model = is_v1_model
|
||||||
|
if is_v1_model:
|
||||||
self.register_buffer('zeros', torch.empty((out_features, 1)))
|
self.register_buffer('zeros', torch.empty((out_features, 1)))
|
||||||
self.register_buffer('scales', torch.empty((out_features, 1)))
|
self.register_buffer('scales', torch.empty((out_features, 1)))
|
||||||
|
self.g_idx = None
|
||||||
else:
|
else:
|
||||||
self.register_buffer('qzeros',
|
self.register_buffer('qzeros',
|
||||||
torch.empty((math.ceil(in_features/groupsize), out_features // 256 * (bits * 8)), dtype=torch.int32)
|
torch.empty((math.ceil(in_features/groupsize), out_features // 256 * (bits * 8)), dtype=torch.int32)
|
||||||
|
|
@ -125,19 +126,17 @@ class Autograd4bitQuantLinear(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if torch.is_grad_enabled():
|
if torch.is_grad_enabled():
|
||||||
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales,
|
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales,
|
||||||
self.qzeros if self.groupsize != -1 else self.zeros,
|
self.qzeros if not self.is_v1_model else self.zeros,
|
||||||
self.g_idx, self.bits, self.maxq,
|
self.g_idx, self.bits, self.maxq)
|
||||||
self.groupsize)
|
|
||||||
else:
|
else:
|
||||||
out = matmul4bit_with_backend(x, self.qweight, self.scales,
|
out = matmul4bit_with_backend(x, self.qweight, self.scales,
|
||||||
self.qzeros if self.groupsize != -1 else self.zeros,
|
self.qzeros if not self.is_v1_model else self.zeros,
|
||||||
self.g_idx, self.bits, self.maxq,
|
self.g_idx, self.bits, self.maxq)
|
||||||
self.groupsize)
|
|
||||||
out += self.bias
|
out += self.bias
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1):
|
def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1, is_v1_model=False):
|
||||||
if isinstance(module, Autograd4bitQuantLinear):
|
if isinstance(module, Autograd4bitQuantLinear):
|
||||||
return
|
return
|
||||||
for attr in dir(module):
|
for attr in dir(module):
|
||||||
|
|
@ -145,17 +144,17 @@ def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1):
|
||||||
name1 = name + '.' + attr if name != '' else attr
|
name1 = name + '.' + attr if name != '' else attr
|
||||||
if name1 in names:
|
if name1 in names:
|
||||||
setattr(
|
setattr(
|
||||||
module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize)
|
module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize, is_v1_model=is_v1_model)
|
||||||
)
|
)
|
||||||
for name1, child in module.named_children():
|
for name1, child in module.named_children():
|
||||||
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize)
|
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize, is_v1_model=is_v1_model)
|
||||||
|
|
||||||
|
|
||||||
def model_to_half(model):
|
def model_to_half(model):
|
||||||
model.half()
|
model.half()
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
if isinstance(m, Autograd4bitQuantLinear):
|
if isinstance(m, Autograd4bitQuantLinear):
|
||||||
if m.groupsize == -1:
|
if m.is_v1_model:
|
||||||
m.zeros = m.zeros.half()
|
m.zeros = m.zeros.half()
|
||||||
m.scales = m.scales.half()
|
m.scales = m.scales.half()
|
||||||
m.bias = m.bias.half()
|
m.bias = m.bias.half()
|
||||||
|
|
@ -166,7 +165,7 @@ def model_to_float(model):
|
||||||
model.float()
|
model.float()
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
if isinstance(m, Autograd4bitQuantLinear):
|
if isinstance(m, Autograd4bitQuantLinear):
|
||||||
if m.groupsize == -1:
|
if m.is_v1_model:
|
||||||
m.zeros = m.zeros.float()
|
m.zeros = m.zeros.float()
|
||||||
m.scales = m.scales.float()
|
m.scales = m.scales.float()
|
||||||
m.bias = m.bias.float()
|
m.bias = m.bias.float()
|
||||||
|
|
@ -184,7 +183,7 @@ def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048):
|
def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048, is_v1_model=False):
|
||||||
import accelerate
|
import accelerate
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
|
|
@ -199,7 +198,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=Fa
|
||||||
for name in ['lm_head']:
|
for name in ['lm_head']:
|
||||||
if name in layers:
|
if name in layers:
|
||||||
del layers[name]
|
del layers[name]
|
||||||
make_quant_for_4bit_autograd(model, layers, groupsize=groupsize)
|
make_quant_for_4bit_autograd(model, layers, groupsize=groupsize, is_v1_model=is_v1_model)
|
||||||
model = accelerate.load_checkpoint_and_dispatch(
|
model = accelerate.load_checkpoint_and_dispatch(
|
||||||
model=model,
|
model=model,
|
||||||
checkpoint=model_path,
|
checkpoint=model_path,
|
||||||
|
|
@ -219,7 +218,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=Fa
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None):
|
def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None, is_v1_model=False):
|
||||||
import accelerate
|
import accelerate
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
|
|
@ -237,7 +236,7 @@ def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path
|
||||||
for name in ['lm_head']:
|
for name in ['lm_head']:
|
||||||
if name in layers:
|
if name in layers:
|
||||||
del layers[name]
|
del layers[name]
|
||||||
make_quant_for_4bit_autograd(model, layers, groupsize=groupsize)
|
make_quant_for_4bit_autograd(model, layers, groupsize=groupsize, is_v1_model=is_v1_model)
|
||||||
accelerate.load_checkpoint_in_model(model, checkpoint=model_path, device_map={'': 'cpu'})
|
accelerate.load_checkpoint_in_model(model, checkpoint=model_path, device_map={'': 'cpu'})
|
||||||
|
|
||||||
# rotary_emb fix
|
# rotary_emb fix
|
||||||
|
|
@ -258,7 +257,7 @@ def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path
|
||||||
print('Apply half ...')
|
print('Apply half ...')
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)):
|
if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)):
|
||||||
if m.groupsize == -1:
|
if m.is_v1_model:
|
||||||
m.zeros = m.zeros.half()
|
m.zeros = m.zeros.half()
|
||||||
m.scales = m.scales.half()
|
m.scales = m.scales.half()
|
||||||
m.bias = m.bias.half()
|
m.bias = m.bias.half()
|
||||||
|
|
|
||||||
|
|
@ -44,8 +44,6 @@ from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftMode
|
||||||
# ! Config
|
# ! Config
|
||||||
import train_data
|
import train_data
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# * Show loaded parameters
|
# * Show loaded parameters
|
||||||
if ft_config.local_rank == 0:
|
if ft_config.local_rank == 0:
|
||||||
print(f"{ft_config}\n")
|
print(f"{ft_config}\n")
|
||||||
|
|
@ -57,7 +55,8 @@ if ft_config.gradient_checkpointing:
|
||||||
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir,
|
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir,
|
||||||
ft_config.llama_q4_model,
|
ft_config.llama_q4_model,
|
||||||
device_map=ft_config.device_map,
|
device_map=ft_config.device_map,
|
||||||
groupsize=ft_config.groupsize)
|
groupsize=ft_config.groupsize,
|
||||||
|
is_v1_model=ft_config.v1)
|
||||||
|
|
||||||
# Config Lora
|
# Config Lora
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ def _matmul4bit_v1(x, qweight, scales, zeros):
|
||||||
return y.reshape(outshape)
|
return y.reshape(outshape)
|
||||||
|
|
||||||
|
|
||||||
def _matmul4bit_v2(x, qweight, scales, zeros, groupsize):
|
def _matmul4bit_v2(x, qweight, scales, zeros, g_idx):
|
||||||
"""
|
"""
|
||||||
input x: (n, m)
|
input x: (n, m)
|
||||||
qweight: (j, k)
|
qweight: (j, k)
|
||||||
|
|
@ -63,7 +63,7 @@ def _matmul4bit_v2(x, qweight, scales, zeros, groupsize):
|
||||||
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
x = x.half()
|
x = x.half()
|
||||||
quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, groupsize, x.shape[-1] // 2)
|
quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, g_idx, x.shape[-1] // 2)
|
||||||
y = y.to(dtype)
|
y = y.to(dtype)
|
||||||
return y.reshape(outshape)
|
return y.reshape(outshape)
|
||||||
|
|
||||||
|
|
@ -84,7 +84,7 @@ def _matmul4bit_v1_recons(x, qweight, scales, zeros, transpose=False):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False):
|
def _matmul4bit_v2_recons(x, qweight, scales, zeros, g_idx, transpose=False):
|
||||||
if debug:
|
if debug:
|
||||||
print('_matmul4bit_v2_recons')
|
print('_matmul4bit_v2_recons')
|
||||||
if not transpose:
|
if not transpose:
|
||||||
|
|
@ -92,7 +92,7 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False)
|
||||||
else:
|
else:
|
||||||
assert qweight.shape[1] == x.shape[-1]
|
assert qweight.shape[1] == x.shape[-1]
|
||||||
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
||||||
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, groupsize)
|
quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, g_idx)
|
||||||
if not transpose:
|
if not transpose:
|
||||||
output = torch.matmul(x, buffer)
|
output = torch.matmul(x, buffer)
|
||||||
else:
|
else:
|
||||||
|
|
@ -100,8 +100,9 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def matmul4bit(x, qweight, scales, zeros, groupsize=-1):
|
def matmul4bit(x, qweight, scales, zeros, g_idx=None):
|
||||||
if groupsize == -1:
|
# detect if zeros is int32
|
||||||
|
if zeros.dtype == torch.int32:
|
||||||
# use v1
|
# use v1
|
||||||
if use_new:
|
if use_new:
|
||||||
if auto_switch:
|
if auto_switch:
|
||||||
|
|
@ -112,21 +113,24 @@ def matmul4bit(x, qweight, scales, zeros, groupsize=-1):
|
||||||
else:
|
else:
|
||||||
output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float())
|
output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float())
|
||||||
else:
|
else:
|
||||||
|
if g_idx is None:
|
||||||
|
g_idx = torch.zeros(qweight.shape[0] * 8, dtype=torch.int32, device=x.device)
|
||||||
# use v2
|
# use v2
|
||||||
if use_new:
|
if use_new:
|
||||||
if auto_switch:
|
if auto_switch:
|
||||||
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
||||||
output = _matmul4bit_v2_recons(x.to(scales.dtype), qweight, scales, zeros, groupsize)
|
output = _matmul4bit_v2_recons(x.to(scales.dtype), qweight, scales, zeros, g_idx)
|
||||||
else:
|
else:
|
||||||
output = _matmul4bit_v2(x, qweight, scales.float(), zeros, groupsize)
|
output = _matmul4bit_v2(x, qweight, scales.float(), zeros, g_idx)
|
||||||
else:
|
else:
|
||||||
output = _matmul4bit_v2(x, qweight, scales.float(), zeros, groupsize)
|
output = _matmul4bit_v2(x, qweight, scales.float(), zeros, g_idx)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def v2_to_v1(scales, zeros):
|
def v2_to_v1(scales, zeros):
|
||||||
"""
|
"""
|
||||||
Convert zeros in V2 model to V1 model when group_num = 1, for debugging
|
Convert zeros in V2 model to V1 model when group_num = 1, for debugging
|
||||||
|
depreciated
|
||||||
"""
|
"""
|
||||||
assert zeros.shape[0] == 1
|
assert zeros.shape[0] == 1
|
||||||
z_mat = torch.zeros((zeros.shape[1], 256), dtype=torch.int, device=zeros.device) + zeros.reshape((-1,1))
|
z_mat = torch.zeros((zeros.shape[1], 256), dtype=torch.int, device=zeros.device) + zeros.reshape((-1,1))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue