Merge pull request #39 from winglian/fix-prompt-eos-token
properly include the eos token so inference doesn't blabber on
This commit is contained in:
commit
cff57ebfa4
|
|
@ -114,15 +114,18 @@ class TrainSAD(ATrainData):
|
||||||
# there's probably a way to do this with the tokenizer settings
|
# there's probably a way to do this with the tokenizer settings
|
||||||
# but again, gotta move fast
|
# but again, gotta move fast
|
||||||
result = self.tokenizer(
|
result = self.tokenizer(
|
||||||
prompt,
|
prompt + self.tokenizer.eos_token,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=self.cutoff_len + 1,
|
max_length=self.cutoff_len,
|
||||||
padding="max_length",
|
padding=False,
|
||||||
)
|
)
|
||||||
return {
|
if (
|
||||||
"input_ids": result["input_ids"][:-1],
|
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||||
"attention_mask": result["attention_mask"][:-1],
|
and len(result["input_ids"]) < self.cutoff_len
|
||||||
}
|
):
|
||||||
|
result["input_ids"].append(tokenizer.eos_token_id)
|
||||||
|
result["attention_mask"].append(1)
|
||||||
|
return result
|
||||||
|
|
||||||
def prepare_data(self, **kwargs) -> None:
|
def prepare_data(self, **kwargs) -> None:
|
||||||
data = load_dataset("json", data_files=self.dataset)
|
data = load_dataset("json", data_files=self.dataset)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue