193 lines
6.7 KiB
Python
193 lines
6.7 KiB
Python
"""
|
|
llama-4b trainer with support of Stanford Alpaca-like JSON datasets (short for SAD)
|
|
Intended to use with https://github.com/johnsmith0031/alpaca_lora_4bit
|
|
|
|
SAD structure:
|
|
[
|
|
{
|
|
"instruction": "Give null hypothesis",
|
|
"input": "6 subjects were given a drug (treatment group) and an additional 6 subjects a placebo (control group).",
|
|
"output": "Drug is equivalent of placebo"
|
|
},
|
|
{
|
|
"instruction": "What does RNA stand for?",
|
|
"input": "",
|
|
"output": "RNA stands for ribonucleic acid."
|
|
}
|
|
]
|
|
"""
|
|
# Early load config to replace attn if needed
|
|
from arg_parser import get_config
|
|
ft_config = get_config()
|
|
|
|
from monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_gptq_lora_model
|
|
replace_peft_model_with_gptq_lora_model()
|
|
|
|
if ft_config.flash_attention:
|
|
from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
|
replace_llama_attn_with_flash_attn()
|
|
elif ft_config.xformers:
|
|
from monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention
|
|
hijack_llama_attention()
|
|
|
|
import autograd_4bit
|
|
if ft_config.backend.lower() == 'triton':
|
|
autograd_4bit.switch_backend_to('triton')
|
|
else:
|
|
autograd_4bit.switch_backend_to('cuda')
|
|
|
|
import sys
|
|
import os
|
|
|
|
import peft
|
|
import peft.tuners.lora
|
|
|
|
import torch
|
|
import transformers
|
|
from autograd_4bit import load_llama_model_4bit_low_ram
|
|
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel, set_peft_model_state_dict
|
|
|
|
# ! Config
|
|
import train_data
|
|
|
|
# * Show loaded parameters
|
|
if ft_config.local_rank == 0:
|
|
print(f"{ft_config}\n")
|
|
|
|
if ft_config.gradient_checkpointing:
|
|
print('Disable Dropout.')
|
|
|
|
if ft_config.mbatch_size > ft_config.batch_size:
|
|
raise Exception('batch_size need to be larger than mbatch_size.')
|
|
|
|
# Load Basic Model
|
|
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir,
|
|
ft_config.llama_q4_model,
|
|
device_map=ft_config.device_map,
|
|
groupsize=ft_config.groupsize,
|
|
is_v1_model=ft_config.v1)
|
|
|
|
# Config Lora
|
|
lora_config = LoraConfig(
|
|
r=ft_config.lora_r,
|
|
lora_alpha=ft_config.lora_alpha,
|
|
target_modules=["q_proj", "v_proj"],
|
|
lora_dropout=ft_config.lora_dropout,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
)
|
|
if ft_config.lora_apply_dir is None:
|
|
model = get_peft_model(model, lora_config)
|
|
else:
|
|
device_map = ft_config.device_map
|
|
if ft_config.ddp:
|
|
device_map = {'': 0}
|
|
else:
|
|
if torch.cuda.device_count() > 1:
|
|
device_map = "auto"
|
|
else:
|
|
device_map = {'': 0}
|
|
print('Device map for lora:', device_map)
|
|
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map=device_map, torch_dtype=torch.float32, is_trainable=True)
|
|
print(ft_config.lora_apply_dir, 'loaded')
|
|
|
|
|
|
# Scales to half
|
|
print('Fitting 4bit scales and zeros to half')
|
|
for n, m in model.named_modules():
|
|
if '4bit' in str(type(m)):
|
|
if m.is_v1_model:
|
|
m.zeros = m.zeros.half()
|
|
m.scales = m.scales.half()
|
|
|
|
# Set tokenizer
|
|
tokenizer.pad_token_id = 0
|
|
|
|
if not ft_config.skip:
|
|
# Load Data
|
|
data = None
|
|
if ft_config.ds_type == "txt" and not ft_config.skip:
|
|
#### LLaMa
|
|
data = train_data.TrainTxt(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
|
|
elif ft_config.ds_type == "alpaca" and not ft_config.skip:
|
|
#### Stanford Alpaca-like Data
|
|
data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
|
|
elif ft_config.ds_type == "gpt4all" and not ft_config.skip:
|
|
#### GPT4All Data
|
|
data = train_data.TrainGPT4All(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
|
|
else:
|
|
raise NotImplementedError("ERROR: Unknown dataset format")
|
|
data.prepare_data(thd=ft_config.txt_row_thd, use_eos_token=ft_config.use_eos_token)
|
|
####
|
|
|
|
# Use gradient checkpointing
|
|
if ft_config.gradient_checkpointing:
|
|
print('Applying gradient checkpointing ...')
|
|
from gradient_checkpointing import apply_gradient_checkpointing
|
|
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
|
|
|
|
# Disable Trainer's DataParallel for multigpu
|
|
if not ft_config.ddp and torch.cuda.device_count() > 1:
|
|
model.is_parallelizable = True
|
|
model.model_parallel = True
|
|
|
|
training_arguments = transformers.TrainingArguments(
|
|
per_device_train_batch_size=ft_config.mbatch_size,
|
|
gradient_accumulation_steps=ft_config.gradient_accumulation_steps,
|
|
warmup_steps=ft_config.warmup_steps,
|
|
optim="adamw_torch",
|
|
num_train_epochs=ft_config.epochs,
|
|
learning_rate=ft_config.lr,
|
|
fp16=True,
|
|
logging_steps=ft_config.logging_steps,
|
|
evaluation_strategy="no",
|
|
save_strategy="steps",
|
|
eval_steps=None,
|
|
save_steps=ft_config.save_steps,
|
|
output_dir=ft_config.lora_out_dir,
|
|
save_total_limit=ft_config.save_total_limit,
|
|
load_best_model_at_end=False,
|
|
ddp_find_unused_parameters=False if ft_config.ddp else None,
|
|
)
|
|
|
|
trainer = transformers.Trainer(
|
|
model=model,
|
|
train_dataset=data.train_data,
|
|
eval_dataset=data.val_data,
|
|
args=training_arguments,
|
|
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
|
)
|
|
model.config.use_cache = False
|
|
|
|
# Set Model dict
|
|
old_state_dict = model.state_dict
|
|
model.state_dict = (
|
|
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
|
).__get__(model, type(model))
|
|
|
|
# Set Verbose
|
|
if ft_config.verbose:
|
|
transformers.logging.set_verbosity_info()
|
|
|
|
# Run Trainer
|
|
if ft_config.resume_checkpoint:
|
|
print('Resuming from {} ...'.format(ft_config.resume_checkpoint))
|
|
state_dict_peft = torch.load(os.path.join(ft_config.resume_checkpoint, 'pytorch_model.bin'), map_location='cpu')
|
|
set_peft_model_state_dict(model, state_dict_peft)
|
|
trainer.train(ft_config.resume_checkpoint)
|
|
else:
|
|
trainer.train()
|
|
|
|
# Restore old model state dict
|
|
model.state_dict = old_state_dict
|
|
|
|
print('Train completed.')
|
|
|
|
# Save Model
|
|
model.save_pretrained(ft_config.lora_out_dir)
|
|
|
|
if ft_config.checkpoint:
|
|
print("Warning: Merge model + LoRA and save the whole checkpoint not implemented yet.")
|
|
|
|
print('Model Saved.')
|