From 8791eaee9a82175f14b21d213fe5505e113d3e3c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 30 Mar 2023 19:08:35 -0400 Subject: [PATCH] fix gpt4all training to more closely match the released logic, other small fixes and optimizations --- autograd_4bit.py | 22 +++++----- finetune.py | 36 ++++++++-------- matmul_utils_4bit.py | 18 ++++---- train_data.py | 98 ++++++++++++++++++++++++++++++++++++++------ 4 files changed, 124 insertions(+), 50 deletions(-) diff --git a/autograd_4bit.py b/autograd_4bit.py index 166a2a1..26597a5 100644 --- a/autograd_4bit.py +++ b/autograd_4bit.py @@ -30,7 +30,7 @@ class AutogradMatmul4bit(torch.autograd.Function): # Assumes layer is perfectly divisible into 256 * 256 blocks -class Autograd4bitQuantLinear(nn.Module): +class Autograd4bitQuantLinear(nn.Module): def __init__(self, infeatures, outfeatures, groupsize=-1): super().__init__() @@ -47,7 +47,7 @@ class Autograd4bitQuantLinear(nn.Module): torch.empty((math.ceil(infeatures/groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int) ) self.register_buffer('scales', torch.empty((math.ceil(infeatures/groupsize),outfeatures))) - self.register_buffer('bias', torch.empty(outfeatures)) + self.bias = nn.Parameter(torch.empty(outfeatures)) self.register_buffer( 'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int) ) @@ -57,11 +57,11 @@ class Autograd4bitQuantLinear(nn.Module): if torch.is_grad_enabled(): out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize) - out += self.bias + out.add_(self.bias) else: out = mm4b.matmul4bit(x, self.qweight, self.scales, self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize) - out += self.bias + out.add_(self.bias) return out @@ -115,7 +115,7 @@ def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048): import accelerate from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer - + print("Loading Model ...") t0 = time.time() @@ -136,7 +136,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=Fa ) model.seqlen = seqlen - + if half: model_to_half(model) @@ -144,9 +144,9 @@ def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=Fa tokenizer.truncation_side = 'left' print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") - + return model, tokenizer - + def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None): import accelerate from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer @@ -190,13 +190,13 @@ def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lo m.zeros = m.zeros.half() m.scales = m.scales.half() m.bias = m.bias.half() - + print('Dispatching model ...') device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"]) model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True, main_device=0) torch.cuda.empty_cache() print('Total {:.2f} Gib VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024)) - + # rotary_emb fix for n, m in model.named_modules(): if 'rotary_emb' in n: @@ -210,7 +210,7 @@ def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lo if n + '.sin_cached' not in hook.weights_map.dataset.state_dict.keys(): hook.weights_map.dataset.state_dict[n + '.sin_cached'] = sin_cached.clone().cpu() hook.weights_map.dataset.state_dict[n + '.cos_cached'] = cos_cached.clone().cpu() - + tokenizer = LlamaTokenizer.from_pretrained(config_path) tokenizer.truncation_side = 'left' diff --git a/finetune.py b/finetune.py index 666403f..32cc104 100644 --- a/finetune.py +++ b/finetune.py @@ -102,27 +102,29 @@ if not ft_config.skip: model.is_parallelizable = True model.model_parallel = True + training_arguments = transformers.TrainingArguments( + per_device_train_batch_size=ft_config.mbatch_size, + gradient_accumulation_steps=ft_config.gradient_accumulation_steps, + warmup_steps=ft_config.warmup_steps, + num_train_epochs=ft_config.epochs, + learning_rate=ft_config.lr, + fp16=True, + logging_steps=ft_config.logging_steps, + evaluation_strategy="no", + save_strategy="steps", + eval_steps=None, + save_steps=ft_config.save_steps, + output_dir=ft_config.lora_out_dir, + save_total_limit=ft_config.save_total_limit, + load_best_model_at_end=False, + ddp_find_unused_parameters=False if ft_config.ddp else None, + ) + trainer = transformers.Trainer( model=model, train_dataset=data.train_data, eval_dataset=data.val_data, - args=transformers.TrainingArguments( - per_device_train_batch_size=ft_config.mbatch_size, - gradient_accumulation_steps=ft_config.gradient_accumulation_steps, - warmup_steps=ft_config.warmup_steps, - num_train_epochs=ft_config.epochs, - learning_rate=ft_config.lr, - fp16=True, - logging_steps=ft_config.logging_steps, - evaluation_strategy="no", - save_strategy="steps", - eval_steps=None, - save_steps=ft_config.save_steps, - output_dir=ft_config.lora_out_dir, - save_total_limit=ft_config.save_total_limit, - load_best_model_at_end=False, - ddp_find_unused_parameters=False if ft_config.ddp else None, - ), + args=training_arguments, data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), ) model.config.use_cache = False diff --git a/matmul_utils_4bit.py b/matmul_utils_4bit.py index e9d621f..1be5e22 100644 --- a/matmul_utils_4bit.py +++ b/matmul_utils_4bit.py @@ -27,15 +27,15 @@ def _matmul4bit_v1(x, qweight, scales, zeros): input x: (n, m) qweight: (j, k) where m == j*8 - + perform x @ qweight - - return y: + + return y: """ if debug: print('_matmul4bit_v1') assert qweight.shape[0] * 8 == x.shape[-1] - outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]]) + outshape = x.shape[:-1] + (qweight.shape[1],) x = x.reshape(-1, x.shape[-1]) y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) dtype = x.dtype @@ -50,15 +50,15 @@ def _matmul4bit_v2(x, qweight, scales, zeros, groupsize): input x: (n, m) qweight: (j, k) where m == j*8 - + perform x @ qweight - - return y: + + return y: """ if debug: print('_matmul4bit_v2') assert qweight.shape[0] * 8 == x.shape[-1] - outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]]) + outshape = x.shape[:-1] + (qweight.shape[1],) x = x.reshape(-1, x.shape[-1]) y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) dtype = x.dtype @@ -95,7 +95,7 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize, transpose=False) quant_cuda.vecquant4recons_v2(qweight, buffer, scales, zeros, groupsize) if not transpose: output = torch.matmul(x, buffer) - if transpose: + else: output = torch.matmul(x, buffer.T) return output diff --git a/train_data.py b/train_data.py index 6ac3d9e..3e3817e 100644 --- a/train_data.py +++ b/train_data.py @@ -1,6 +1,10 @@ +import torch + from abc import ABC, abstractmethod from typing import Dict, Any from datasets import load_dataset, Dataset +from torch.utils.data import DataLoader +from transformers import DefaultDataCollator import os @@ -126,7 +130,7 @@ class TrainTxt(ATrainData): class TrainSAD(ATrainData): def __init__(self, dataset: str, val_set_size: int, tokenizer, cutoff_len) -> None: super().__init__(dataset, val_set_size, tokenizer, cutoff_len) - + def tokenize(self, prompt: str, use_eos_token=True, **kwargs) -> Dict[str, Any]: # there's probably a way to do this with the tokenizer settings # but again, gotta move fast @@ -186,16 +190,84 @@ class TrainSAD(ATrainData): return self.tokenize(prompt, **kwargs) # GPT4All-like Data -class TrainGPT4All(TrainSAD): - # Auxiliary methods - def generate_prompt(self, data_point, **kwargs): - return "{0}\n\n{1}\n{2}\n\n{3}\n{4}\n\n{5}\n{6}".format( - "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.", - "### Instruction:", - data_point["prompt"], - "### Input:", - "", - "### Response:", - data_point["response"] - ) +class TrainGPT4All(ATrainData): + def __init__(self, dataset: str, val_set_size: int, tokenizer, cutoff_len) -> None: + super().__init__(dataset, val_set_size, tokenizer, cutoff_len) + def tokenize(self, prompt: str, use_eos_token=True, **kwargs) -> Dict[str, Any]: + pass + + def tokenize_inputs(self, examples): + max_length = self.cutoff_len + input_ids = torch.full((len(examples["prompt"]), max_length), self.tokenizer.pad_token_id) + # ignore bos + newline_tokens = self.tokenizer("\n", return_tensors="pt")["input_ids"][0, 1:] + + out = {"labels": [], "attention_mask": []} + for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])): + input_tokens = self.tokenizer(prompt, truncation=True, max_length=max_length // 2, return_tensors="pt")["input_ids"].squeeze() + if input_tokens.dim() == 0: + input_tokens = input_tokens.unsqueeze(0) + + input_len = len(input_tokens) + + # plus one since we remove bos from response + # but we subtract one since we want to add eos token + remaining_tokens = max_length - input_len - len(newline_tokens) + 1 + # remove bos + target_tokens = self.tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:] + + input_ids[i, :input_len] = input_tokens + # add newline between prompt and response + newline_plus_inputs = input_len + len(newline_tokens) + input_ids[i, input_len: newline_plus_inputs] = newline_tokens + + # add target tokens, remove bos + input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens + # add eos token, enforce stopping if we don't truncate + # we don't want long code to stop generating if truncated during training + if newline_plus_inputs + len(target_tokens) < max_length: + input_ids[i, newline_plus_inputs + len(target_tokens)] = self.tokenizer.eos_token_id + + labels = input_ids[i].clone() + labels[: newline_plus_inputs] = -100 + labels[labels == self.tokenizer.pad_token_id] = -100 + # to debug this, can set all values == -100 to the pad token, then assert that tokenizer.decode(labels, skip_special_tokens=True).strip() == response + + attention_mask = input_ids[i].ne(self.tokenizer.pad_token_id).int() + + out["labels"].append(labels) + out["attention_mask"].append(attention_mask) + + out["input_ids"] = input_ids + + out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()} + + return out + + def prepare_data(self, **kwargs) -> None: + dataset = load_dataset("json", data_files=self.dataset) + + self.val_data = None + if self.val_set_size > 0: + dataset = dataset["train"].train_test_split( + test_size=self.val_set_size, shuffle=True, seed=42 # ! Seed = 42 (?) + ) + train_dataset, val_dataset = dataset["train"], dataset["test"] + + # tokenize inputs and return labels and attention mask + val_dataset = val_dataset.map( + lambda ele: self.tokenize_inputs(ele), + batched=True, + remove_columns=["source", "prompt"], + ) + self.val_data = val_dataset.with_format("torch") + else: + train_dataset = dataset["train"] + + train_dataset = train_dataset.map( + lambda ele: self.tokenize_inputs(ele), + batched=True, + remove_columns=["source", "prompt"], + ) + self.train_data = train_dataset.with_format("torch")