alpaca_lora_4bit/finetune.py

193 lines
6.7 KiB
Python

"""
llama-4b trainer with support of Stanford Alpaca-like JSON datasets (short for SAD)
Intended to use with https://github.com/johnsmith0031/alpaca_lora_4bit
SAD structure:
[
{
"instruction": "Give null hypothesis",
"input": "6 subjects were given a drug (treatment group) and an additional 6 subjects a placebo (control group).",
"output": "Drug is equivalent of placebo"
},
{
"instruction": "What does RNA stand for?",
"input": "",
"output": "RNA stands for ribonucleic acid."
}
]
"""
# Early load config to replace attn if needed
from arg_parser import get_config
ft_config = get_config()
from monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_gptq_lora_model
replace_peft_model_with_gptq_lora_model()
if ft_config.flash_attention:
from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
elif ft_config.xformers:
from monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention
hijack_llama_attention()
import autograd_4bit
if ft_config.backend.lower() == 'triton':
autograd_4bit.switch_backend_to('triton')
else:
autograd_4bit.switch_backend_to('cuda')
import sys
import os
import peft
import peft.tuners.lora
import torch
import transformers
from autograd_4bit import load_llama_model_4bit_low_ram
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel, set_peft_model_state_dict
# ! Config
import train_data
# * Show loaded parameters
if ft_config.local_rank == 0:
print(f"{ft_config}\n")
if ft_config.gradient_checkpointing:
print('Disable Dropout.')
if ft_config.mbatch_size > ft_config.batch_size:
raise Exception('batch_size need to be larger than mbatch_size.')
# Load Basic Model
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir,
ft_config.llama_q4_model,
device_map=ft_config.device_map,
groupsize=ft_config.groupsize,
is_v1_model=ft_config.v1)
# Config Lora
lora_config = LoraConfig(
r=ft_config.lora_r,
lora_alpha=ft_config.lora_alpha,
target_modules=["q_proj", "v_proj"],
lora_dropout=ft_config.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
if ft_config.lora_apply_dir is None:
model = get_peft_model(model, lora_config)
else:
device_map = ft_config.device_map
if ft_config.ddp:
device_map = {'': 0}
else:
if torch.cuda.device_count() > 1:
device_map = "auto"
else:
device_map = {'': 0}
print('Device map for lora:', device_map)
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map=device_map, torch_dtype=torch.float32, is_trainable=True)
print(ft_config.lora_apply_dir, 'loaded')
# Scales to half
print('Fitting 4bit scales and zeros to half')
for n, m in model.named_modules():
if '4bit' in str(type(m)):
if m.is_v1_model:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
# Set tokenizer
tokenizer.pad_token_id = 0
if not ft_config.skip:
# Load Data
data = None
if ft_config.ds_type == "txt" and not ft_config.skip:
#### LLaMa
data = train_data.TrainTxt(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
elif ft_config.ds_type == "alpaca" and not ft_config.skip:
#### Stanford Alpaca-like Data
data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
elif ft_config.ds_type == "gpt4all" and not ft_config.skip:
#### GPT4All Data
data = train_data.TrainGPT4All(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
else:
raise NotImplementedError("ERROR: Unknown dataset format")
data.prepare_data(thd=ft_config.txt_row_thd, use_eos_token=ft_config.use_eos_token)
####
# Use gradient checkpointing
if ft_config.gradient_checkpointing:
print('Applying gradient checkpointing ...')
from gradient_checkpointing import apply_gradient_checkpointing
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
# Disable Trainer's DataParallel for multigpu
if not ft_config.ddp and torch.cuda.device_count() > 1:
model.is_parallelizable = True
model.model_parallel = True
training_arguments = transformers.TrainingArguments(
per_device_train_batch_size=ft_config.mbatch_size,
gradient_accumulation_steps=ft_config.gradient_accumulation_steps,
warmup_steps=ft_config.warmup_steps,
optim="adamw_torch",
num_train_epochs=ft_config.epochs,
learning_rate=ft_config.lr,
fp16=True,
logging_steps=ft_config.logging_steps,
evaluation_strategy="no",
save_strategy="steps",
eval_steps=None,
save_steps=ft_config.save_steps,
output_dir=ft_config.lora_out_dir,
save_total_limit=ft_config.save_total_limit,
load_best_model_at_end=False,
ddp_find_unused_parameters=False if ft_config.ddp else None,
)
trainer = transformers.Trainer(
model=model,
train_dataset=data.train_data,
eval_dataset=data.val_data,
args=training_arguments,
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
# Set Model dict
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
# Set Verbose
if ft_config.verbose:
transformers.logging.set_verbosity_info()
# Run Trainer
if ft_config.resume_checkpoint:
print('Resuming from {} ...'.format(ft_config.resume_checkpoint))
state_dict_peft = torch.load(os.path.join(ft_config.resume_checkpoint, 'pytorch_model.bin'), map_location='cpu')
set_peft_model_state_dict(model, state_dict_peft)
trainer.train(ft_config.resume_checkpoint)
else:
trainer.train()
# Restore old model state dict
model.state_dict = old_state_dict
print('Train completed.')
# Save Model
model.save_pretrained(ft_config.lora_out_dir)
if ft_config.checkpoint:
print("Warning: Merge model + LoRA and save the whole checkpoint not implemented yet.")
print('Model Saved.')