alpaca_lora_4bit/finetune.py

128 lines
4.0 KiB
Python

"""
llama-4b trainer with support of Stanford Alpaca-like JSON datasets (short for SAD)
Intended to use with https://github.com/johnsmith0031/alpaca_lora_4bit
SAD structure:
[
{
"instruction": "Give null hypothesis",
"input": "6 subjects were given a drug (treatment group) and an additional 6 subjects a placebo (control group).",
"output": "Drug is equivalent of placebo"
},
{
"instruction": "What does RNA stand for?",
"input": "",
"output": "RNA stands for ribonucleic acid."
}
]
"""
import sys
sys.path.insert(0, './repository/transformers/src')
sys.path.insert(0, './repository/GPTQ-for-LLaMa')
sys.path.insert(0, './repository/peft/src')
import peft
import peft.tuners.lora
assert peft.tuners.lora.is_gptq_available()
import torch
import transformers
from autograd_4bit import load_llama_model_4bit_low_ram
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel
# ! Config
from arg_parser import get_config
import train_data
ft_config = get_config()
# * Show loaded parameters
print(f"{ft_config}\n")
# Load Basic Model
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir, ft_config.llama_q4_model)
# Config Lora
lora_config = LoraConfig(
r=ft_config.lora_r,
lora_alpha=ft_config.lora_alpha,
target_modules=["q_proj", "v_proj"],
lora_dropout=ft_config.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
if ft_config.lora_apply_dir is None:
model = get_peft_model(model, lora_config)
else:
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map={'': 0}, torch_dtype=torch.float32) # ! Direct copy from inference.py
print(ft_config.lora_apply_dir, 'loaded')
# Scales to half
print('Fitting 4bit scales and zeros to half')
for n, m in model.named_modules():
if '4bit' in str(type(m)):
m.zeros = m.zeros.half()
m.scales = m.scales.half()
# Set tokenizer
tokenizer.pad_token_id = 0
if not ft_config.skip:
# ! TODO: Refactor to load both SAD and LLAMA datasets
# Load Data
data = None
match ft_config.ds_type:
case "txt" if not ft_config.skip:
#### LLaMA
data = train_data.TrainTxt(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
case "alpaca" if not ft_config.skip:
#### Stanford Alpaca-like Data
data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
case _:
raise NotImplementedError("ERROR: Unknown dataset format")
data.prepare_data()
####
trainer = transformers.Trainer(
model=model,
train_dataset=data.train_data,
eval_dataset=data.val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=ft_config.mbatch_size,
gradient_accumulation_steps=ft_config.gradient_accumulation_steps,
warmup_steps=ft_config.warmup_steps,
num_train_epochs=ft_config.epochs,
learning_rate=ft_config.lr,
fp16=True,
logging_steps=ft_config.logging_steps,
evaluation_strategy="no",
save_strategy="steps",
eval_steps=None,
save_steps=ft_config.save_steps,
output_dir=ft_config.lora_out_dir,
save_total_limit=ft_config.save_total_limit,
load_best_model_at_end=False
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
# Set Model dict
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
# Run Trainer
trainer.train()
print('Train completed.')
# Save Model
model.save_pretrained(ft_config.lora_out_dir)
print('Model Saved.')