Reflect last changes in main

Reflect commits:
4906961bf1
60b227d0ba
This commit is contained in:
Andrey Glushenkov 2023-03-24 15:46:03 +03:00 committed by GitHub
parent 50dbb101e9
commit 397f5041c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 25 deletions

View File

@ -1,16 +1,18 @@
class Finetune4bConfig:
"""Config holder for LLaMA 4bit finetuning
"""
def __init__(self, dataset : str, ds_type : str,
lora_out_dir : str, lora_apply_dir : str,
llama_q4_config_dir : str, llama_q4_model : str,
mbatch_size : int, batch_size : int,
epochs : int, lr : float,
cutoff_len : int,
lora_r : int, lora_alpha : int, lora_dropout : float,
val_set_size : float,
warmup_steps : int, save_steps : int, save_total_limit : int, logging_steps : int,
checkpoint : bool, skip : bool
def __init__(self, dataset: str, ds_type: str,
lora_out_dir: str, lora_apply_dir : str,
llama_q4_config_dir: str, llama_q4_model: str,
mbatch_size: int, batch_size: int,
epochs: int, lr: float,
cutoff_len: int,
lora_r: int, lora_alpha: int, lora_dropout: float,
val_set_size: float,
gradient_checkpointing: bool,
gradient_checkpointing_ratio: float,
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
checkpoint: bool, skip: bool
):
"""
Args:
@ -28,6 +30,8 @@ class Finetune4bConfig:
lora_r (int): LoRA R
lora_alpha (int): LoRA Alpha
lora_dropout (float): LoRA Dropout
gradient_checkpointing (bool) : Use gradient checkpointing
gradient_checkpointing_ratio (float) : Gradient checkpoint ratio
val_set_size (int): Validation set size
warmup_steps (int): Warmup steps before training
save_steps (int): Save steps
@ -50,8 +54,10 @@ class Finetune4bConfig:
self.cutoff_len = cutoff_len
self.lora_r = lora_r
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
self.lora_dropout = 0 if gradient_checkpointing else lora_dropout # should be 0 if gradient checkpointing is on
self.val_set_size = int(val_set_size) if val_set_size > 1.0 else float(val_set_size)
self.gradient_checkpointing = gradient_checkpointing
self.gradient_checkpointing_ratio = gradient_checkpointing_ratio
self.warmup_steps = warmup_steps
self.save_steps = save_steps
self.save_total_limit = save_total_limit
@ -61,9 +67,12 @@ class Finetune4bConfig:
def __str__(self) -> str:
return f"\nParameters:\n{'config':-^20}\n{self.dataset=}\n{self.ds_type=}\n{self.lora_out_dir=}\n{self.lora_apply_dir=}\n{self.llama_q4_config_dir=}\n{self.llama_q4_model=}\n\n" +\
s = f"\nParameters:\n{'config':-^20}\n{self.dataset=}\n{self.ds_type=}\n{self.lora_out_dir=}\n{self.lora_apply_dir=}\n{self.llama_q4_config_dir=}\n{self.llama_q4_model=}\n\n" +\
f"{'training':-^20}\n" +\
f"{self.mbatch_size=}\n{self.batch_size=}\n{self.gradient_accumulation_steps=}\n{self.epochs=}\n{self.lr=}\n{self.cutoff_len=}\n" +\
f"{self.lora_r=}\n{self.lora_alpha=}\n{self.lora_dropout=}\n{self.val_set_size=}\n{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\
f"{self.lora_r=}\n{self.lora_alpha=}\n{self.lora_dropout=}\n{self.val_set_size=}\n" +\
f"{self.gradient_checkpointing=}\n{self.gradient_checkpointing_ratio=}\n" +\
f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\
f"{self.logging_steps=}\n" +\
f"{self.checkpoint=}\n{self.skip=}"
return s.replace("self.", "")

View File

@ -43,6 +43,8 @@ def parse_commandline():
parser_training.add_argument("--lora_r", default=8, type=int, help="Default: %(default)s")
parser_training.add_argument("--lora_alpha", default=16, type=int, help="Default: %(default)s")
parser_training.add_argument("--lora_dropout", default=0.05, type=float, help="Default: %(default)s")
parser_training.add_argument("--grad_chckpt", action="store_true", required=False, help="Use gradient checkpoint. For 30B model. Default: %(default)s")
parser_training.add_argument("--grad_chckpt_ratio", default=1, type=float, help="Gradient checkpoint ratio. Default: %(default)s")
parser_training.add_argument("--val_set_size", default=0.2, type=float, help="Validation set size. Default: %(default)s")
parser_training.add_argument("--warmup_steps", default=50, type=int, help="Default: %(default)s")
parser_training.add_argument("--save_steps", default=50, type=int, help="Default: %(default)s")
@ -72,6 +74,8 @@ def get_config() -> Finetune4bConfig:
lora_alpha=args["lora_alpha"],
lora_dropout=args["lora_dropout"],
val_set_size=args["val_set_size"],
gradient_checkpointing=args["grad_chckpt"],
gradient_checkpointing_ratio=args["grad_chckpt_ratio"],
warmup_steps=args["warmup_steps"],
save_steps=args["save_steps"],
save_total_limit=args["save_total_limit"],

View File

@ -40,6 +40,8 @@ ft_config = get_config()
# * Show loaded parameters
print(f"{ft_config}\n")
if ft_config.gradient_checkpointing:
print('Disable Dropout.')
# Load Basic Model
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model)
@ -70,7 +72,6 @@ for n, m in model.named_modules():
tokenizer.pad_token_id = 0
if not ft_config.skip:
# ! TODO: Refactor to load both SAD and LLAMA datasets
# Load Data
data = None
match ft_config.ds_type:
@ -85,6 +86,12 @@ if not ft_config.skip:
data.prepare_data()
####
# Use gradient checkpointing
if ft_config.gradient_checkpointing:
print('Applying gradient checkpointing ...')
from gradient_checkpointing import apply_gradient_checkpointing
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
trainer = transformers.Trainer(
model=model,
train_dataset=data.train_data,
@ -109,19 +116,21 @@ if not ft_config.skip:
)
model.config.use_cache = False
# Set Model dict
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
# Set Model dict
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
# Run Trainer
trainer.train()
# Run Trainer
trainer.train()
print('Train completed.')
print('Train completed.')
# Save Model
model.save_pretrained(ft_config.lora_out_dir)
if not ft_config.checkpoint:
# Save Model
model.save_pretrained(ft_config.lora_out_dir)
else:
raise NotImplemented("TODO: Merge model + LoRA and save the whole checkpoint")
print('Model Saved.')