properly include the eos token so inference doesn't blabber on

This commit is contained in:
Wing Lian 2023-03-28 20:53:16 -04:00
parent 1043ded7d9
commit daad59f8ef
1 changed files with 10 additions and 7 deletions

View File

@ -114,15 +114,18 @@ class TrainSAD(ATrainData):
# there's probably a way to do this with the tokenizer settings # there's probably a way to do this with the tokenizer settings
# but again, gotta move fast # but again, gotta move fast
result = self.tokenizer( result = self.tokenizer(
prompt, prompt + self.tokenizer.eos_token,
truncation=True, truncation=True,
max_length=self.cutoff_len + 1, max_length=self.cutoff_len,
padding="max_length", padding=False,
) )
return { if (
"input_ids": result["input_ids"][:-1], result["input_ids"][-1] != self.tokenizer.eos_token_id
"attention_mask": result["attention_mask"][:-1], and len(result["input_ids"]) < self.cutoff_len
} ):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
return result
def prepare_data(self, **kwargs) -> None: def prepare_data(self, **kwargs) -> None:
data = load_dataset("json", data_files=self.dataset) data = load_dataset("json", data_files=self.dataset)