From dc036373b2370092076e4bf9d122ba07e6f7f914 Mon Sep 17 00:00:00 2001 From: John Smith Date: Wed, 22 Mar 2023 04:09:04 +0000 Subject: [PATCH] add more scripts and adjust code for transformer branch --- GPTQ-for-LLaMa/autograd_4bit.py | 8 +- README.md | 2 + finetune.py | 144 ++++++++++++++++++++++++++++++++ inference.py | 39 +++++++++ install.bat | 26 ++++++ install.sh | 30 +++++++ requirements.txt | 2 +- 7 files changed, 246 insertions(+), 5 deletions(-) create mode 100644 finetune.py create mode 100644 inference.py create mode 100644 install.bat create mode 100644 install.sh diff --git a/GPTQ-for-LLaMa/autograd_4bit.py b/GPTQ-for-LLaMa/autograd_4bit.py index dc25d08..c55b7d3 100644 --- a/GPTQ-for-LLaMa/autograd_4bit.py +++ b/GPTQ-for-LLaMa/autograd_4bit.py @@ -198,18 +198,18 @@ def model_to_float(model): def load_llama_model_4bit_low_ram(config_path, model_path, half=False): import transformers import accelerate - from transformers import LLaMAConfig, LLaMAForCausalLM, LLaMATokenizer + from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer from modelutils import find_layers print("Loading Model ...") t0 = time.time() with accelerate.init_empty_weights(): - config = LLaMAConfig.from_pretrained(config_path) + config = LlamaConfig.from_pretrained(config_path) torch.set_default_dtype(torch.half) transformers.modeling_utils._init_weights = False torch.set_default_dtype(torch.half) - model = LLaMAForCausalLM(config) + model = LlamaForCausalLM(config) torch.set_default_dtype(torch.float) model = model.eval() layers = find_layers(model) @@ -224,7 +224,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False): if half: model_to_half(model) - tokenizer = LLaMATokenizer.from_pretrained(config_path) + tokenizer = LlamaTokenizer.from_pretrained(config_path) tokenizer.truncation_side = 'left' print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") diff --git a/README.md b/README.md index 8174a4f..77f0868 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,8 @@ Made some adjust for the code in peft and gptq for llama, and make it possible f
Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
+Added install script for windows and linux. +
# Requirements gptq-for-llama: https://github.com/qwopqwop200/GPTQ-for-LLaMa
peft: https://github.com/huggingface/peft.git
diff --git a/finetune.py b/finetune.py new file mode 100644 index 0000000..7047a5c --- /dev/null +++ b/finetune.py @@ -0,0 +1,144 @@ +import os +import sys +sys.path.insert(0, './repository/transformers/src') +sys.path.insert(0, './repository/GPTQ-for-LLaMa') +sys.path.insert(0, './repository/peft/src') + +import peft +import peft.tuners.lora +assert peft.tuners.lora.is_gptq_available() + +import time +import torch +import transformers +from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer +import accelerate +from modelutils import find_layers +from autograd_4bit import make_quant_for_4bit_autograd +from autograd_4bit import load_llama_model_4bit_low_ram +from datasets import load_dataset, Dataset +import json +from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel + + +# Parameters +DATA_PATH = "./data.txt" +OUTPUT_DIR = "alpaca_lora" +lora_path_old = '' +config_path = './llama-13b-4bit/' +model_path = './llama-13b-4bit.pt' + +MICRO_BATCH_SIZE = 1 +BATCH_SIZE = 2 +GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE +EPOCHS = 3 +LEARNING_RATE = 2e-4 +CUTOFF_LEN = 256 +LORA_R = 8 +LORA_ALPHA = 16 +LORA_DROPOUT = 0.05 +VAL_SET_SIZE = 0 +TARGET_MODULES = [ + "q_proj", + "v_proj", +] +warmup_steps = 50 +save_steps = 50 +save_total_limit = 3 +logging_steps = 10 + +# Load Basic Model +model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path) + +# Config Lora +config = LoraConfig( + r=LORA_R, + lora_alpha=LORA_ALPHA, + target_modules=["q_proj", "v_proj"], + lora_dropout=LORA_DROPOUT, + bias="none", + task_type="CAUSAL_LM", +) +if lora_path_old == '': + model = get_peft_model(model, config) +else: + model = PeftModel.from_pretrained(model, lora_path_old) + print(lora_path_old, 'loaded') + +# Scales to half +print('Fitting 4bit scales and zeros to half') +for n, m in model.named_modules(): + if '4bit' in str(type(m)): + m.zeros = m.zeros.half() + m.scales = m.scales.half() + +# Set tokenizer +tokenizer.pad_token_id = 0 + +# Load Data +with open(DATA_PATH, 'r', encoding='utf8') as file: + txt = file.read() +txt = txt.replace('\r\n', '\n') +rows = [r for r in txt.split('\n') if r != ''] +data = Dataset.from_dict({"input": rows}) +exceed_count = 0 +def tokenize(prompt): + # there's probably a way to do this with the tokenizer settings + # but again, gotta move fast + global exceed_count + prompt = prompt['input'] + result = tokenizer( + prompt, + truncation=True, + max_length=CUTOFF_LEN + 1, + padding="max_length", + ) + d = { + "input_ids": result["input_ids"][:-1], + "attention_mask": result["attention_mask"][:-1], + } + if sum(d['attention_mask']) >= CUTOFF_LEN: + exceed_count += 1 + return d +data = data.shuffle().map(lambda x: tokenize(x)) +print('Train Data: {:.2f}%'.format(exceed_count / len(data) * 100), 'outliers') +train_data = data + +trainer = transformers.Trainer( + model=model, + train_dataset=train_data, + args=transformers.TrainingArguments( + per_device_train_batch_size=MICRO_BATCH_SIZE, + gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, + warmup_steps=warmup_steps, + num_train_epochs=EPOCHS, + learning_rate=LEARNING_RATE, + fp16=True, + logging_steps=logging_steps, + evaluation_strategy="no", + save_strategy="steps", + eval_steps=None, + save_steps=save_steps, + output_dir=OUTPUT_DIR, + save_total_limit=save_total_limit, + load_best_model_at_end=False + ), + data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), +) +model.config.use_cache = False + +# Set Model dict +old_state_dict = model.state_dict +model.state_dict = ( + lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) +).__get__(model, type(model)) + +# Run Trainer +trainer.train() + +print('Train completed.') + +# Save Model +model.save_pretrained(OUTPUT_DIR) + +print('Model Saved.') diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..0a25272 --- /dev/null +++ b/inference.py @@ -0,0 +1,39 @@ +import os +import sys +sys.path.insert(0, './repository/transformers/src') +sys.path.insert(0, './repository/GPTQ-for-LLaMa') +sys.path.insert(0, './repository/peft/src') +import time +import torch +from autograd_4bit import load_llama_model_4bit_low_ram +config_path = './llama-13b-4bit/' +model_path = './llama-13b-4bit.pt' +model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path) + +print('Fitting 4bit scales and zeros to half') +for n, m in model.named_modules(): + if '4bit' in str(type(m)): + m.zeros = m.zeros.half() + m.scales = m.scales.half() + +prompt = '''I think the meaning of life is''' +batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) +batch = {k: v.cuda() for k, v in batch.items()} + +start = time.time() +with torch.no_grad(): + generated = model.generate(inputs=batch["input_ids"], + do_sample=True, use_cache=True, + repetition_penalty=1.1, + max_new_tokens=20, + temperature=0.9, + top_p=0.95, + top_k=40, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False) +result_text = tokenizer.decode(generated['sequences'].cpu().tolist()[0]) +end = time.time() +print(result_text) +print(end - start) diff --git a/install.bat b/install.bat new file mode 100644 index 0000000..bf5e817 --- /dev/null +++ b/install.bat @@ -0,0 +1,26 @@ +REM This is a install script for Alpaca_LoRA_4bit + +REM makedir ./repository/ if not exists +if not exist .\repository mkdir .\repository + +REM Clone repos into current repository into ./repository/ +git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git ./repository/GPTQ-for-LLaMa +git clone https://github.com/huggingface/peft.git ./repository/peft +git clone https://github.com/huggingface/transformers.git ./repository/transformers + +REM replace ./repository/peft/src/peft/tuners/lora.py with ./peft/tuners/lora.py +copy .\peft\tuners\lora.py .\repository\peft\src\peft\tuners\lora.py /Y + +REM replace ./repository/GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu with ./GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu +copy .\GPTQ-for-LLaMa\quant_cuda.cpp .\repository\GPTQ-for-LLaMa\quant_cuda.cpp /Y +copy .\GPTQ-for-LLaMa\quant_cuda_kernel.cu .\repository\GPTQ-for-LLaMa\quant_cuda_kernel.cu /Y + +REM copy autograd_4bit.py into ./repository/GPTQ-for-LLaMa/autograd_4bit.py +copy .\GPTQ-for-LLaMa\autograd_4bit.py .\repository\GPTQ-for-LLaMa\autograd_4bit.py /Y + +REM install quant_cuda +cd .\repository\GPTQ-for-LLaMa +python setup_cuda.py install + +echo "Install finished" +@pause diff --git a/install.sh b/install.sh new file mode 100644 index 0000000..262f951 --- /dev/null +++ b/install.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +# This is an install script for Alpaca_LoRA_4bit + +# makedir ./repository/ if not exists +if [ ! -d "./repository" ]; then + mkdir ./repository +fi + +# Clone repos into current repository into ./repository/ +git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git ./repository/GPTQ-for-LLaMa +git clone https://github.com/huggingface/peft.git ./repository/peft +git clone https://github.com/huggingface/transformers.git ./repository/transformers + +# Replace ./repository/peft/src/peft/tuners/lora.py with ./peft/tuners/lora.py +cp ./peft/tuners/lora.py ./repository/peft/src/peft/tuners/lora.py + +# Replace ./repository/GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu with ./GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu +cp ./GPTQ-for-LLaMa/quant_cuda.cpp ./repository/GPTQ-for-LLaMa/quant_cuda.cpp +cp ./GPTQ-for-LLaMa/quant_cuda_kernel.cu ./repository/GPTQ-for-LLaMa/quant_cuda_kernel.cu + +# Copy autograd_4bit.py into ./repository/GPTQ-for-LLaMa/autograd_4bit.py +cp ./autograd_4bit.py ./repository/GPTQ-for-LLaMa/autograd_4bit.py + +# Install quant_cuda and cd into ./repository/GPTQ-for-LLaMa +cd ./repository/GPTQ-for-LLaMa +python setup_cuda.py install + +echo "Install finished" +read -p "Press [Enter] to continue..." diff --git a/requirements.txt b/requirements.txt index 1f7dcbe..b0fca51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torch accelerate bitsandbytes -git+https://github.com/zphang/transformers@llama_push +git+https://github.com/huggingface/transformers.git git+https://github.com/qwopqwop200/GPTQ-for-LLaMa.git git+https://github.com/huggingface/peft.git