diff --git a/GPTQ-for-LLaMa/autograd_4bit.py b/GPTQ-for-LLaMa/autograd_4bit.py
index dc25d08..c55b7d3 100644
--- a/GPTQ-for-LLaMa/autograd_4bit.py
+++ b/GPTQ-for-LLaMa/autograd_4bit.py
@@ -198,18 +198,18 @@ def model_to_float(model):
def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
import transformers
import accelerate
- from transformers import LLaMAConfig, LLaMAForCausalLM, LLaMATokenizer
+ from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
from modelutils import find_layers
print("Loading Model ...")
t0 = time.time()
with accelerate.init_empty_weights():
- config = LLaMAConfig.from_pretrained(config_path)
+ config = LlamaConfig.from_pretrained(config_path)
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
- model = LLaMAForCausalLM(config)
+ model = LlamaForCausalLM(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
@@ -224,7 +224,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
if half:
model_to_half(model)
- tokenizer = LLaMATokenizer.from_pretrained(config_path)
+ tokenizer = LlamaTokenizer.from_pretrained(config_path)
tokenizer.truncation_side = 'left'
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
diff --git a/README.md b/README.md
index 8174a4f..77f0868 100644
--- a/README.md
+++ b/README.md
@@ -5,6 +5,8 @@ Made some adjust for the code in peft and gptq for llama, and make it possible f
Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
+Added install script for windows and linux.
+
# Requirements
gptq-for-llama: https://github.com/qwopqwop200/GPTQ-for-LLaMa
peft: https://github.com/huggingface/peft.git
diff --git a/finetune.py b/finetune.py
new file mode 100644
index 0000000..7047a5c
--- /dev/null
+++ b/finetune.py
@@ -0,0 +1,144 @@
+import os
+import sys
+sys.path.insert(0, './repository/transformers/src')
+sys.path.insert(0, './repository/GPTQ-for-LLaMa')
+sys.path.insert(0, './repository/peft/src')
+
+import peft
+import peft.tuners.lora
+assert peft.tuners.lora.is_gptq_available()
+
+import time
+import torch
+import transformers
+from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
+import accelerate
+from modelutils import find_layers
+from autograd_4bit import make_quant_for_4bit_autograd
+from autograd_4bit import load_llama_model_4bit_low_ram
+from datasets import load_dataset, Dataset
+import json
+from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel
+
+
+# Parameters
+DATA_PATH = "./data.txt"
+OUTPUT_DIR = "alpaca_lora"
+lora_path_old = ''
+config_path = './llama-13b-4bit/'
+model_path = './llama-13b-4bit.pt'
+
+MICRO_BATCH_SIZE = 1
+BATCH_SIZE = 2
+GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
+EPOCHS = 3
+LEARNING_RATE = 2e-4
+CUTOFF_LEN = 256
+LORA_R = 8
+LORA_ALPHA = 16
+LORA_DROPOUT = 0.05
+VAL_SET_SIZE = 0
+TARGET_MODULES = [
+ "q_proj",
+ "v_proj",
+]
+warmup_steps = 50
+save_steps = 50
+save_total_limit = 3
+logging_steps = 10
+
+# Load Basic Model
+model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path)
+
+# Config Lora
+config = LoraConfig(
+ r=LORA_R,
+ lora_alpha=LORA_ALPHA,
+ target_modules=["q_proj", "v_proj"],
+ lora_dropout=LORA_DROPOUT,
+ bias="none",
+ task_type="CAUSAL_LM",
+)
+if lora_path_old == '':
+ model = get_peft_model(model, config)
+else:
+ model = PeftModel.from_pretrained(model, lora_path_old)
+ print(lora_path_old, 'loaded')
+
+# Scales to half
+print('Fitting 4bit scales and zeros to half')
+for n, m in model.named_modules():
+ if '4bit' in str(type(m)):
+ m.zeros = m.zeros.half()
+ m.scales = m.scales.half()
+
+# Set tokenizer
+tokenizer.pad_token_id = 0
+
+# Load Data
+with open(DATA_PATH, 'r', encoding='utf8') as file:
+ txt = file.read()
+txt = txt.replace('\r\n', '\n')
+rows = [r for r in txt.split('\n') if r != '']
+data = Dataset.from_dict({"input": rows})
+exceed_count = 0
+def tokenize(prompt):
+ # there's probably a way to do this with the tokenizer settings
+ # but again, gotta move fast
+ global exceed_count
+ prompt = prompt['input']
+ result = tokenizer(
+ prompt,
+ truncation=True,
+ max_length=CUTOFF_LEN + 1,
+ padding="max_length",
+ )
+ d = {
+ "input_ids": result["input_ids"][:-1],
+ "attention_mask": result["attention_mask"][:-1],
+ }
+ if sum(d['attention_mask']) >= CUTOFF_LEN:
+ exceed_count += 1
+ return d
+data = data.shuffle().map(lambda x: tokenize(x))
+print('Train Data: {:.2f}%'.format(exceed_count / len(data) * 100), 'outliers')
+train_data = data
+
+trainer = transformers.Trainer(
+ model=model,
+ train_dataset=train_data,
+ args=transformers.TrainingArguments(
+ per_device_train_batch_size=MICRO_BATCH_SIZE,
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
+ warmup_steps=warmup_steps,
+ num_train_epochs=EPOCHS,
+ learning_rate=LEARNING_RATE,
+ fp16=True,
+ logging_steps=logging_steps,
+ evaluation_strategy="no",
+ save_strategy="steps",
+ eval_steps=None,
+ save_steps=save_steps,
+ output_dir=OUTPUT_DIR,
+ save_total_limit=save_total_limit,
+ load_best_model_at_end=False
+ ),
+ data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
+)
+model.config.use_cache = False
+
+# Set Model dict
+old_state_dict = model.state_dict
+model.state_dict = (
+ lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
+).__get__(model, type(model))
+
+# Run Trainer
+trainer.train()
+
+print('Train completed.')
+
+# Save Model
+model.save_pretrained(OUTPUT_DIR)
+
+print('Model Saved.')
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000..0a25272
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,39 @@
+import os
+import sys
+sys.path.insert(0, './repository/transformers/src')
+sys.path.insert(0, './repository/GPTQ-for-LLaMa')
+sys.path.insert(0, './repository/peft/src')
+import time
+import torch
+from autograd_4bit import load_llama_model_4bit_low_ram
+config_path = './llama-13b-4bit/'
+model_path = './llama-13b-4bit.pt'
+model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path)
+
+print('Fitting 4bit scales and zeros to half')
+for n, m in model.named_modules():
+ if '4bit' in str(type(m)):
+ m.zeros = m.zeros.half()
+ m.scales = m.scales.half()
+
+prompt = '''I think the meaning of life is'''
+batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
+batch = {k: v.cuda() for k, v in batch.items()}
+
+start = time.time()
+with torch.no_grad():
+ generated = model.generate(inputs=batch["input_ids"],
+ do_sample=True, use_cache=True,
+ repetition_penalty=1.1,
+ max_new_tokens=20,
+ temperature=0.9,
+ top_p=0.95,
+ top_k=40,
+ return_dict_in_generate=True,
+ output_attentions=False,
+ output_hidden_states=False,
+ output_scores=False)
+result_text = tokenizer.decode(generated['sequences'].cpu().tolist()[0])
+end = time.time()
+print(result_text)
+print(end - start)
diff --git a/install.bat b/install.bat
new file mode 100644
index 0000000..bf5e817
--- /dev/null
+++ b/install.bat
@@ -0,0 +1,26 @@
+REM This is a install script for Alpaca_LoRA_4bit
+
+REM makedir ./repository/ if not exists
+if not exist .\repository mkdir .\repository
+
+REM Clone repos into current repository into ./repository/
+git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git ./repository/GPTQ-for-LLaMa
+git clone https://github.com/huggingface/peft.git ./repository/peft
+git clone https://github.com/huggingface/transformers.git ./repository/transformers
+
+REM replace ./repository/peft/src/peft/tuners/lora.py with ./peft/tuners/lora.py
+copy .\peft\tuners\lora.py .\repository\peft\src\peft\tuners\lora.py /Y
+
+REM replace ./repository/GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu with ./GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu
+copy .\GPTQ-for-LLaMa\quant_cuda.cpp .\repository\GPTQ-for-LLaMa\quant_cuda.cpp /Y
+copy .\GPTQ-for-LLaMa\quant_cuda_kernel.cu .\repository\GPTQ-for-LLaMa\quant_cuda_kernel.cu /Y
+
+REM copy autograd_4bit.py into ./repository/GPTQ-for-LLaMa/autograd_4bit.py
+copy .\GPTQ-for-LLaMa\autograd_4bit.py .\repository\GPTQ-for-LLaMa\autograd_4bit.py /Y
+
+REM install quant_cuda
+cd .\repository\GPTQ-for-LLaMa
+python setup_cuda.py install
+
+echo "Install finished"
+@pause
diff --git a/install.sh b/install.sh
new file mode 100644
index 0000000..262f951
--- /dev/null
+++ b/install.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+
+# This is an install script for Alpaca_LoRA_4bit
+
+# makedir ./repository/ if not exists
+if [ ! -d "./repository" ]; then
+ mkdir ./repository
+fi
+
+# Clone repos into current repository into ./repository/
+git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git ./repository/GPTQ-for-LLaMa
+git clone https://github.com/huggingface/peft.git ./repository/peft
+git clone https://github.com/huggingface/transformers.git ./repository/transformers
+
+# Replace ./repository/peft/src/peft/tuners/lora.py with ./peft/tuners/lora.py
+cp ./peft/tuners/lora.py ./repository/peft/src/peft/tuners/lora.py
+
+# Replace ./repository/GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu with ./GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu
+cp ./GPTQ-for-LLaMa/quant_cuda.cpp ./repository/GPTQ-for-LLaMa/quant_cuda.cpp
+cp ./GPTQ-for-LLaMa/quant_cuda_kernel.cu ./repository/GPTQ-for-LLaMa/quant_cuda_kernel.cu
+
+# Copy autograd_4bit.py into ./repository/GPTQ-for-LLaMa/autograd_4bit.py
+cp ./autograd_4bit.py ./repository/GPTQ-for-LLaMa/autograd_4bit.py
+
+# Install quant_cuda and cd into ./repository/GPTQ-for-LLaMa
+cd ./repository/GPTQ-for-LLaMa
+python setup_cuda.py install
+
+echo "Install finished"
+read -p "Press [Enter] to continue..."
diff --git a/requirements.txt b/requirements.txt
index 1f7dcbe..b0fca51 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,6 @@
torch
accelerate
bitsandbytes
-git+https://github.com/zphang/transformers@llama_push
+git+https://github.com/huggingface/transformers.git
git+https://github.com/qwopqwop200/GPTQ-for-LLaMa.git
git+https://github.com/huggingface/peft.git