diff --git a/Finetune4bConfig.py b/Finetune4bConfig.py index b583e4a..fa01c1f 100644 --- a/Finetune4bConfig.py +++ b/Finetune4bConfig.py @@ -13,7 +13,7 @@ class Finetune4bConfig: gradient_checkpointing: bool, gradient_checkpointing_ratio: float, warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int, - checkpoint: bool, skip: bool, groupsize: int + checkpoint: bool, skip: bool, txt_row_thd: int, groupsize: int ): """ Args: @@ -40,6 +40,7 @@ class Finetune4bConfig: logging_steps (int): Logging steps checkpoint (bool): Produce checkpoint instead of LoRA skip (bool): Don't train model + txt_row_thd (int): Custom row thd for txt file groupsize (int): Group size of V2 model, use -1 to load V1 model """ self.dataset = dataset @@ -66,6 +67,7 @@ class Finetune4bConfig: self.logging_steps = logging_steps self.checkpoint = checkpoint self.skip = skip + self.txt_row_thd = txt_row_thd self.world_size = int(os.environ.get("WORLD_SIZE", 1)) self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) self.ddp = self.world_size != 1 diff --git a/arg_parser.py b/arg_parser.py index 0f80eb4..f77456e 100644 --- a/arg_parser.py +++ b/arg_parser.py @@ -52,6 +52,7 @@ def parse_commandline(): parser_training.add_argument("--logging_steps", default=10, type=int, help="Default: %(default)s") parser_training.add_argument("-c", "--checkpoint", action="store_true", help="Produce checkpoint instead of LoRA. Default: %(default)s") parser_training.add_argument("--skip", action="store_true", help="Don't train model. Can be useful to produce checkpoint from existing LoRA. Default: %(default)s") + parser_training.add_argument("--txt_row_thd", default=-1, type=int, help="Custom thd for txt rows.") # V2 model support parser_training.add_argument("--groupsize", type=int, default=-1, help="Groupsize of v2 model, use -1 to load v1 model") @@ -85,5 +86,6 @@ def get_config() -> Finetune4bConfig: logging_steps=args["logging_steps"], checkpoint=args["checkpoint"], skip=args["skip"], + txt_row_thd=args["txt_row_thd"], groupsize=args["groupsize"] ) diff --git a/finetune.py b/finetune.py index 36cfaea..f4c8073 100644 --- a/finetune.py +++ b/finetune.py @@ -62,6 +62,7 @@ else: model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map={'': 0}, torch_dtype=torch.float32) # ! Direct copy from inference.py print(ft_config.lora_apply_dir, 'loaded') + # Scales to half print('Fitting 4bit scales and zeros to half') for n, m in model.named_modules(): @@ -84,7 +85,7 @@ if not ft_config.skip: data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len) else: raise NotImplementedError("ERROR: Unknown dataset format") - data.prepare_data() + data.prepare_data(thd=ft_config.txt_row_thd) #### # Use gradient checkpointing diff --git a/train_data.py b/train_data.py index d2d30a8..62dee6d 100644 --- a/train_data.py +++ b/train_data.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, Any from datasets import load_dataset, Dataset +import os # Abstract train data loader @@ -66,11 +67,38 @@ class TrainTxt(ATrainData): self.exceed_count += 1 return d - def prepare_data(self): - with open(self.dataset, 'r', encoding='utf8') as file: - txt = file.read() - txt = txt.replace('\r\n', '\n') - rows = [r for r in txt.split('\n') if r != ''] + @classmethod + def format_new_rows(cls, rows, thd=128): + r_b = '' + new_rows = [] + for row in rows: + if len(r_b) == 0: + r_b += row + else: + r_b += '\n' + row + if len(r_b) > thd: + new_rows.append(r_b) + r_b = '' + if len(r_b) > thd: + new_rows.append(r_b) + r_b = '' + return new_rows + + def prepare_data(self, thd=-1): + if os.path.isdir(self.dataset): + rows = [] + for filename in os.listdir(self.dataset): + with open(self.dataset + filename, 'r', encoding='utf8') as file: + txt = file.read() + txt = txt.replace('\r\n', '\n').replace('\u3000', ' ') + rows += [r for r in txt.split('\n') if r != ''] + else: + with open(self.dataset, 'r', encoding='utf8') as file: + txt = file.read() + txt = txt.replace('\r\n', '\n') + rows = [r for r in txt.split('\n') if r != ''] + if thd != -1: + rows = self.format_new_rows(rows, thd=thd) data = Dataset.from_dict({"input": rows}) data = data.shuffle().map(lambda x: self.tokenize(x)) print('Train Data: {:.2f}%'.format(self.exceed_count / len(data) * 100), 'outliers')