From daad59f8efb0c3962c264a0c7f4d4ad8d85097da Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 28 Mar 2023 20:53:16 -0400 Subject: [PATCH] properly include the eos token so inference doesn't blabber on --- train_data.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/train_data.py b/train_data.py index 76cb2e8..673a014 100644 --- a/train_data.py +++ b/train_data.py @@ -114,15 +114,18 @@ class TrainSAD(ATrainData): # there's probably a way to do this with the tokenizer settings # but again, gotta move fast result = self.tokenizer( - prompt, + prompt + self.tokenizer.eos_token, truncation=True, - max_length=self.cutoff_len + 1, - padding="max_length", + max_length=self.cutoff_len, + padding=False, ) - return { - "input_ids": result["input_ids"][:-1], - "attention_mask": result["attention_mask"][:-1], - } + if ( + result["input_ids"][-1] != self.tokenizer.eos_token_id + and len(result["input_ids"]) < self.cutoff_len + ): + result["input_ids"].append(tokenizer.eos_token_id) + result["attention_mask"].append(1) + return result def prepare_data(self, **kwargs) -> None: data = load_dataset("json", data_files=self.dataset)