From 8cf3bd40864cea9cdd29d97ad18979cd337f6421 Mon Sep 17 00:00:00 2001 From: John Smith Date: Sun, 9 Apr 2023 12:26:22 +0800 Subject: [PATCH] add g_idx support on cuda backend --- Finetune4bConfig.py | 8 +++--- arg_parser.py | 4 ++- autograd_4bit.py | 63 ++++++++++++++++++++++---------------------- finetune.py | 5 ++-- matmul_utils_4bit.py | 22 +++++++++------- 5 files changed, 54 insertions(+), 48 deletions(-) diff --git a/Finetune4bConfig.py b/Finetune4bConfig.py index 459102a..2169f46 100644 --- a/Finetune4bConfig.py +++ b/Finetune4bConfig.py @@ -14,7 +14,7 @@ class Finetune4bConfig: gradient_checkpointing_ratio: float, warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int, checkpoint: bool, skip: bool, verbose: bool, - txt_row_thd: int, use_eos_token: bool, groupsize: int, + txt_row_thd: int, use_eos_token: bool, groupsize: int, v1: bool, local_rank: int, flash_attention: bool, backend: str ): """ @@ -46,7 +46,8 @@ class Finetune4bConfig: verbose (bool): If output log of training txt_row_thd (int): Custom row thd for txt file use_eos_token (bool): Use Eos token instead of padding with 0 - groupsize (int): Group size of V2 model, use -1 to load V1 model + groupsize (int): Group size of V2 model + v1 (bool): v1 model flag local_rank (int): local rank if using torch.distributed.launch flash_attention (bool): Enables flash attention """ @@ -85,6 +86,7 @@ class Finetune4bConfig: if self.ddp: self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size self.groupsize = groupsize + self.v1 = v1 self.flash_attention = flash_attention self.backend = backend @@ -99,5 +101,5 @@ class Finetune4bConfig: f"{self.logging_steps=}\n" +\ f"{self.checkpoint=}\n{self.skip=}\n" +\ f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}\n" +\ - f"{self.groupsize=}\n{self.backend=}\n" + f"{self.groupsize=}\n{self.v1=}\n{self.backend=}\n" return s.replace("self.", "") diff --git a/arg_parser.py b/arg_parser.py index b83c939..02c56d2 100644 --- a/arg_parser.py +++ b/arg_parser.py @@ -62,7 +62,8 @@ def parse_commandline(): parser_training.add_argument("--use_eos_token", default=1, type=int, help="Use eos token instead if padding with 0. enable with 1, disable with 0.") # V2 model support - parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model, use -1 to load v1 model") + parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model") + parser_training.add_argument("--v1", action="store_true", help="Use V1 model") # Multi GPU Support parser_training.add_argument("--local_rank", type=int, default=0, help="local rank if using torch.distributed.launch") @@ -107,6 +108,7 @@ def get_config() -> Finetune4bConfig: txt_row_thd=args["txt_row_thd"], use_eos_token=args["use_eos_token"]!=0, groupsize=args["groupsize"], + v1=args["v1"], local_rank=args["local_rank"], flash_attention=args["flash_attention"], backend=args["backend"], diff --git a/autograd_4bit.py b/autograd_4bit.py index 5d19908..544e429 100644 --- a/autograd_4bit.py +++ b/autograd_4bit.py @@ -12,27 +12,25 @@ class AutogradMatmul4bitCuda(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, x, qweight, scales, zeros, g_idx, bits, maxq, groupsize=-1): - ctx.save_for_backward(qweight, scales, zeros) - ctx.groupsize = groupsize - if groupsize == -1: + def forward(ctx, x, qweight, scales, zeros, g_idx, bits, maxq): + ctx.save_for_backward(qweight, scales, zeros, g_idx) + if g_idx is None: output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros) else: - output = mm4b._matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize) + output = mm4b._matmul4bit_v2_recons(x, qweight, scales, zeros, g_idx) output = output.clone() return output @staticmethod @custom_bwd def backward(ctx, grad_output): - qweight, scales, zeros = ctx.saved_tensors - groupsize = ctx.groupsize + qweight, scales, zeros, g_idx = ctx.saved_tensors if ctx.needs_input_grad[0]: - if groupsize == -1: + if g_idx is None: grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True) else: - grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, groupsize=groupsize, transpose=True) - return grad, None, None, None, None, None, None, None + grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, g_idx, transpose=True) + return grad, None, None, None, None, None, None try: @@ -42,7 +40,7 @@ try: @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, x, qweight, scales, qzeros, g_idx, bits, maxq, groupsize=-1): + def forward(ctx, x, qweight, scales, qzeros, g_idx, bits, maxq): output = tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq) ctx.save_for_backward(qweight, scales, qzeros, g_idx) ctx.bits, ctx.maxq = bits, maxq @@ -58,7 +56,7 @@ try: if ctx.needs_input_grad[0]: grad_input = tu.triton_matmul_transpose(grad_output, qweight, scales, qzeros, g_idx, bits, maxq) - return grad_input, None, None, None, None, None, None, None + return grad_input, None, None, None, None, None, None except ImportError: print('Triton not found. Please run "pip install triton".') @@ -86,9 +84,9 @@ def switch_backend_to(to_backend): raise ValueError('Backend not supported.') -def matmul4bit_with_backend(x, qweight, scales, qzeros, g_idx, bits, maxq, groupsize): +def matmul4bit_with_backend(x, qweight, scales, qzeros, g_idx, bits, maxq): if backend == 'cuda': - return mm4b.matmul4bit(x, qweight, scales, qzeros, groupsize) + return mm4b.matmul4bit(x, qweight, scales, qzeros, g_idx) elif backend == 'triton': assert qzeros.dtype == torch.int32 return tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq) @@ -99,17 +97,20 @@ def matmul4bit_with_backend(x, qweight, scales, qzeros, g_idx, bits, maxq, group # Assumes layer is perfectly divisible into 256 * 256 blocks class Autograd4bitQuantLinear(nn.Module): - def __init__(self, in_features, out_features, groupsize=-1): + def __init__(self, in_features, out_features, groupsize=-1, is_v1_model=False): super().__init__() bits = 4 self.in_features = in_features self.out_features = out_features self.bits = bits self.maxq = 2 ** self.bits - 1 + groupsize = groupsize if groupsize != -1 else in_features self.groupsize = groupsize - if groupsize == -1: + self.is_v1_model = is_v1_model + if is_v1_model: self.register_buffer('zeros', torch.empty((out_features, 1))) self.register_buffer('scales', torch.empty((out_features, 1))) + self.g_idx = None else: self.register_buffer('qzeros', torch.empty((math.ceil(in_features/groupsize), out_features // 256 * (bits * 8)), dtype=torch.int32) @@ -125,19 +126,17 @@ class Autograd4bitQuantLinear(nn.Module): def forward(self, x): if torch.is_grad_enabled(): out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, - self.qzeros if self.groupsize != -1 else self.zeros, - self.g_idx, self.bits, self.maxq, - self.groupsize) + self.qzeros if not self.is_v1_model else self.zeros, + self.g_idx, self.bits, self.maxq) else: out = matmul4bit_with_backend(x, self.qweight, self.scales, - self.qzeros if self.groupsize != -1 else self.zeros, - self.g_idx, self.bits, self.maxq, - self.groupsize) + self.qzeros if not self.is_v1_model else self.zeros, + self.g_idx, self.bits, self.maxq) out += self.bias return out -def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1): +def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1, is_v1_model=False): if isinstance(module, Autograd4bitQuantLinear): return for attr in dir(module): @@ -145,17 +144,17 @@ def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1): name1 = name + '.' + attr if name != '' else attr if name1 in names: setattr( - module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize) + module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize, is_v1_model=is_v1_model) ) for name1, child in module.named_children(): - make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize) + make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize, is_v1_model=is_v1_model) def model_to_half(model): model.half() for n, m in model.named_modules(): if isinstance(m, Autograd4bitQuantLinear): - if m.groupsize == -1: + if m.is_v1_model: m.zeros = m.zeros.half() m.scales = m.scales.half() m.bias = m.bias.half() @@ -166,7 +165,7 @@ def model_to_float(model): model.float() for n, m in model.named_modules(): if isinstance(m, Autograd4bitQuantLinear): - if m.groupsize == -1: + if m.is_v1_model: m.zeros = m.zeros.float() m.scales = m.scales.float() m.bias = m.bias.float() @@ -184,7 +183,7 @@ def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): return res -def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048): +def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048, is_v1_model=False): import accelerate from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer @@ -199,7 +198,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=Fa for name in ['lm_head']: if name in layers: del layers[name] - make_quant_for_4bit_autograd(model, layers, groupsize=groupsize) + make_quant_for_4bit_autograd(model, layers, groupsize=groupsize, is_v1_model=is_v1_model) model = accelerate.load_checkpoint_and_dispatch( model=model, checkpoint=model_path, @@ -219,7 +218,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=Fa return model, tokenizer -def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None): +def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None, is_v1_model=False): import accelerate from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer @@ -237,7 +236,7 @@ def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path for name in ['lm_head']: if name in layers: del layers[name] - make_quant_for_4bit_autograd(model, layers, groupsize=groupsize) + make_quant_for_4bit_autograd(model, layers, groupsize=groupsize, is_v1_model=is_v1_model) accelerate.load_checkpoint_in_model(model, checkpoint=model_path, device_map={'': 'cpu'}) # rotary_emb fix @@ -258,7 +257,7 @@ def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path print('Apply half ...') for n, m in model.named_modules(): if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)): - if m.groupsize == -1: + if m.is_v1_model: m.zeros = m.zeros.half() m.scales = m.scales.half() m.bias = m.bias.half() diff --git a/finetune.py b/finetune.py index 576ec31..ac72694 100644 --- a/finetune.py +++ b/finetune.py @@ -44,8 +44,6 @@ from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftMode # ! Config import train_data - - # * Show loaded parameters if ft_config.local_rank == 0: print(f"{ft_config}\n") @@ -57,7 +55,8 @@ if ft_config.gradient_checkpointing: model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model, device_map=ft_config.device_map, - groupsize=ft_config.groupsize) + groupsize=ft_config.groupsize, + is_v1_model=ft_config.v1) # Config Lora lora_config = LoraConfig( diff --git a/matmul_utils_4bit.py b/matmul_utils_4bit.py index 1be5e22..009093c 100644 --- a/matmul_utils_4bit.py +++ b/matmul_utils_4bit.py @@ -45,7 +45,7 @@ def _matmul4bit_v1(x, qweight, scales, zeros): return y.reshape(outshape) -def _matmul4bit_v2(x, qweight, scales, zeros, groupsize): +def _matmul4bit_v2(x, qweight, scales, zeros, g_idx): """ input x: (n, m) qweight: (j, k) @@ -63,7 +63,7 @@ def _matmul4bit_v2(x, qweight, scales, zeros, groupsize): y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) dtype = x.dtype x = x.half() - quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, groupsize, x.shape[-1] // 2) + quant_cuda.vecquant4matmul_faster(x, qweight, y, scales, zeros, g_idx, x.shape[-1] // 2) y = y.to(dtype) return y.reshape(outshape) @@ -84,7 +84,7 @@ def _matmul4bit_v1_recons(x, qweight, scales, zeros, transpose=False): return output -def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False): +def _matmul4bit_v2_recons(x, qweight, scales, zeros, g_idx, transpose=False): if debug: print('_matmul4bit_v2_recons') if not transpose: @@ -92,7 +92,7 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False) else: assert qweight.shape[1] == x.shape[-1] buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device) - quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, groupsize) + quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, g_idx) if not transpose: output = torch.matmul(x, buffer) else: @@ -100,8 +100,9 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False) return output -def matmul4bit(x, qweight, scales, zeros, groupsize=-1): - if groupsize == -1: +def matmul4bit(x, qweight, scales, zeros, g_idx=None): + # detect if zeros is int32 + if zeros.dtype == torch.int32: # use v1 if use_new: if auto_switch: @@ -112,21 +113,24 @@ def matmul4bit(x, qweight, scales, zeros, groupsize=-1): else: output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float()) else: + if g_idx is None: + g_idx = torch.zeros(qweight.shape[0] * 8, dtype=torch.int32, device=x.device) # use v2 if use_new: if auto_switch: if np.prod(x.shape[:-1]) > auto_switch_thd: - output = _matmul4bit_v2_recons(x.to(scales.dtype), qweight, scales, zeros, groupsize) + output = _matmul4bit_v2_recons(x.to(scales.dtype), qweight, scales, zeros, g_idx) else: - output = _matmul4bit_v2(x, qweight, scales.float(), zeros, groupsize) + output = _matmul4bit_v2(x, qweight, scales.float(), zeros, g_idx) else: - output = _matmul4bit_v2(x, qweight, scales.float(), zeros, groupsize) + output = _matmul4bit_v2(x, qweight, scales.float(), zeros, g_idx) return output def v2_to_v1(scales, zeros): """ Convert zeros in V2 model to V1 model when group_num = 1, for debugging + depreciated """ assert zeros.shape[0] == 1 z_mat = torch.zeros((zeros.shape[1], 256), dtype=torch.int, device=zeros.device) + zeros.reshape((-1,1))