Merge branch 'main' into finetune-refactor
This commit is contained in:
commit
0879580006
|
|
@ -0,0 +1,61 @@
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class NewForward:
|
||||||
|
|
||||||
|
def __init__(self, layer):
|
||||||
|
self.layer = layer
|
||||||
|
self.apply_patch()
|
||||||
|
|
||||||
|
def apply_patch(self):
|
||||||
|
self.layer.old_forward_for_cp = self.layer.forward
|
||||||
|
self.layer.forward = self.new_forward
|
||||||
|
|
||||||
|
def new_forward(self, *args, **kwargs):
|
||||||
|
def func(*args):
|
||||||
|
return self.layer.old_forward_for_cp(*args, **kwargs)
|
||||||
|
output = checkpoint(func, *args)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class VarWrapper:
|
||||||
|
|
||||||
|
def __init__(self, model):
|
||||||
|
self.model = model
|
||||||
|
self.apply_patch()
|
||||||
|
print('Var Wrapper Patch Applied')
|
||||||
|
|
||||||
|
def apply_patch(self):
|
||||||
|
self.model.old_forward_for_cp = self.model.forward
|
||||||
|
self.model.forward = self.new_forward
|
||||||
|
|
||||||
|
def new_forward(self, *args, **kwargs):
|
||||||
|
out = self.model.old_forward_for_cp(*args, **kwargs)
|
||||||
|
out = Variable(out.data, requires_grad=True)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def apply_gradient_checkpointing(model, checkpoint_ratio=1):
|
||||||
|
new_forwards = []
|
||||||
|
modules = []
|
||||||
|
for n, m in model.named_modules():
|
||||||
|
if isinstance(m, LlamaDecoderLayer):
|
||||||
|
modules.append(m)
|
||||||
|
if checkpoint_ratio < 1 and checkpoint_ratio > 0:
|
||||||
|
checkpoint_locs = np.array((np.linspace(0, 1, int(len(modules) * checkpoint_ratio)) * (len(modules)-1)).round(), dtype=int)
|
||||||
|
else:
|
||||||
|
checkpoint_locs = np.arange(len(modules))
|
||||||
|
for i in checkpoint_locs:
|
||||||
|
m = modules[i]
|
||||||
|
new_forwards.append(NewForward(m))
|
||||||
|
print('Forward Patch Applied For Block {}'.format(i))
|
||||||
|
for n, m in model.named_modules():
|
||||||
|
if isinstance(m, torch.nn.Embedding):
|
||||||
|
wrapper = VarWrapper(m)
|
||||||
|
break
|
||||||
|
return new_forwards, wrapper
|
||||||
|
|
@ -4,8 +4,10 @@
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#ifdef __CUDA_ARCH__
|
||||||
|
#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600
|
||||||
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
|
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
|
||||||
__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
|
__device__ __forceinline__ void atomicAddHalf(__half* address, c10::Half val) {
|
||||||
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
|
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
|
||||||
unsigned int old = *address_as_ui;
|
unsigned int old = *address_as_ui;
|
||||||
unsigned int assumed;
|
unsigned int assumed;
|
||||||
|
|
@ -22,6 +24,8 @@ __device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
|
||||||
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
||||||
} while (assumed != old);
|
} while (assumed != old);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void VecQuant2MatMulKernel(
|
__global__ void VecQuant2MatMulKernel(
|
||||||
|
|
@ -543,7 +547,14 @@ __global__ void VecQuant4MatMulHalfKernel(
|
||||||
}
|
}
|
||||||
|
|
||||||
__half* mul2 = (__half*)mul;
|
__half* mul2 = (__half*)mul;
|
||||||
|
#ifdef __CUDA_ARCH__
|
||||||
|
#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600
|
||||||
|
atomicAddHalf(&mul2[b * width + w], res);
|
||||||
|
#else
|
||||||
atomicAdd(&mul2[b * width + w], res);
|
atomicAdd(&mul2[b * width + w], res);
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void vecquant4matmul_half_cuda(
|
void vecquant4matmul_half_cuda(
|
||||||
|
|
@ -616,7 +627,13 @@ __global__ void VecQuant4TransposeMatMulHalfKernel(
|
||||||
}
|
}
|
||||||
|
|
||||||
__half* mul2 = (__half*)mul;
|
__half* mul2 = (__half*)mul;
|
||||||
|
#ifdef __CUDA_ARCH__
|
||||||
|
#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600
|
||||||
|
atomicAddHalf(&mul2[n_cols * height * 8 + n_rows], res);
|
||||||
|
#else
|
||||||
atomicAdd(&mul2[n_cols * height * 8 + n_rows], res);
|
atomicAdd(&mul2[n_cols * height * 8 + n_rows], res);
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void vecquant4transposematmul_half_cuda(
|
void vecquant4transposematmul_half_cuda(
|
||||||
|
|
|
||||||
17
README.md
17
README.md
|
|
@ -1,11 +1,22 @@
|
||||||
# Alpaca Lora 4bit
|
# Alpaca Lora 4bit
|
||||||
Made some adjust for the code in peft and gptq for llama, and make it possible for lora finetuning with a 4 bits base model. The same adjustment can be made for 2, 3 and 8 bits.
|
Made some adjust for the code in peft and gptq for llama, and make it possible for lora finetuning with a 4 bits base model. The same adjustment can be made for 2, 3 and 8 bits.
|
||||||
<br>
|
<br>
|
||||||
~Still numerically unstable.~ Resolved.
|
* Install Manual by s4rduk4r: https://github.com/s4rduk4r/alpaca_lora_4bit_readme/blob/main/README.md
|
||||||
|
|
||||||
|
# Update Logs
|
||||||
|
* Resolved numerically unstable issue
|
||||||
<br>
|
<br>
|
||||||
Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
|
|
||||||
|
* Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
|
||||||
<br>
|
<br>
|
||||||
Added install script for windows and linux.
|
|
||||||
|
* Added install script for windows and linux.
|
||||||
|
<br>
|
||||||
|
|
||||||
|
* Added Gradient Checkpointing. Now It can finetune 30b model 4bit on a single GPU with 24G VRAM with Gradient Checkpointing enabled. (finetune.py updated) (but would reduce training speed, so if having enough VRAM this option is not needed)
|
||||||
|
<br>
|
||||||
|
|
||||||
|
* Added install manual by s4rduk4r
|
||||||
<br>
|
<br>
|
||||||
|
|
||||||
# Requirements
|
# Requirements
|
||||||
|
|
|
||||||
12
finetune.py
12
finetune.py
|
|
@ -127,10 +127,10 @@ if not ft_config.skip:
|
||||||
|
|
||||||
print('Train completed.')
|
print('Train completed.')
|
||||||
|
|
||||||
if not ft_config.checkpoint:
|
# Save Model
|
||||||
# Save Model
|
model.save_pretrained(ft_config.lora_out_dir)
|
||||||
model.save_pretrained(ft_config.lora_out_dir)
|
|
||||||
else:
|
|
||||||
raise NotImplemented("TODO: Merge model + LoRA and save the whole checkpoint")
|
|
||||||
|
|
||||||
print('Model Saved.')
|
if ft_config.checkpoint:
|
||||||
|
print("Warning: Merge model + LoRA and save the whole checkpoint not implemented yet.")
|
||||||
|
|
||||||
|
print('Model Saved.')
|
||||||
|
|
@ -15,8 +15,9 @@ REM replace ./repository/GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu
|
||||||
copy .\GPTQ-for-LLaMa\quant_cuda.cpp .\repository\GPTQ-for-LLaMa\quant_cuda.cpp /Y
|
copy .\GPTQ-for-LLaMa\quant_cuda.cpp .\repository\GPTQ-for-LLaMa\quant_cuda.cpp /Y
|
||||||
copy .\GPTQ-for-LLaMa\quant_cuda_kernel.cu .\repository\GPTQ-for-LLaMa\quant_cuda_kernel.cu /Y
|
copy .\GPTQ-for-LLaMa\quant_cuda_kernel.cu .\repository\GPTQ-for-LLaMa\quant_cuda_kernel.cu /Y
|
||||||
|
|
||||||
REM copy autograd_4bit.py into ./repository/GPTQ-for-LLaMa/autograd_4bit.py
|
REM copy files into ./repository/GPTQ-for-LLaMa/
|
||||||
copy .\GPTQ-for-LLaMa\autograd_4bit.py .\repository\GPTQ-for-LLaMa\autograd_4bit.py /Y
|
copy .\GPTQ-for-LLaMa\autograd_4bit.py .\repository\GPTQ-for-LLaMa\autograd_4bit.py /Y
|
||||||
|
copy .\GPTQ-for-LLaMa\gradient_checkpointing.py .\repository\GPTQ-for-LLaMa\gradient_checkpointing.py /Y
|
||||||
|
|
||||||
REM install quant_cuda
|
REM install quant_cuda
|
||||||
cd .\repository\GPTQ-for-LLaMa
|
cd .\repository\GPTQ-for-LLaMa
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,9 @@ cp ./peft/tuners/lora.py ./repository/peft/src/peft/tuners/lora.py
|
||||||
cp ./GPTQ-for-LLaMa/quant_cuda.cpp ./repository/GPTQ-for-LLaMa/quant_cuda.cpp
|
cp ./GPTQ-for-LLaMa/quant_cuda.cpp ./repository/GPTQ-for-LLaMa/quant_cuda.cpp
|
||||||
cp ./GPTQ-for-LLaMa/quant_cuda_kernel.cu ./repository/GPTQ-for-LLaMa/quant_cuda_kernel.cu
|
cp ./GPTQ-for-LLaMa/quant_cuda_kernel.cu ./repository/GPTQ-for-LLaMa/quant_cuda_kernel.cu
|
||||||
|
|
||||||
# Copy autograd_4bit.py into ./repository/GPTQ-for-LLaMa/autograd_4bit.py
|
# Copy files into ./repository/GPTQ-for-LLaMa/
|
||||||
cp ./GPTQ-for-LLaMa/autograd_4bit.py ./repository/GPTQ-for-LLaMa/autograd_4bit.py
|
cp ./GPTQ-for-LLaMa/autograd_4bit.py ./repository/GPTQ-for-LLaMa/autograd_4bit.py
|
||||||
|
cp ./GPTQ-for-LLaMa/gradient_checkpointing.py ./repository/GPTQ-for-LLaMa/gradient_checkpointing.py
|
||||||
|
|
||||||
# Install quant_cuda and cd into ./repository/GPTQ-for-LLaMa
|
# Install quant_cuda and cd into ./repository/GPTQ-for-LLaMa
|
||||||
cd ./repository/GPTQ-for-LLaMa
|
cd ./repository/GPTQ-for-LLaMa
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue