add more scripts and adjust code for transformer branch

This commit is contained in:
John Smith 2023-03-22 04:09:04 +00:00
parent a955a1c2a5
commit dc036373b2
7 changed files with 246 additions and 5 deletions

View File

@ -198,18 +198,18 @@ def model_to_float(model):
def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
import transformers
import accelerate
from transformers import LLaMAConfig, LLaMAForCausalLM, LLaMATokenizer
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
from modelutils import find_layers
print("Loading Model ...")
t0 = time.time()
with accelerate.init_empty_weights():
config = LLaMAConfig.from_pretrained(config_path)
config = LlamaConfig.from_pretrained(config_path)
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = LLaMAForCausalLM(config)
model = LlamaForCausalLM(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
@ -224,7 +224,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
if half:
model_to_half(model)
tokenizer = LLaMATokenizer.from_pretrained(config_path)
tokenizer = LlamaTokenizer.from_pretrained(config_path)
tokenizer.truncation_side = 'left'
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")

View File

@ -5,6 +5,8 @@ Made some adjust for the code in peft and gptq for llama, and make it possible f
<br>
Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
<br>
Added install script for windows and linux.
<br>
# Requirements
gptq-for-llama: https://github.com/qwopqwop200/GPTQ-for-LLaMa<br>
peft: https://github.com/huggingface/peft.git<br>

144
finetune.py Normal file
View File

@ -0,0 +1,144 @@
import os
import sys
sys.path.insert(0, './repository/transformers/src')
sys.path.insert(0, './repository/GPTQ-for-LLaMa')
sys.path.insert(0, './repository/peft/src')
import peft
import peft.tuners.lora
assert peft.tuners.lora.is_gptq_available()
import time
import torch
import transformers
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
import accelerate
from modelutils import find_layers
from autograd_4bit import make_quant_for_4bit_autograd
from autograd_4bit import load_llama_model_4bit_low_ram
from datasets import load_dataset, Dataset
import json
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel
# Parameters
DATA_PATH = "./data.txt"
OUTPUT_DIR = "alpaca_lora"
lora_path_old = ''
config_path = './llama-13b-4bit/'
model_path = './llama-13b-4bit.pt'
MICRO_BATCH_SIZE = 1
BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
EPOCHS = 3
LEARNING_RATE = 2e-4
CUTOFF_LEN = 256
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
VAL_SET_SIZE = 0
TARGET_MODULES = [
"q_proj",
"v_proj",
]
warmup_steps = 50
save_steps = 50
save_total_limit = 3
logging_steps = 10
# Load Basic Model
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path)
# Config Lora
config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=["q_proj", "v_proj"],
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM",
)
if lora_path_old == '':
model = get_peft_model(model, config)
else:
model = PeftModel.from_pretrained(model, lora_path_old)
print(lora_path_old, 'loaded')
# Scales to half
print('Fitting 4bit scales and zeros to half')
for n, m in model.named_modules():
if '4bit' in str(type(m)):
m.zeros = m.zeros.half()
m.scales = m.scales.half()
# Set tokenizer
tokenizer.pad_token_id = 0
# Load Data
with open(DATA_PATH, 'r', encoding='utf8') as file:
txt = file.read()
txt = txt.replace('\r\n', '\n')
rows = [r for r in txt.split('\n') if r != '']
data = Dataset.from_dict({"input": rows})
exceed_count = 0
def tokenize(prompt):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
global exceed_count
prompt = prompt['input']
result = tokenizer(
prompt,
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
)
d = {
"input_ids": result["input_ids"][:-1],
"attention_mask": result["attention_mask"][:-1],
}
if sum(d['attention_mask']) >= CUTOFF_LEN:
exceed_count += 1
return d
data = data.shuffle().map(lambda x: tokenize(x))
print('Train Data: {:.2f}%'.format(exceed_count / len(data) * 100), 'outliers')
train_data = data
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=MICRO_BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
warmup_steps=warmup_steps,
num_train_epochs=EPOCHS,
learning_rate=LEARNING_RATE,
fp16=True,
logging_steps=logging_steps,
evaluation_strategy="no",
save_strategy="steps",
eval_steps=None,
save_steps=save_steps,
output_dir=OUTPUT_DIR,
save_total_limit=save_total_limit,
load_best_model_at_end=False
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
# Set Model dict
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
# Run Trainer
trainer.train()
print('Train completed.')
# Save Model
model.save_pretrained(OUTPUT_DIR)
print('Model Saved.')

