Merge pull request #30 from winglian/features/python-fixes
backwards support for pre-py3.10, add datasets requirement used in train
This commit is contained in:
commit
667e43cb5b
21
finetune.py
21
finetune.py
|
|
@ -1,7 +1,7 @@
|
||||||
"""
|
"""
|
||||||
llama-4b trainer with support of Stanford Alpaca-like JSON datasets (short for SAD)
|
llama-4b trainer with support of Stanford Alpaca-like JSON datasets (short for SAD)
|
||||||
Intended to use with https://github.com/johnsmith0031/alpaca_lora_4bit
|
Intended to use with https://github.com/johnsmith0031/alpaca_lora_4bit
|
||||||
|
|
||||||
SAD structure:
|
SAD structure:
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
|
|
@ -72,15 +72,14 @@ tokenizer.pad_token_id = 0
|
||||||
if not ft_config.skip:
|
if not ft_config.skip:
|
||||||
# Load Data
|
# Load Data
|
||||||
data = None
|
data = None
|
||||||
match ft_config.ds_type:
|
if ft_config.ds_type == "txt" and not ft_config.skip:
|
||||||
case "txt" if not ft_config.skip:
|
#### LLaMa
|
||||||
#### LLaMA
|
data = train_data.TrainTxt(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
|
||||||
data = train_data.TrainTxt(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
|
elif ft_config.ds_type == "alpaca" and not ft_config.skip:
|
||||||
case "alpaca" if not ft_config.skip:
|
#### Stanford Alpaca-like Data
|
||||||
#### Stanford Alpaca-like Data
|
data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
|
||||||
data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
|
else:
|
||||||
case _:
|
raise NotImplementedError("ERROR: Unknown dataset format")
|
||||||
raise NotImplementedError("ERROR: Unknown dataset format")
|
|
||||||
data.prepare_data()
|
data.prepare_data()
|
||||||
####
|
####
|
||||||
|
|
||||||
|
|
@ -136,5 +135,5 @@ model.save_pretrained(ft_config.lora_out_dir)
|
||||||
|
|
||||||
if ft_config.checkpoint:
|
if ft_config.checkpoint:
|
||||||
print("Warning: Merge model + LoRA and save the whole checkpoint not implemented yet.")
|
print("Warning: Merge model + LoRA and save the whole checkpoint not implemented yet.")
|
||||||
|
|
||||||
print('Model Saved.')
|
print('Model Saved.')
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
torch
|
torch
|
||||||
accelerate
|
accelerate
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
|
datasets
|
||||||
|
sentencepiece
|
||||||
git+https://github.com/huggingface/transformers.git
|
git+https://github.com/huggingface/transformers.git
|
||||||
git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit
|
git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit
|
||||||
git+https://github.com/sterlind/peft.git
|
git+https://github.com/sterlind/peft.git
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue