add more scripts and adjust code for transformer branch
This commit is contained in:
parent
a955a1c2a5
commit
dc036373b2
|
|
@ -198,18 +198,18 @@ def model_to_float(model):
|
||||||
def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
||||||
import transformers
|
import transformers
|
||||||
import accelerate
|
import accelerate
|
||||||
from transformers import LLaMAConfig, LLaMAForCausalLM, LLaMATokenizer
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
from modelutils import find_layers
|
from modelutils import find_layers
|
||||||
|
|
||||||
print("Loading Model ...")
|
print("Loading Model ...")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
config = LLaMAConfig.from_pretrained(config_path)
|
config = LlamaConfig.from_pretrained(config_path)
|
||||||
torch.set_default_dtype(torch.half)
|
torch.set_default_dtype(torch.half)
|
||||||
transformers.modeling_utils._init_weights = False
|
transformers.modeling_utils._init_weights = False
|
||||||
torch.set_default_dtype(torch.half)
|
torch.set_default_dtype(torch.half)
|
||||||
model = LLaMAForCausalLM(config)
|
model = LlamaForCausalLM(config)
|
||||||
torch.set_default_dtype(torch.float)
|
torch.set_default_dtype(torch.float)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
layers = find_layers(model)
|
layers = find_layers(model)
|
||||||
|
|
@ -224,7 +224,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
||||||
if half:
|
if half:
|
||||||
model_to_half(model)
|
model_to_half(model)
|
||||||
|
|
||||||
tokenizer = LLaMATokenizer.from_pretrained(config_path)
|
tokenizer = LlamaTokenizer.from_pretrained(config_path)
|
||||||
tokenizer.truncation_side = 'left'
|
tokenizer.truncation_side = 'left'
|
||||||
|
|
||||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ Made some adjust for the code in peft and gptq for llama, and make it possible f
|
||||||
<br>
|
<br>
|
||||||
Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
|
Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
|
||||||
<br>
|
<br>
|
||||||
|
Added install script for windows and linux.
|
||||||
|
<br>
|
||||||
# Requirements
|
# Requirements
|
||||||
gptq-for-llama: https://github.com/qwopqwop200/GPTQ-for-LLaMa<br>
|
gptq-for-llama: https://github.com/qwopqwop200/GPTQ-for-LLaMa<br>
|
||||||
peft: https://github.com/huggingface/peft.git<br>
|
peft: https://github.com/huggingface/peft.git<br>
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,144 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, './repository/transformers/src')
|
||||||
|
sys.path.insert(0, './repository/GPTQ-for-LLaMa')
|
||||||
|
sys.path.insert(0, './repository/peft/src')
|
||||||
|
|
||||||
|
import peft
|
||||||
|
import peft.tuners.lora
|
||||||
|
assert peft.tuners.lora.is_gptq_available()
|
||||||
|
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
import accelerate
|
||||||
|
from modelutils import find_layers
|
||||||
|
from autograd_4bit import make_quant_for_4bit_autograd
|
||||||
|
from autograd_4bit import load_llama_model_4bit_low_ram
|
||||||
|
from datasets import load_dataset, Dataset
|
||||||
|
import json
|
||||||
|
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel
|
||||||
|
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
DATA_PATH = "./data.txt"
|
||||||
|
OUTPUT_DIR = "alpaca_lora"
|
||||||
|
lora_path_old = ''
|
||||||
|
config_path = './llama-13b-4bit/'
|
||||||
|
model_path = './llama-13b-4bit.pt'
|
||||||
|
|
||||||
|
MICRO_BATCH_SIZE = 1
|
||||||
|
BATCH_SIZE = 2
|
||||||
|
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
|
||||||
|
EPOCHS = 3
|
||||||
|
LEARNING_RATE = 2e-4
|
||||||
|
CUTOFF_LEN = 256
|
||||||
|
LORA_R = 8
|
||||||
|
LORA_ALPHA = 16
|
||||||
|
LORA_DROPOUT = 0.05
|
||||||
|
VAL_SET_SIZE = 0
|
||||||
|
TARGET_MODULES = [
|
||||||
|
"q_proj",
|
||||||
|
"v_proj",
|
||||||
|
]
|
||||||
|
warmup_steps = 50
|
||||||
|
save_steps = 50
|
||||||
|
save_total_limit = 3
|
||||||
|
logging_steps = 10
|
||||||
|
|
||||||
|
# Load Basic Model
|
||||||
|
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path)
|
||||||
|
|
||||||
|
# Config Lora
|
||||||
|
config = LoraConfig(
|
||||||
|
r=LORA_R,
|
||||||
|
lora_alpha=LORA_ALPHA,
|
||||||
|
target_modules=["q_proj", "v_proj"],
|
||||||
|
lora_dropout=LORA_DROPOUT,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
if lora_path_old == '':
|
||||||
|
model = get_peft_model(model, config)
|
||||||
|
else:
|
||||||
|
model = PeftModel.from_pretrained(model, lora_path_old)
|
||||||
|
print(lora_path_old, 'loaded')
|
||||||
|
|
||||||
|
# Scales to half
|
||||||
|
print('Fitting 4bit scales and zeros to half')
|
||||||
|
for n, m in model.named_modules():
|
||||||
|
if '4bit' in str(type(m)):
|
||||||
|
m.zeros = m.zeros.half()
|
||||||
|
m.scales = m.scales.half()
|
||||||
|
|
||||||
|
# Set tokenizer
|
||||||
|
tokenizer.pad_token_id = 0
|
||||||
|
|
||||||
|
# Load Data
|
||||||
|
with open(DATA_PATH, 'r', encoding='utf8') as file:
|
||||||
|
txt = file.read()
|
||||||
|
txt = txt.replace('\r\n', '\n')
|
||||||
|
rows = [r for r in txt.split('\n') if r != '']
|
||||||
|
data = Dataset.from_dict({"input": rows})
|
||||||
|
exceed_count = 0
|
||||||
|
def tokenize(prompt):
|
||||||
|
# there's probably a way to do this with the tokenizer settings
|
||||||
|
# but again, gotta move fast
|
||||||
|
global exceed_count
|
||||||
|
prompt = prompt['input']
|
||||||
|
result = tokenizer(
|
||||||
|
prompt,
|
||||||
|
truncation=True,
|
||||||
|
max_length=CUTOFF_LEN + 1,
|
||||||
|
padding="max_length",
|
||||||
|
)
|
||||||
|
d = {
|
||||||
|
"input_ids": result["input_ids"][:-1],
|
||||||
|
"attention_mask": result["attention_mask"][:-1],
|
||||||
|
}
|
||||||
|
if sum(d['attention_mask']) >= CUTOFF_LEN:
|
||||||
|
exceed_count += 1
|
||||||
|
return d
|
||||||
|
data = data.shuffle().map(lambda x: tokenize(x))
|
||||||
|
print('Train Data: {:.2f}%'.format(exceed_count / len(data) * 100), 'outliers')
|
||||||
|
train_data = data
|
||||||
|
|
||||||
|
trainer = transformers.Trainer(
|
||||||
|
model=model,
|
||||||
|
train_dataset=train_data,
|
||||||
|
args=transformers.TrainingArguments(
|
||||||
|
per_device_train_batch_size=MICRO_BATCH_SIZE,
|
||||||
|
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
num_train_epochs=EPOCHS,
|
||||||
|
learning_rate=LEARNING_RATE,
|
||||||
|
fp16=True,
|
||||||
|
logging_steps=logging_steps,
|
||||||
|
evaluation_strategy="no",
|
||||||
|
save_strategy="steps",
|
||||||
|
eval_steps=None,
|
||||||
|
save_steps=save_steps,
|
||||||
|
output_dir=OUTPUT_DIR,
|
||||||
|
save_total_limit=save_total_limit,
|
||||||
|
load_best_model_at_end=False
|
||||||
|
),
|
||||||
|
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||||
|
)
|
||||||
|
model.config.use_cache = False
|
||||||
|
|
||||||
|
# Set Model dict
|
||||||
|
old_state_dict = model.state_dict
|
||||||
|
model.state_dict = (
|
||||||
|
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
||||||
|
).__get__(model, type(model))
|
||||||
|
|
||||||
|
# Run Trainer
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
print('Train completed.')
|
||||||
|
|
||||||
|
# Save Model
|
||||||
|
model.save_pretrained(OUTPUT_DIR)
|
||||||
|
|
||||||
|
print('Model Saved.')
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, './repository/transformers/src')
|
||||||
|
sys.path.insert(0, './repository/GPTQ-for-LLaMa')
|
||||||
|
sys.path.insert(0, './repository/peft/src')
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from autograd_4bit import load_llama_model_4bit_low_ram
|
||||||
|
config_path = './llama-13b-4bit/'
|
||||||
|
model_path = './llama-13b-4bit.pt'
|
||||||
|
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path)
|
||||||
|
|
||||||
|
print('Fitting 4bit scales and zeros to half')
|
||||||
|
for n, m in model.named_modules():
|
||||||
|
if '4bit' in str(type(m)):
|
||||||
|
m.zeros = m.zeros.half()
|
||||||
|
m.scales = m.scales.half()
|
||||||
|
|
||||||
|
prompt = '''I think the meaning of life is'''
|
||||||
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
|
||||||
|
batch = {k: v.cuda() for k, v in batch.items()}
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
generated = model.generate(inputs=batch["input_ids"],
|
||||||
|
do_sample=True, use_cache=True,
|
||||||
|
repetition_penalty=1.1,
|
||||||
|
max_new_tokens=20,
|
||||||
|
temperature=0.9,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=40,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
output_scores=False)
|
||||||
|
result_text = tokenizer.decode(generated['sequences'].cpu().tolist()[0])
|
||||||
|
end = time.time()
|
||||||
|
print(result_text)
|
||||||
|
print(end - start)
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
REM This is a install script for Alpaca_LoRA_4bit
|
||||||
|
|
||||||
|
REM makedir ./repository/ if not exists
|
||||||
|
if not exist .\repository mkdir .\repository
|
||||||
|
|
||||||
|
REM Clone repos into current repository into ./repository/
|
||||||
|
git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git ./repository/GPTQ-for-LLaMa
|
||||||
|
git clone https://github.com/huggingface/peft.git ./repository/peft
|
||||||
|
git clone https://github.com/huggingface/transformers.git ./repository/transformers
|
||||||
|
|
||||||
|
REM replace ./repository/peft/src/peft/tuners/lora.py with ./peft/tuners/lora.py
|
||||||
|
copy .\peft\tuners\lora.py .\repository\peft\src\peft\tuners\lora.py /Y
|
||||||
|
|
||||||
|
REM replace ./repository/GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu with ./GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu
|
||||||
|
copy .\GPTQ-for-LLaMa\quant_cuda.cpp .\repository\GPTQ-for-LLaMa\quant_cuda.cpp /Y
|
||||||
|
copy .\GPTQ-for-LLaMa\quant_cuda_kernel.cu .\repository\GPTQ-for-LLaMa\quant_cuda_kernel.cu /Y
|
||||||
|
|
||||||
|
REM copy autograd_4bit.py into ./repository/GPTQ-for-LLaMa/autograd_4bit.py
|
||||||
|
copy .\GPTQ-for-LLaMa\autograd_4bit.py .\repository\GPTQ-for-LLaMa\autograd_4bit.py /Y
|
||||||
|
|
||||||
|
REM install quant_cuda
|
||||||
|
cd .\repository\GPTQ-for-LLaMa
|
||||||
|
python setup_cuda.py install
|
||||||
|
|
||||||
|
echo "Install finished"
|
||||||
|
@pause
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# This is an install script for Alpaca_LoRA_4bit
|
||||||
|
|
||||||
|
# makedir ./repository/ if not exists
|
||||||
|
if [ ! -d "./repository" ]; then
|
||||||
|
mkdir ./repository
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Clone repos into current repository into ./repository/
|
||||||
|
git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git ./repository/GPTQ-for-LLaMa
|
||||||
|
git clone https://github.com/huggingface/peft.git ./repository/peft
|
||||||
|
git clone https://github.com/huggingface/transformers.git ./repository/transformers
|
||||||
|
|
||||||
|
# Replace ./repository/peft/src/peft/tuners/lora.py with ./peft/tuners/lora.py
|
||||||
|
cp ./peft/tuners/lora.py ./repository/peft/src/peft/tuners/lora.py
|
||||||
|
|
||||||
|
# Replace ./repository/GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu with ./GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu
|
||||||
|
cp ./GPTQ-for-LLaMa/quant_cuda.cpp ./repository/GPTQ-for-LLaMa/quant_cuda.cpp
|
||||||
|
cp ./GPTQ-for-LLaMa/quant_cuda_kernel.cu ./repository/GPTQ-for-LLaMa/quant_cuda_kernel.cu
|
||||||
|
|
||||||
|
# Copy autograd_4bit.py into ./repository/GPTQ-for-LLaMa/autograd_4bit.py
|
||||||
|
cp ./autograd_4bit.py ./repository/GPTQ-for-LLaMa/autograd_4bit.py
|
||||||
|
|
||||||
|
# Install quant_cuda and cd into ./repository/GPTQ-for-LLaMa
|
||||||
|
cd ./repository/GPTQ-for-LLaMa
|
||||||
|
python setup_cuda.py install
|
||||||
|
|
||||||
|
echo "Install finished"
|
||||||
|
read -p "Press [Enter] to continue..."
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
torch
|
torch
|
||||||
accelerate
|
accelerate
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
git+https://github.com/zphang/transformers@llama_push
|
git+https://github.com/huggingface/transformers.git
|
||||||
git+https://github.com/qwopqwop200/GPTQ-for-LLaMa.git
|
git+https://github.com/qwopqwop200/GPTQ-for-LLaMa.git
|
||||||
git+https://github.com/huggingface/peft.git
|
git+https://github.com/huggingface/peft.git
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue