update finetune data format
This commit is contained in:
parent
8a6c8661df
commit
0768d0fdff
|
|
@ -13,7 +13,7 @@ class Finetune4bConfig:
|
|||
gradient_checkpointing: bool,
|
||||
gradient_checkpointing_ratio: float,
|
||||
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
|
||||
checkpoint: bool, skip: bool, groupsize: int
|
||||
checkpoint: bool, skip: bool, txt_row_thd: int, groupsize: int
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
|
@ -40,6 +40,7 @@ class Finetune4bConfig:
|
|||
logging_steps (int): Logging steps
|
||||
checkpoint (bool): Produce checkpoint instead of LoRA
|
||||
skip (bool): Don't train model
|
||||
txt_row_thd (int): Custom row thd for txt file
|
||||
groupsize (int): Group size of V2 model, use -1 to load V1 model
|
||||
"""
|
||||
self.dataset = dataset
|
||||
|
|
@ -66,6 +67,7 @@ class Finetune4bConfig:
|
|||
self.logging_steps = logging_steps
|
||||
self.checkpoint = checkpoint
|
||||
self.skip = skip
|
||||
self.txt_row_thd = txt_row_thd
|
||||
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
self.ddp = self.world_size != 1
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ def parse_commandline():
|
|||
parser_training.add_argument("--logging_steps", default=10, type=int, help="Default: %(default)s")
|
||||
parser_training.add_argument("-c", "--checkpoint", action="store_true", help="Produce checkpoint instead of LoRA. Default: %(default)s")
|
||||
parser_training.add_argument("--skip", action="store_true", help="Don't train model. Can be useful to produce checkpoint from existing LoRA. Default: %(default)s")
|
||||
parser_training.add_argument("--txt_row_thd", default=-1, type=int, help="Custom thd for txt rows.")
|
||||
|
||||
# V2 model support
|
||||
parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model, use -1 to load v1 model")
|
||||
|
|
@ -85,5 +86,6 @@ def get_config() -> Finetune4bConfig:
|
|||
logging_steps=args["logging_steps"],
|
||||
checkpoint=args["checkpoint"],
|
||||
skip=args["skip"],
|
||||
txt_row_thd=args["txt_row_thd"],
|
||||
groupsize=args["groupsize"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ else:
|
|||
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map={'': 0}, torch_dtype=torch.float32) # ! Direct copy from inference.py
|
||||
print(ft_config.lora_apply_dir, 'loaded')
|
||||
|
||||
|
||||
# Scales to half
|
||||
print('Fitting 4bit scales and zeros to half')
|
||||
for n, m in model.named_modules():
|
||||
|
|
@ -84,7 +85,7 @@ if not ft_config.skip:
|
|||
data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
|
||||
else:
|
||||
raise NotImplementedError("ERROR: Unknown dataset format")
|
||||
data.prepare_data()
|
||||
data.prepare_data(thd=ft_config.txt_row_thd)
|
||||
####
|
||||
|
||||
# Use gradient checkpointing
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any
|
||||
from datasets import load_dataset, Dataset
|
||||
import os
|
||||
|
||||
|
||||
# Abstract train data loader
|
||||
|
|
@ -66,11 +67,38 @@ class TrainTxt(ATrainData):
|
|||
self.exceed_count += 1
|
||||
return d
|
||||
|
||||
def prepare_data(self):
|
||||
@classmethod
|
||||
def format_new_rows(cls, rows, thd=128):
|
||||
r_b = ''
|
||||
new_rows = []
|
||||
for row in rows:
|
||||
if len(r_b) == 0:
|
||||
r_b += row
|
||||
else:
|
||||
r_b += '\n' + row
|
||||
if len(r_b) > thd:
|
||||
new_rows.append(r_b)
|
||||
r_b = ''
|
||||
if len(r_b) > thd:
|
||||
new_rows.append(r_b)
|
||||
r_b = ''
|
||||
return new_rows
|
||||
|
||||
def prepare_data(self, thd=-1):
|
||||
if os.path.isdir(self.dataset):
|
||||
rows = []
|
||||
for filename in os.listdir(self.dataset):
|
||||
with open(self.dataset + filename, 'r', encoding='utf8') as file:
|
||||
txt = file.read()
|
||||
txt = txt.replace('\r\n', '\n').replace('\u3000', ' ')
|
||||
rows += [r for r in txt.split('\n') if r != '']
|
||||
else:
|
||||
with open(self.dataset, 'r', encoding='utf8') as file:
|
||||
txt = file.read()
|
||||
txt = txt.replace('\r\n', '\n')
|
||||
rows = [r for r in txt.split('\n') if r != '']
|
||||
if thd != -1:
|
||||
rows = self.format_new_rows(rows, thd=thd)
|
||||
data = Dataset.from_dict({"input": rows})
|
||||
data = data.shuffle().map(lambda x: self.tokenize(x))
|
||||
print('Train Data: {:.2f}%'.format(self.exceed_count / len(data) * 100), 'outliers')
|
||||
|
|
|
|||
Loading…
Reference in New Issue