update finetune data format
This commit is contained in:
parent
8a6c8661df
commit
0768d0fdff
|
|
@ -13,7 +13,7 @@ class Finetune4bConfig:
|
||||||
gradient_checkpointing: bool,
|
gradient_checkpointing: bool,
|
||||||
gradient_checkpointing_ratio: float,
|
gradient_checkpointing_ratio: float,
|
||||||
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
|
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
|
||||||
checkpoint: bool, skip: bool, groupsize: int
|
checkpoint: bool, skip: bool, txt_row_thd: int, groupsize: int
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -40,6 +40,7 @@ class Finetune4bConfig:
|
||||||
logging_steps (int): Logging steps
|
logging_steps (int): Logging steps
|
||||||
checkpoint (bool): Produce checkpoint instead of LoRA
|
checkpoint (bool): Produce checkpoint instead of LoRA
|
||||||
skip (bool): Don't train model
|
skip (bool): Don't train model
|
||||||
|
txt_row_thd (int): Custom row thd for txt file
|
||||||
groupsize (int): Group size of V2 model, use -1 to load V1 model
|
groupsize (int): Group size of V2 model, use -1 to load V1 model
|
||||||
"""
|
"""
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
|
@ -66,6 +67,7 @@ class Finetune4bConfig:
|
||||||
self.logging_steps = logging_steps
|
self.logging_steps = logging_steps
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
self.skip = skip
|
self.skip = skip
|
||||||
|
self.txt_row_thd = txt_row_thd
|
||||||
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
self.ddp = self.world_size != 1
|
self.ddp = self.world_size != 1
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,7 @@ def parse_commandline():
|
||||||
parser_training.add_argument("--logging_steps", default=10, type=int, help="Default: %(default)s")
|
parser_training.add_argument("--logging_steps", default=10, type=int, help="Default: %(default)s")
|
||||||
parser_training.add_argument("-c", "--checkpoint", action="store_true", help="Produce checkpoint instead of LoRA. Default: %(default)s")
|
parser_training.add_argument("-c", "--checkpoint", action="store_true", help="Produce checkpoint instead of LoRA. Default: %(default)s")
|
||||||
parser_training.add_argument("--skip", action="store_true", help="Don't train model. Can be useful to produce checkpoint from existing LoRA. Default: %(default)s")
|
parser_training.add_argument("--skip", action="store_true", help="Don't train model. Can be useful to produce checkpoint from existing LoRA. Default: %(default)s")
|
||||||
|
parser_training.add_argument("--txt_row_thd", default=-1, type=int, help="Custom thd for txt rows.")
|
||||||
|
|
||||||
# V2 model support
|
# V2 model support
|
||||||
parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model, use -1 to load v1 model")
|
parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model, use -1 to load v1 model")
|
||||||
|
|
@ -85,5 +86,6 @@ def get_config() -> Finetune4bConfig:
|
||||||
logging_steps=args["logging_steps"],
|
logging_steps=args["logging_steps"],
|
||||||
checkpoint=args["checkpoint"],
|
checkpoint=args["checkpoint"],
|
||||||
skip=args["skip"],
|
skip=args["skip"],
|
||||||
|
txt_row_thd=args["txt_row_thd"],
|
||||||
groupsize=args["groupsize"]
|
groupsize=args["groupsize"]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,7 @@ else:
|
||||||
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map={'': 0}, torch_dtype=torch.float32) # ! Direct copy from inference.py
|
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map={'': 0}, torch_dtype=torch.float32) # ! Direct copy from inference.py
|
||||||
print(ft_config.lora_apply_dir, 'loaded')
|
print(ft_config.lora_apply_dir, 'loaded')
|
||||||
|
|
||||||
|
|
||||||
# Scales to half
|
# Scales to half
|
||||||
print('Fitting 4bit scales and zeros to half')
|
print('Fitting 4bit scales and zeros to half')
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
|
|
@ -84,7 +85,7 @@ if not ft_config.skip:
|
||||||
data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
|
data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("ERROR: Unknown dataset format")
|
raise NotImplementedError("ERROR: Unknown dataset format")
|
||||||
data.prepare_data()
|
data.prepare_data(thd=ft_config.txt_row_thd)
|
||||||
####
|
####
|
||||||
|
|
||||||
# Use gradient checkpointing
|
# Use gradient checkpointing
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from datasets import load_dataset, Dataset
|
from datasets import load_dataset, Dataset
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
# Abstract train data loader
|
# Abstract train data loader
|
||||||
|
|
@ -66,11 +67,38 @@ class TrainTxt(ATrainData):
|
||||||
self.exceed_count += 1
|
self.exceed_count += 1
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def prepare_data(self):
|
@classmethod
|
||||||
|
def format_new_rows(cls, rows, thd=128):
|
||||||
|
r_b = ''
|
||||||
|
new_rows = []
|
||||||
|
for row in rows:
|
||||||
|
if len(r_b) == 0:
|
||||||
|
r_b += row
|
||||||
|
else:
|
||||||
|
r_b += '\n' + row
|
||||||
|
if len(r_b) > thd:
|
||||||
|
new_rows.append(r_b)
|
||||||
|
r_b = ''
|
||||||
|
if len(r_b) > thd:
|
||||||
|
new_rows.append(r_b)
|
||||||
|
r_b = ''
|
||||||
|
return new_rows
|
||||||
|
|
||||||
|
def prepare_data(self, thd=-1):
|
||||||
|
if os.path.isdir(self.dataset):
|
||||||
|
rows = []
|
||||||
|
for filename in os.listdir(self.dataset):
|
||||||
|
with open(self.dataset + filename, 'r', encoding='utf8') as file:
|
||||||
|
txt = file.read()
|
||||||
|
txt = txt.replace('\r\n', '\n').replace('\u3000', ' ')
|
||||||
|
rows += [r for r in txt.split('\n') if r != '']
|
||||||
|
else:
|
||||||
with open(self.dataset, 'r', encoding='utf8') as file:
|
with open(self.dataset, 'r', encoding='utf8') as file:
|
||||||
txt = file.read()
|
txt = file.read()
|
||||||
txt = txt.replace('\r\n', '\n')
|
txt = txt.replace('\r\n', '\n')
|
||||||
rows = [r for r in txt.split('\n') if r != '']
|
rows = [r for r in txt.split('\n') if r != '']
|
||||||
|
if thd != -1:
|
||||||
|
rows = self.format_new_rows(rows, thd=thd)
|
||||||
data = Dataset.from_dict({"input": rows})
|
data = Dataset.from_dict({"input": rows})
|
||||||
data = data.shuffle().map(lambda x: self.tokenize(x))
|
data = data.shuffle().map(lambda x: self.tokenize(x))
|
||||||
print('Train Data: {:.2f}%'.format(self.exceed_count / len(data) * 100), 'outliers')
|
print('Train Data: {:.2f}%'.format(self.exceed_count / len(data) * 100), 'outliers')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue