diff --git a/GPTQ-for-LLaMa/gradient_checkpointing.py b/GPTQ-for-LLaMa/gradient_checkpointing.py new file mode 100644 index 0000000..b75fd2c --- /dev/null +++ b/GPTQ-for-LLaMa/gradient_checkpointing.py @@ -0,0 +1,61 @@ +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from torch.utils.checkpoint import checkpoint +from torch.autograd import Variable +import torch +from torch import nn +import numpy as np + + +class NewForward: + + def __init__(self, layer): + self.layer = layer + self.apply_patch() + + def apply_patch(self): + self.layer.old_forward_for_cp = self.layer.forward + self.layer.forward = self.new_forward + + def new_forward(self, *args, **kwargs): + def func(*args): + return self.layer.old_forward_for_cp(*args, **kwargs) + output = checkpoint(func, *args) + return output + + +class VarWrapper: + + def __init__(self, model): + self.model = model + self.apply_patch() + print('Var Wrapper Patch Applied') + + def apply_patch(self): + self.model.old_forward_for_cp = self.model.forward + self.model.forward = self.new_forward + + def new_forward(self, *args, **kwargs): + out = self.model.old_forward_for_cp(*args, **kwargs) + out = Variable(out.data, requires_grad=True) + return out + + +def apply_gradient_checkpointing(model, checkpoint_ratio=1): + new_forwards = [] + modules = [] + for n, m in model.named_modules(): + if isinstance(m, LlamaDecoderLayer): + modules.append(m) + if checkpoint_ratio < 1 and checkpoint_ratio > 0: + checkpoint_locs = np.array((np.linspace(0, 1, int(len(modules) * checkpoint_ratio)) * (len(modules)-1)).round(), dtype=int) + else: + checkpoint_locs = np.arange(len(modules)) + for i in checkpoint_locs: + m = modules[i] + new_forwards.append(NewForward(m)) + print('Forward Patch Applied For Block {}'.format(i)) + for n, m in model.named_modules(): + if isinstance(m, torch.nn.Embedding): + wrapper = VarWrapper(m) + break + return new_forwards, wrapper diff --git a/GPTQ-for-LLaMa/quant_cuda_kernel.cu b/GPTQ-for-LLaMa/quant_cuda_kernel.cu index 0077650..ed9627a 100644 --- a/GPTQ-for-LLaMa/quant_cuda_kernel.cu +++ b/GPTQ-for-LLaMa/quant_cuda_kernel.cu @@ -4,8 +4,10 @@ #include #include +#ifdef __CUDA_ARCH__ +#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600 // adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh -__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { +__device__ __forceinline__ void atomicAddHalf(__half* address, c10::Half val) { unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); unsigned int old = *address_as_ui; unsigned int assumed; @@ -22,6 +24,8 @@ __device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) } while (assumed != old); } +#endif +#endif template __global__ void VecQuant2MatMulKernel( @@ -543,7 +547,14 @@ __global__ void VecQuant4MatMulHalfKernel( } __half* mul2 = (__half*)mul; +#ifdef __CUDA_ARCH__ +#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600 + atomicAddHalf(&mul2[b * width + w], res); +#else atomicAdd(&mul2[b * width + w], res); +#endif +#endif + } void vecquant4matmul_half_cuda( @@ -616,7 +627,13 @@ __global__ void VecQuant4TransposeMatMulHalfKernel( } __half* mul2 = (__half*)mul; +#ifdef __CUDA_ARCH__ +#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600 + atomicAddHalf(&mul2[n_cols * height * 8 + n_rows], res); +#else atomicAdd(&mul2[n_cols * height * 8 + n_rows], res); +#endif +#endif } void vecquant4transposematmul_half_cuda( diff --git a/README.md b/README.md index c926eac..7f64534 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,22 @@ # Alpaca Lora 4bit Made some adjust for the code in peft and gptq for llama, and make it possible for lora finetuning with a 4 bits base model. The same adjustment can be made for 2, 3 and 8 bits.
-~Still numerically unstable.~ Resolved. +* Install Manual by s4rduk4r: https://github.com/s4rduk4r/alpaca_lora_4bit_readme/blob/main/README.md + +# Update Logs +* Resolved numerically unstable issue
-Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed. + +* Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
-Added install script for windows and linux. + +* Added install script for windows and linux. +
+ +* Added Gradient Checkpointing. Now It can finetune 30b model 4bit on a single GPU with 24G VRAM with Gradient Checkpointing enabled. (finetune.py updated) (but would reduce training speed, so if having enough VRAM this option is not needed) +
+ +* Added install manual by s4rduk4r
# Requirements diff --git a/finetune.py b/finetune.py index 41dc376..6224f2e 100644 --- a/finetune.py +++ b/finetune.py @@ -127,10 +127,10 @@ if not ft_config.skip: print('Train completed.') -if not ft_config.checkpoint: - # Save Model - model.save_pretrained(ft_config.lora_out_dir) -else: - raise NotImplemented("TODO: Merge model + LoRA and save the whole checkpoint") +# Save Model +model.save_pretrained(ft_config.lora_out_dir) -print('Model Saved.') +if ft_config.checkpoint: + print("Warning: Merge model + LoRA and save the whole checkpoint not implemented yet.") + +print('Model Saved.') \ No newline at end of file diff --git a/install.bat b/install.bat index bf5e817..bd79a55 100644 --- a/install.bat +++ b/install.bat @@ -15,8 +15,9 @@ REM replace ./repository/GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu copy .\GPTQ-for-LLaMa\quant_cuda.cpp .\repository\GPTQ-for-LLaMa\quant_cuda.cpp /Y copy .\GPTQ-for-LLaMa\quant_cuda_kernel.cu .\repository\GPTQ-for-LLaMa\quant_cuda_kernel.cu /Y -REM copy autograd_4bit.py into ./repository/GPTQ-for-LLaMa/autograd_4bit.py +REM copy files into ./repository/GPTQ-for-LLaMa/ copy .\GPTQ-for-LLaMa\autograd_4bit.py .\repository\GPTQ-for-LLaMa\autograd_4bit.py /Y +copy .\GPTQ-for-LLaMa\gradient_checkpointing.py .\repository\GPTQ-for-LLaMa\gradient_checkpointing.py /Y REM install quant_cuda cd .\repository\GPTQ-for-LLaMa diff --git a/install.sh b/install.sh index be4c802..fd5a67c 100644 --- a/install.sh +++ b/install.sh @@ -19,8 +19,9 @@ cp ./peft/tuners/lora.py ./repository/peft/src/peft/tuners/lora.py cp ./GPTQ-for-LLaMa/quant_cuda.cpp ./repository/GPTQ-for-LLaMa/quant_cuda.cpp cp ./GPTQ-for-LLaMa/quant_cuda_kernel.cu ./repository/GPTQ-for-LLaMa/quant_cuda_kernel.cu -# Copy autograd_4bit.py into ./repository/GPTQ-for-LLaMa/autograd_4bit.py +# Copy files into ./repository/GPTQ-for-LLaMa/ cp ./GPTQ-for-LLaMa/autograd_4bit.py ./repository/GPTQ-for-LLaMa/autograd_4bit.py +cp ./GPTQ-for-LLaMa/gradient_checkpointing.py ./repository/GPTQ-for-LLaMa/gradient_checkpointing.py # Install quant_cuda and cd into ./repository/GPTQ-for-LLaMa cd ./repository/GPTQ-for-LLaMa