diff --git a/Finetune4bConfig.py b/Finetune4bConfig.py index 62f8091..c074028 100644 --- a/Finetune4bConfig.py +++ b/Finetune4bConfig.py @@ -1,16 +1,18 @@ class Finetune4bConfig: """Config holder for LLaMA 4bit finetuning """ - def __init__(self, dataset : str, ds_type : str, - lora_out_dir : str, lora_apply_dir : str, - llama_q4_config_dir : str, llama_q4_model : str, - mbatch_size : int, batch_size : int, - epochs : int, lr : float, - cutoff_len : int, - lora_r : int, lora_alpha : int, lora_dropout : float, - val_set_size : float, - warmup_steps : int, save_steps : int, save_total_limit : int, logging_steps : int, - checkpoint : bool, skip : bool + def __init__(self, dataset: str, ds_type: str, + lora_out_dir: str, lora_apply_dir : str, + llama_q4_config_dir: str, llama_q4_model: str, + mbatch_size: int, batch_size: int, + epochs: int, lr: float, + cutoff_len: int, + lora_r: int, lora_alpha: int, lora_dropout: float, + val_set_size: float, + gradient_checkpointing: bool, + gradient_checkpointing_ratio: float, + warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int, + checkpoint: bool, skip: bool ): """ Args: @@ -28,6 +30,8 @@ class Finetune4bConfig: lora_r (int): LoRA R lora_alpha (int): LoRA Alpha lora_dropout (float): LoRA Dropout + gradient_checkpointing (bool) : Use gradient checkpointing + gradient_checkpointing_ratio (float) : Gradient checkpoint ratio val_set_size (int): Validation set size warmup_steps (int): Warmup steps before training save_steps (int): Save steps @@ -50,8 +54,10 @@ class Finetune4bConfig: self.cutoff_len = cutoff_len self.lora_r = lora_r self.lora_alpha = lora_alpha - self.lora_dropout = lora_dropout + self.lora_dropout = 0 if gradient_checkpointing else lora_dropout # should be 0 if gradient checkpointing is on self.val_set_size = int(val_set_size) if val_set_size > 1.0 else float(val_set_size) + self.gradient_checkpointing = gradient_checkpointing + self.gradient_checkpointing_ratio = gradient_checkpointing_ratio self.warmup_steps = warmup_steps self.save_steps = save_steps self.save_total_limit = save_total_limit @@ -61,9 +67,12 @@ class Finetune4bConfig: def __str__(self) -> str: - return f"\nParameters:\n{'config':-^20}\n{self.dataset=}\n{self.ds_type=}\n{self.lora_out_dir=}\n{self.lora_apply_dir=}\n{self.llama_q4_config_dir=}\n{self.llama_q4_model=}\n\n" +\ + s = f"\nParameters:\n{'config':-^20}\n{self.dataset=}\n{self.ds_type=}\n{self.lora_out_dir=}\n{self.lora_apply_dir=}\n{self.llama_q4_config_dir=}\n{self.llama_q4_model=}\n\n" +\ f"{'training':-^20}\n" +\ f"{self.mbatch_size=}\n{self.batch_size=}\n{self.gradient_accumulation_steps=}\n{self.epochs=}\n{self.lr=}\n{self.cutoff_len=}\n" +\ - f"{self.lora_r=}\n{self.lora_alpha=}\n{self.lora_dropout=}\n{self.val_set_size=}\n{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\ + f"{self.lora_r=}\n{self.lora_alpha=}\n{self.lora_dropout=}\n{self.val_set_size=}\n" +\ + f"{self.gradient_checkpointing=}\n{self.gradient_checkpointing_ratio=}\n" +\ + f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\ f"{self.logging_steps=}\n" +\ f"{self.checkpoint=}\n{self.skip=}" + return s.replace("self.", "") diff --git a/arg_parser.py b/arg_parser.py index 9794e57..bb25086 100644 --- a/arg_parser.py +++ b/arg_parser.py @@ -43,6 +43,8 @@ def parse_commandline(): parser_training.add_argument("--lora_r", default=8, type=int, help="Default: %(default)s") parser_training.add_argument("--lora_alpha", default=16, type=int, help="Default: %(default)s") parser_training.add_argument("--lora_dropout", default=0.05, type=float, help="Default: %(default)s") + parser_training.add_argument("--grad_chckpt", action="store_true", required=False, help="Use gradient checkpoint. For 30B model. Default: %(default)s") + parser_training.add_argument("--grad_chckpt_ratio", default=1, type=float, help="Gradient checkpoint ratio. Default: %(default)s") parser_training.add_argument("--val_set_size", default=0.2, type=float, help="Validation set size. Default: %(default)s") parser_training.add_argument("--warmup_steps", default=50, type=int, help="Default: %(default)s") parser_training.add_argument("--save_steps", default=50, type=int, help="Default: %(default)s") @@ -72,6 +74,8 @@ def get_config() -> Finetune4bConfig: lora_alpha=args["lora_alpha"], lora_dropout=args["lora_dropout"], val_set_size=args["val_set_size"], + gradient_checkpointing=args["grad_chckpt"], + gradient_checkpointing_ratio=args["grad_chckpt_ratio"], warmup_steps=args["warmup_steps"], save_steps=args["save_steps"], save_total_limit=args["save_total_limit"], diff --git a/finetune.py b/finetune.py index b8809f6..41dc376 100644 --- a/finetune.py +++ b/finetune.py @@ -40,6 +40,8 @@ ft_config = get_config() # * Show loaded parameters print(f"{ft_config}\n") +if ft_config.gradient_checkpointing: + print('Disable Dropout.') # Load Basic Model model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model) @@ -70,7 +72,6 @@ for n, m in model.named_modules(): tokenizer.pad_token_id = 0 if not ft_config.skip: - # ! TODO: Refactor to load both SAD and LLAMA datasets # Load Data data = None match ft_config.ds_type: @@ -85,6 +86,12 @@ if not ft_config.skip: data.prepare_data() #### + # Use gradient checkpointing + if ft_config.gradient_checkpointing: + print('Applying gradient checkpointing ...') + from gradient_checkpointing import apply_gradient_checkpointing + apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio) + trainer = transformers.Trainer( model=model, train_dataset=data.train_data, @@ -109,19 +116,21 @@ if not ft_config.skip: ) model.config.use_cache = False + # Set Model dict + old_state_dict = model.state_dict + model.state_dict = ( + lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) + ).__get__(model, type(model)) -# Set Model dict -old_state_dict = model.state_dict -model.state_dict = ( - lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) -).__get__(model, type(model)) + # Run Trainer + trainer.train() -# Run Trainer -trainer.train() + print('Train completed.') -print('Train completed.') - -# Save Model -model.save_pretrained(ft_config.lora_out_dir) +if not ft_config.checkpoint: + # Save Model + model.save_pretrained(ft_config.lora_out_dir) +else: + raise NotImplemented("TODO: Merge model + LoRA and save the whole checkpoint") print('Model Saved.')