add resume checkpoint to continue a training

This commit is contained in:
John Smith 2023-03-29 14:35:39 +08:00
parent 2a1cb42966
commit 1c02d4262d
3 changed files with 22 additions and 3 deletions

View File

@ -3,7 +3,7 @@ class Finetune4bConfig:
"""Config holder for LLaMA 4bit finetuning
"""
def __init__(self, dataset: str, ds_type: str,
lora_out_dir: str, lora_apply_dir : str,
lora_out_dir: str, lora_apply_dir: str, resume_checkpoint: str,
llama_q4_config_dir: str, llama_q4_model: str,
mbatch_size: int, batch_size: int,
epochs: int, lr: float,
@ -13,7 +13,8 @@ class Finetune4bConfig:
gradient_checkpointing: bool,
gradient_checkpointing_ratio: float,
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
checkpoint: bool, skip: bool, txt_row_thd: int, use_eos_token: bool, groupsize: int
checkpoint: bool, skip: bool, verbose: bool,
txt_row_thd: int, use_eos_token: bool, groupsize: int
):
"""
Args:
@ -21,6 +22,7 @@ class Finetune4bConfig:
ds_type (str): Dataset structure format
lora_out_dir (str): Directory to place new LoRA
lora_apply_dir (str): Path to directory from which LoRA has to be applied before training
resume_checkpoint (str): Path to Specified checkpoint you want to resume.
llama_q4_config_dir (str): Path to the config.json, tokenizer_config.json, etc
llama_q4_model (str): Path to the quantized model in huggingface format
mbatch_size (int): Micro-batch size
@ -40,6 +42,7 @@ class Finetune4bConfig:
logging_steps (int): Logging steps
checkpoint (bool): Produce checkpoint instead of LoRA
skip (bool): Don't train model
verbose (bool): If output log of training
txt_row_thd (int): Custom row thd for txt file
use_eos_token (bool): Use Eos token instead of padding with 0
groupsize (int): Group size of V2 model, use -1 to load V1 model
@ -48,6 +51,7 @@ class Finetune4bConfig:
self.ds_type = ds_type
self.lora_out_dir = lora_out_dir
self.lora_apply_dir = lora_apply_dir
self.resume_checkpoint = resume_checkpoint
self.llama_q4_config_dir = llama_q4_config_dir
self.llama_q4_model = llama_q4_model
self.mbatch_size = mbatch_size
@ -68,6 +72,7 @@ class Finetune4bConfig:
self.logging_steps = logging_steps
self.checkpoint = checkpoint
self.skip = skip
self.verbose = verbose
self.txt_row_thd = txt_row_thd
self.use_eos_token = use_eos_token
self.world_size = int(os.environ.get("WORLD_SIZE", 1))

View File

@ -27,6 +27,9 @@ def parse_commandline():
parser_config.add_argument("--lora_apply_dir", default=None, required=False,
help="Path to directory from which LoRA has to be applied before training. Default: %(default)s"
)
parser_training.add_argument("--resume_checkpoint", default=None, required=False,
help="Resume training from specified checkpoint. Default: %(default)s"
)
parser_config.add_argument("--llama_q4_config_dir", default="./llama-13b-4bit/", required=False,
help="Path to the config.json, tokenizer_config.json, etc. Default: %(default)s"
)
@ -52,6 +55,7 @@ def parse_commandline():
parser_training.add_argument("--logging_steps", default=10, type=int, help="Default: %(default)s")
parser_training.add_argument("-c", "--checkpoint", action="store_true", help="Produce checkpoint instead of LoRA. Default: %(default)s")
parser_training.add_argument("--skip", action="store_true", help="Don't train model. Can be useful to produce checkpoint from existing LoRA. Default: %(default)s")
parser_training.add_argument("--verbose", action="store_true", help="If output log of training. Default: %(default)s")
# Data args
parser_training.add_argument("--txt_row_thd", default=-1, type=int, help="Custom thd for txt rows.")
@ -70,6 +74,7 @@ def get_config() -> Finetune4bConfig:
ds_type=args["ds_type"],
lora_out_dir=args["lora_out_dir"],
lora_apply_dir=args["lora_apply_dir"],
resume_checkpoint=args["resume_checkpoint"],
llama_q4_config_dir=args["llama_q4_config_dir"],
llama_q4_model=args["llama_q4_model"],
mbatch_size=args["mbatch_size"],
@ -89,6 +94,7 @@ def get_config() -> Finetune4bConfig:
logging_steps=args["logging_steps"],
checkpoint=args["checkpoint"],
skip=args["skip"],
verbose=args["verbose"],
txt_row_thd=args["txt_row_thd"],
use_eos_token=args["use_eos_token"]!=0,
groupsize=args["groupsize"]

View File

@ -130,8 +130,16 @@ if not ft_config.skip:
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
# Set Verbose
if ft_config.verbose:
transformers.logging.set_verbosity_info()
# Run Trainer
trainer.train()
if ft_config.resume_checkpoint:
print('Resuming from {} ...'.format(ft_config.resume_checkpoint))
trainer.train(ft_config.resume_checkpoint)
else:
trainer.train()
print('Train completed.')