Merge pull request #30 from winglian/features/python-fixes

backwards support for pre-py3.10, add datasets requirement used in train
This commit is contained in:
John Smith 2023-03-28 09:34:50 +08:00 committed by GitHub
commit 667e43cb5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 11 deletions

View File

@ -1,7 +1,7 @@
"""
llama-4b trainer with support of Stanford Alpaca-like JSON datasets (short for SAD)
Intended to use with https://github.com/johnsmith0031/alpaca_lora_4bit
SAD structure:
[
{
@ -72,15 +72,14 @@ tokenizer.pad_token_id = 0
if not ft_config.skip:
# Load Data
data = None
match ft_config.ds_type:
case "txt" if not ft_config.skip:
#### LLaMA
data = train_data.TrainTxt(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
case "alpaca" if not ft_config.skip:
#### Stanford Alpaca-like Data
data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
case _:
raise NotImplementedError("ERROR: Unknown dataset format")
if ft_config.ds_type == "txt" and not ft_config.skip:
#### LLaMa
data = train_data.TrainTxt(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
elif ft_config.ds_type == "alpaca" and not ft_config.skip:
#### Stanford Alpaca-like Data
data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
else:
raise NotImplementedError("ERROR: Unknown dataset format")
data.prepare_data()
####
@ -136,5 +135,5 @@ model.save_pretrained(ft_config.lora_out_dir)
if ft_config.checkpoint:
print("Warning: Merge model + LoRA and save the whole checkpoint not implemented yet.")
print('Model Saved.')

View File

@ -1,6 +1,8 @@
torch
accelerate
bitsandbytes
datasets
sentencepiece
git+https://github.com/huggingface/transformers.git
git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit
git+https://github.com/sterlind/peft.git