39
inference.py Normal file
View File

@ -0,0 +1,39 @@
import os
import sys
sys.path.insert(0, './repository/transformers/src')
sys.path.insert(0, './repository/GPTQ-for-LLaMa')
sys.path.insert(0, './repository/peft/src')
import time
import torch
from autograd_4bit import load_llama_model_4bit_low_ram
config_path = './llama-13b-4bit/'
model_path = './llama-13b-4bit.pt'
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path)
print('Fitting 4bit scales and zeros to half')
for n, m in model.named_modules():
if '4bit' in str(type(m)):
m.zeros = m.zeros.half()
m.scales = m.scales.half()
prompt = '''I think the meaning of life is'''
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
batch = {k: v.cuda() for k, v in batch.items()}
start = time.time()
with torch.no_grad():
generated = model.generate(inputs=batch["input_ids"],
do_sample=True, use_cache=True,
repetition_penalty=1.1,
max_new_tokens=20,
temperature=0.9,
top_p=0.95,
top_k=40,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False)
result_text = tokenizer.decode(generated['sequences'].cpu().tolist()[0])
end = time.time()
print(result_text)
print(end - start)

26
install.bat Normal file
View File

@ -0,0 +1,26 @@
REM This is a install script for Alpaca_LoRA_4bit
REM makedir ./repository/ if not exists
if not exist .\repository mkdir .\repository
REM Clone repos into current repository into ./repository/
git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git ./repository/GPTQ-for-LLaMa
git clone https://github.com/huggingface/peft.git ./repository/peft
git clone https://github.com/huggingface/transformers.git ./repository/transformers
REM replace ./repository/peft/src/peft/tuners/lora.py with ./peft/tuners/lora.py
copy .\peft\tuners\lora.py .\repository\peft\src\peft\tuners\lora.py /Y
REM replace ./repository/GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu with ./GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu
copy .\GPTQ-for-LLaMa\quant_cuda.cpp .\repository\GPTQ-for-LLaMa\quant_cuda.cpp /Y
copy .\GPTQ-for-LLaMa\quant_cuda_kernel.cu .\repository\GPTQ-for-LLaMa\quant_cuda_kernel.cu /Y
REM copy autograd_4bit.py into ./repository/GPTQ-for-LLaMa/autograd_4bit.py
copy .\GPTQ-for-LLaMa\autograd_4bit.py .\repository\GPTQ-for-LLaMa\autograd_4bit.py /Y
REM install quant_cuda
cd .\repository\GPTQ-for-LLaMa
python setup_cuda.py install
echo "Install finished"
@pause

30
install.sh Normal file
View File

@ -0,0 +1,30 @@
#!/bin/bash
# This is an install script for Alpaca_LoRA_4bit
# makedir ./repository/ if not exists
if [ ! -d "./repository" ]; then
mkdir ./repository
fi
# Clone repos into current repository into ./repository/
git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git ./repository/GPTQ-for-LLaMa
git clone https://github.com/huggingface/peft.git ./repository/peft
git clone https://github.com/huggingface/transformers.git ./repository/transformers
# Replace ./repository/peft/src/peft/tuners/lora.py with ./peft/tuners/lora.py
cp ./peft/tuners/lora.py ./repository/peft/src/peft/tuners/lora.py
# Replace ./repository/GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu with ./GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu
cp ./GPTQ-for-LLaMa/quant_cuda.cpp ./repository/GPTQ-for-LLaMa/quant_cuda.cpp
cp ./GPTQ-for-LLaMa/quant_cuda_kernel.cu ./repository/GPTQ-for-LLaMa/quant_cuda_kernel.cu
# Copy autograd_4bit.py into ./repository/GPTQ-for-LLaMa/autograd_4bit.py
cp ./autograd_4bit.py ./repository/GPTQ-for-LLaMa/autograd_4bit.py
# Install quant_cuda and cd into ./repository/GPTQ-for-LLaMa
cd ./repository/GPTQ-for-LLaMa
python setup_cuda.py install
echo "Install finished"
read -p "Press [Enter] to continue..."

View File

@ -1,6 +1,6 @@
torch
accelerate
bitsandbytes
git+https://github.com/zphang/transformers@llama_push
git+https://github.com/huggingface/transformers.git
git+https://github.com/qwopqwop200/GPTQ-for-LLaMa.git
git+https://github.com/huggingface/peft.git