From 2a1cb4296661935936876d1c5ae9e11913111450 Mon Sep 17 00:00:00 2001 From: John Smith Date: Wed, 29 Mar 2023 11:20:16 +0800 Subject: [PATCH] add padding support as an option --- Finetune4bConfig.py | 4 +- arg_parser.py | 6 ++- finetune.py | 2 +- train_data.py | 101 ++++++++++++++++++++++++++++---------------- 4 files changed, 74 insertions(+), 39 deletions(-) diff --git a/Finetune4bConfig.py b/Finetune4bConfig.py index fa01c1f..57fce77 100644 --- a/Finetune4bConfig.py +++ b/Finetune4bConfig.py @@ -13,7 +13,7 @@ class Finetune4bConfig: gradient_checkpointing: bool, gradient_checkpointing_ratio: float, warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int, - checkpoint: bool, skip: bool, txt_row_thd: int, groupsize: int + checkpoint: bool, skip: bool, txt_row_thd: int, use_eos_token: bool, groupsize: int ): """ Args: @@ -41,6 +41,7 @@ class Finetune4bConfig: checkpoint (bool): Produce checkpoint instead of LoRA skip (bool): Don't train model txt_row_thd (int): Custom row thd for txt file + use_eos_token (bool): Use Eos token instead of padding with 0 groupsize (int): Group size of V2 model, use -1 to load V1 model """ self.dataset = dataset @@ -68,6 +69,7 @@ class Finetune4bConfig: self.checkpoint = checkpoint self.skip = skip self.txt_row_thd = txt_row_thd + self.use_eos_token = use_eos_token self.world_size = int(os.environ.get("WORLD_SIZE", 1)) self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) self.ddp = self.world_size != 1 diff --git a/arg_parser.py b/arg_parser.py index f77456e..d6449a8 100644 --- a/arg_parser.py +++ b/arg_parser.py @@ -52,8 +52,11 @@ def parse_commandline(): parser_training.add_argument("--logging_steps", default=10, type=int, help="Default: %(default)s") parser_training.add_argument("-c", "--checkpoint", action="store_true", help="Produce checkpoint instead of LoRA. Default: %(default)s") parser_training.add_argument("--skip", action="store_true", help="Don't train model. Can be useful to produce checkpoint from existing LoRA. Default: %(default)s") - parser_training.add_argument("--txt_row_thd", default=-1, type=int, help="Custom thd for txt rows.") + # Data args + parser_training.add_argument("--txt_row_thd", default=-1, type=int, help="Custom thd for txt rows.") + parser_training.add_argument("--use_eos_token", default=1, type=int, help="Use eos token instead if padding with 0. enable with 1, disable with 0.") + # V2 model support parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model, use -1 to load v1 model") @@ -87,5 +90,6 @@ def get_config() -> Finetune4bConfig: checkpoint=args["checkpoint"], skip=args["skip"], txt_row_thd=args["txt_row_thd"], + use_eos_token=args["use_eos_token"]!=0, groupsize=args["groupsize"] ) diff --git a/finetune.py b/finetune.py index f4c8073..eb7bbbb 100644 --- a/finetune.py +++ b/finetune.py @@ -85,7 +85,7 @@ if not ft_config.skip: data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len) else: raise NotImplementedError("ERROR: Unknown dataset format") - data.prepare_data(thd=ft_config.txt_row_thd) + data.prepare_data(thd=ft_config.txt_row_thd, use_eos_token=ft_config.use_eos_token) #### # Use gradient checkpointing diff --git a/train_data.py b/train_data.py index 673a014..806743f 100644 --- a/train_data.py +++ b/train_data.py @@ -49,20 +49,37 @@ class TrainTxt(ATrainData): self.cutoff_len = cutoff_len self.exceed_count = 0 - def tokenize(self, prompt: str) -> Dict[str, Any]: + def tokenize(self, prompt: str, use_eos_token=True, **kwargs) -> Dict[str, Any]: # there's probably a way to do this with the tokenizer settings # but again, gotta move fast - prompt = prompt['input'] - result = self.tokenizer( - prompt, - truncation=True, - max_length=self.cutoff_len + 1, - padding="max_length", - ) - d = { - "input_ids": result["input_ids"][:-1], - "attention_mask": result["attention_mask"][:-1], - } + if use_eos_token: + result = self.tokenizer( + prompt + self.tokenizer.eos_token, + truncation=True, + max_length=self.cutoff_len, + padding=False, + ) + d = { + "input_ids": result["input_ids"], + "attention_mask": result["attention_mask"], + } + if ( + d["input_ids"][-1] != self.tokenizer.eos_token_id + and len(d["input_ids"]) < self.cutoff_len + ): + d["input_ids"].append(self.tokenizer.eos_token_id) + d["attention_mask"].append(1) + else: + result = self.tokenizer( + prompt, + truncation=True, + max_length=self.cutoff_len + 1, + padding="max_length", + ) + d = { + "input_ids": result["input_ids"][:-1], + "attention_mask": result["attention_mask"][:-1], + } if sum(d['attention_mask']) >= self.cutoff_len: self.exceed_count += 1 return d @@ -84,7 +101,7 @@ class TrainTxt(ATrainData): r_b = '' return new_rows - def prepare_data(self, thd=-1, **kwargs): + def prepare_data(self, thd=-1, use_eos_token=True, **kwargs): if os.path.isdir(self.dataset): rows = [] for filename in os.listdir(self.dataset): @@ -100,7 +117,7 @@ class TrainTxt(ATrainData): if thd != -1: rows = self.format_new_rows(rows, thd=thd) data = Dataset.from_dict({"input": rows}) - data = data.shuffle().map(lambda x: self.tokenize(x)) + data = data.shuffle().map(lambda x: self.tokenize(x["input"], use_eos_token=use_eos_token)) print('Train Data: {:.2f}%'.format(self.exceed_count / len(data) * 100), 'outliers') self.train_data = data @@ -110,38 +127,50 @@ class TrainSAD(ATrainData): def __init__(self, dataset: str, val_set_size: int, tokenizer, cutoff_len) -> None: super().__init__(dataset, val_set_size, tokenizer, cutoff_len) - def tokenize(self, prompt: str) -> Dict[str, Any]: + def tokenize(self, prompt: str, use_eos_token=True, **kwargs) -> Dict[str, Any]: # there's probably a way to do this with the tokenizer settings # but again, gotta move fast - result = self.tokenizer( - prompt + self.tokenizer.eos_token, - truncation=True, - max_length=self.cutoff_len, - padding=False, - ) - if ( - result["input_ids"][-1] != self.tokenizer.eos_token_id - and len(result["input_ids"]) < self.cutoff_len - ): - result["input_ids"].append(tokenizer.eos_token_id) - result["attention_mask"].append(1) - return result + if use_eos_token: + result = self.tokenizer( + prompt + self.tokenizer.eos_token, + truncation=True, + max_length=self.cutoff_len, + padding=False, + ) + if ( + result["input_ids"][-1] != self.tokenizer.eos_token_id + and len(result["input_ids"]) < self.cutoff_len + ): + result["input_ids"].append(self.tokenizer.eos_token_id) + result["attention_mask"].append(1) + return result + else: + result = self.tokenizer( + prompt, + truncation=True, + max_length=self.cutoff_len + 1, + padding="max_length", + ) + return { + "input_ids": result["input_ids"][:-1], + "attention_mask": result["attention_mask"][:-1], + } - def prepare_data(self, **kwargs) -> None: + def prepare_data(self, use_eos_token=True, **kwargs) -> None: data = load_dataset("json", data_files=self.dataset) if self.val_set_size > 0: train_val = data["train"].train_test_split( test_size=self.val_set_size, shuffle=True, seed=42 # ! Seed = 42 (?) ) - self.train_data = train_val["train"].shuffle().map(self.generate_and_tokenize_prompt) - self.val_data = train_val["test"].shuffle().map(self.generate_and_tokenize_prompt) + self.train_data = train_val["train"].shuffle().map(lambda x: self.generate_and_tokenize_prompt(x, use_eos_token=use_eos_token)) + self.val_data = train_val["test"].shuffle().map(lambda x: self.generate_and_tokenize_prompt(x, use_eos_token=use_eos_token)) else: - self.train_data = data["train"].shuffle().map(self.generate_and_tokenize_prompt) + self.train_data = data["train"].shuffle().map(lambda x: self.generate_and_tokenize_prompt(x, use_eos_token=use_eos_token)) self.val_data = None # Auxiliary methods - def generate_prompt(self, data_point): + def generate_prompt(self, data_point, **kwargs): return "{0}\n\n{1}\n{2}\n\n{3}\n{4}\n\n{5}\n{6}".format( "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.", "### Instruction:", @@ -152,6 +181,6 @@ class TrainSAD(ATrainData): data_point["output"] ) - def generate_and_tokenize_prompt(self, data_point): - prompt = self.generate_prompt(data_point) - return self.tokenize(prompt) + def generate_and_tokenize_prompt(self, data_point, **kwargs): + prompt = self.generate_prompt(data_point, **kwargs) + return self.tokenize(prompt, **kwargs)