add padding support as an option
This commit is contained in:
parent
cff57ebfa4
commit
2a1cb42966
|
|
@ -13,7 +13,7 @@ class Finetune4bConfig:
|
|||
gradient_checkpointing: bool,
|
||||
gradient_checkpointing_ratio: float,
|
||||
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
|
||||
checkpoint: bool, skip: bool, txt_row_thd: int, groupsize: int
|
||||
checkpoint: bool, skip: bool, txt_row_thd: int, use_eos_token: bool, groupsize: int
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
|
@ -41,6 +41,7 @@ class Finetune4bConfig:
|
|||
checkpoint (bool): Produce checkpoint instead of LoRA
|
||||
skip (bool): Don't train model
|
||||
txt_row_thd (int): Custom row thd for txt file
|
||||
use_eos_token (bool): Use Eos token instead of padding with 0
|
||||
groupsize (int): Group size of V2 model, use -1 to load V1 model
|
||||
"""
|
||||
self.dataset = dataset
|
||||
|
|
@ -68,6 +69,7 @@ class Finetune4bConfig:
|
|||
self.checkpoint = checkpoint
|
||||
self.skip = skip
|
||||
self.txt_row_thd = txt_row_thd
|
||||
self.use_eos_token = use_eos_token
|
||||
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
self.ddp = self.world_size != 1
|
||||
|
|
|
|||
|
|
@ -52,8 +52,11 @@ def parse_commandline():
|
|||
parser_training.add_argument("--logging_steps", default=10, type=int, help="Default: %(default)s")
|
||||
parser_training.add_argument("-c", "--checkpoint", action="store_true", help="Produce checkpoint instead of LoRA. Default: %(default)s")
|
||||
parser_training.add_argument("--skip", action="store_true", help="Don't train model. Can be useful to produce checkpoint from existing LoRA. Default: %(default)s")
|
||||
parser_training.add_argument("--txt_row_thd", default=-1, type=int, help="Custom thd for txt rows.")
|
||||
|
||||
# Data args
|
||||
parser_training.add_argument("--txt_row_thd", default=-1, type=int, help="Custom thd for txt rows.")
|
||||
parser_training.add_argument("--use_eos_token", default=1, type=int, help="Use eos token instead if padding with 0. enable with 1, disable with 0.")
|
||||
|
||||
# V2 model support
|
||||
parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model, use -1 to load v1 model")
|
||||
|
||||
|
|
@ -87,5 +90,6 @@ def get_config() -> Finetune4bConfig:
|
|||
checkpoint=args["checkpoint"],
|
||||
skip=args["skip"],
|
||||
txt_row_thd=args["txt_row_thd"],
|
||||
use_eos_token=args["use_eos_token"]!=0,
|
||||
groupsize=args["groupsize"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ if not ft_config.skip:
|
|||
data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
|
||||
else:
|
||||
raise NotImplementedError("ERROR: Unknown dataset format")
|
||||
data.prepare_data(thd=ft_config.txt_row_thd)
|
||||
data.prepare_data(thd=ft_config.txt_row_thd, use_eos_token=ft_config.use_eos_token)
|
||||
####
|
||||
|
||||
# Use gradient checkpointing
|
||||
|
|
|
|||
101
train_data.py
101
train_data.py
|
|
@ -49,20 +49,37 @@ class TrainTxt(ATrainData):
|
|||
self.cutoff_len = cutoff_len
|
||||
self.exceed_count = 0
|
||||
|
||||
def tokenize(self, prompt: str) -> Dict[str, Any]:
|
||||
def tokenize(self, prompt: str, use_eos_token=True, **kwargs) -> Dict[str, Any]:
|
||||
# there's probably a way to do this with the tokenizer settings
|
||||
# but again, gotta move fast
|
||||
prompt = prompt['input']
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.cutoff_len + 1,
|
||||
padding="max_length",
|
||||
)
|
||||
d = {
|
||||
"input_ids": result["input_ids"][:-1],
|
||||
"attention_mask": result["attention_mask"][:-1],
|
||||
}
|
||||
if use_eos_token:
|
||||
result = self.tokenizer(
|
||||
prompt + self.tokenizer.eos_token,
|
||||
truncation=True,
|
||||
max_length=self.cutoff_len,
|
||||
padding=False,
|
||||
)
|
||||
d = {
|
||||
"input_ids": result["input_ids"],
|
||||
"attention_mask": result["attention_mask"],
|
||||
}
|
||||
if (
|
||||
d["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(d["input_ids"]) < self.cutoff_len
|
||||
):
|
||||
d["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
d["attention_mask"].append(1)
|
||||
else:
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.cutoff_len + 1,
|
||||
padding="max_length",
|
||||
)
|
||||
d = {
|
||||
"input_ids": result["input_ids"][:-1],
|
||||
"attention_mask": result["attention_mask"][:-1],
|
||||
}
|
||||
if sum(d['attention_mask']) >= self.cutoff_len:
|
||||
self.exceed_count += 1
|
||||
return d
|
||||
|
|
@ -84,7 +101,7 @@ class TrainTxt(ATrainData):
|
|||
r_b = ''
|
||||
return new_rows
|
||||
|
||||
def prepare_data(self, thd=-1, **kwargs):
|
||||
def prepare_data(self, thd=-1, use_eos_token=True, **kwargs):
|
||||
if os.path.isdir(self.dataset):
|
||||
rows = []
|
||||
for filename in os.listdir(self.dataset):
|
||||
|
|
@ -100,7 +117,7 @@ class TrainTxt(ATrainData):
|
|||
if thd != -1:
|
||||
rows = self.format_new_rows(rows, thd=thd)
|
||||
data = Dataset.from_dict({"input": rows})
|
||||
data = data.shuffle().map(lambda x: self.tokenize(x))
|
||||
data = data.shuffle().map(lambda x: self.tokenize(x["input"], use_eos_token=use_eos_token))
|
||||
print('Train Data: {:.2f}%'.format(self.exceed_count / len(data) * 100), 'outliers')
|
||||
self.train_data = data
|
||||
|
||||
|
|
@ -110,38 +127,50 @@ class TrainSAD(ATrainData):
|
|||
def __init__(self, dataset: str, val_set_size: int, tokenizer, cutoff_len) -> None:
|
||||
super().__init__(dataset, val_set_size, tokenizer, cutoff_len)
|
||||
|
||||
def tokenize(self, prompt: str) -> Dict[str, Any]:
|
||||
def tokenize(self, prompt: str, use_eos_token=True, **kwargs) -> Dict[str, Any]:
|
||||
# there's probably a way to do this with the tokenizer settings
|
||||
# but again, gotta move fast
|
||||
result = self.tokenizer(
|
||||
prompt + self.tokenizer.eos_token,
|
||||
truncation=True,
|
||||
max_length=self.cutoff_len,
|
||||
padding=False,
|
||||
)
|
||||
if (
|
||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.cutoff_len
|
||||
):
|
||||
result["input_ids"].append(tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
return result
|
||||
if use_eos_token:
|
||||
result = self.tokenizer(
|
||||
prompt + self.tokenizer.eos_token,
|
||||
truncation=True,
|
||||
max_length=self.cutoff_len,
|
||||
padding=False,
|
||||
)
|
||||
if (
|
||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.cutoff_len
|
||||
):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
return result
|
||||
else:
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.cutoff_len + 1,
|
||||
padding="max_length",
|
||||
)
|
||||
return {
|
||||
"input_ids": result["input_ids"][:-1],
|
||||
"attention_mask": result["attention_mask"][:-1],
|
||||
}
|
||||
|
||||
def prepare_data(self, **kwargs) -> None:
|
||||
def prepare_data(self, use_eos_token=True, **kwargs) -> None:
|
||||
data = load_dataset("json", data_files=self.dataset)
|
||||
|
||||
if self.val_set_size > 0:
|
||||
train_val = data["train"].train_test_split(
|
||||
test_size=self.val_set_size, shuffle=True, seed=42 # ! Seed = 42 (?)
|
||||
)
|
||||
self.train_data = train_val["train"].shuffle().map(self.generate_and_tokenize_prompt)
|
||||
self.val_data = train_val["test"].shuffle().map(self.generate_and_tokenize_prompt)
|
||||
self.train_data = train_val["train"].shuffle().map(lambda x: self.generate_and_tokenize_prompt(x, use_eos_token=use_eos_token))
|
||||
self.val_data = train_val["test"].shuffle().map(lambda x: self.generate_and_tokenize_prompt(x, use_eos_token=use_eos_token))
|
||||
else:
|
||||
self.train_data = data["train"].shuffle().map(self.generate_and_tokenize_prompt)
|
||||
self.train_data = data["train"].shuffle().map(lambda x: self.generate_and_tokenize_prompt(x, use_eos_token=use_eos_token))
|
||||
self.val_data = None
|
||||
|
||||
# Auxiliary methods
|
||||
def generate_prompt(self, data_point):
|
||||
def generate_prompt(self, data_point, **kwargs):
|
||||
return "{0}\n\n{1}\n{2}\n\n{3}\n{4}\n\n{5}\n{6}".format(
|
||||
"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.",
|
||||
"### Instruction:",
|
||||
|
|
@ -152,6 +181,6 @@ class TrainSAD(ATrainData):
|
|||
data_point["output"]
|
||||
)
|
||||
|
||||
def generate_and_tokenize_prompt(self, data_point):
|
||||
prompt = self.generate_prompt(data_point)
|
||||
return self.tokenize(prompt)
|
||||
def generate_and_tokenize_prompt(self, data_point, **kwargs):
|
||||
prompt = self.generate_prompt(data_point, **kwargs)
|
||||
return self.tokenize(prompt, **kwargs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue