From 44978669cf824f5fcad948b808b19c339ac7fa61 Mon Sep 17 00:00:00 2001 From: John Smith Date: Thu, 23 Mar 2023 08:25:29 +0000 Subject: [PATCH 1/5] Add gradient checkpointing --- GPTQ-for-LLaMa/gradient_checkpointing.py | 61 ++++++++++++++++++++++++ README.md | 13 +++-- finetune.py | 8 ++++ 3 files changed, 79 insertions(+), 3 deletions(-) create mode 100644 GPTQ-for-LLaMa/gradient_checkpointing.py diff --git a/GPTQ-for-LLaMa/gradient_checkpointing.py b/GPTQ-for-LLaMa/gradient_checkpointing.py new file mode 100644 index 0000000..b75fd2c --- /dev/null +++ b/GPTQ-for-LLaMa/gradient_checkpointing.py @@ -0,0 +1,61 @@ +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from torch.utils.checkpoint import checkpoint +from torch.autograd import Variable +import torch +from torch import nn +import numpy as np + + +class NewForward: + + def __init__(self, layer): + self.layer = layer + self.apply_patch() + + def apply_patch(self): + self.layer.old_forward_for_cp = self.layer.forward + self.layer.forward = self.new_forward + + def new_forward(self, *args, **kwargs): + def func(*args): + return self.layer.old_forward_for_cp(*args, **kwargs) + output = checkpoint(func, *args) + return output + + +class VarWrapper: + + def __init__(self, model): + self.model = model + self.apply_patch() + print('Var Wrapper Patch Applied') + + def apply_patch(self): + self.model.old_forward_for_cp = self.model.forward + self.model.forward = self.new_forward + + def new_forward(self, *args, **kwargs): + out = self.model.old_forward_for_cp(*args, **kwargs) + out = Variable(out.data, requires_grad=True) + return out + + +def apply_gradient_checkpointing(model, checkpoint_ratio=1): + new_forwards = [] + modules = [] + for n, m in model.named_modules(): + if isinstance(m, LlamaDecoderLayer): + modules.append(m) + if checkpoint_ratio < 1 and checkpoint_ratio > 0: + checkpoint_locs = np.array((np.linspace(0, 1, int(len(modules) * checkpoint_ratio)) * (len(modules)-1)).round(), dtype=int) + else: + checkpoint_locs = np.arange(len(modules)) + for i in checkpoint_locs: + m = modules[i] + new_forwards.append(NewForward(m)) + print('Forward Patch Applied For Block {}'.format(i)) + for n, m in model.named_modules(): + if isinstance(m, torch.nn.Embedding): + wrapper = VarWrapper(m) + break + return new_forwards, wrapper diff --git a/README.md b/README.md index c926eac..9c3a46d 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,18 @@ # Alpaca Lora 4bit Made some adjust for the code in peft and gptq for llama, and make it possible for lora finetuning with a 4 bits base model. The same adjustment can be made for 2, 3 and 8 bits.
-~Still numerically unstable.~ Resolved. +* Install Manual by s4rduk4r: https://github.com/s4rduk4r/alpaca_lora_4bit_readme/blob/main/README.md + +# Update Logs +* Resolved numerically unstable issue
-Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed. +* Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
-Added install script for windows and linux. +* Added install script for windows and linux. +
+* Added Gradient Checkpointing. Now It can finetune 30b model 4bit on a single GPU with 24G VRAM. (finetune.py updated) +
+* Added install manual by s4rduk4r
# Requirements diff --git a/finetune.py b/finetune.py index 7047a5c..075ef68 100644 --- a/finetune.py +++ b/finetune.py @@ -42,6 +42,8 @@ TARGET_MODULES = [ "q_proj", "v_proj", ] +GRADIENT_CHECKPOINTING = False +GRADIENT_CHECKPOINTING_RATIO = 1 warmup_steps = 50 save_steps = 50 save_total_limit = 3 @@ -104,6 +106,12 @@ data = data.shuffle().map(lambda x: tokenize(x)) print('Train Data: {:.2f}%'.format(exceed_count / len(data) * 100), 'outliers') train_data = data +# Use gradient checkpointing +if GRADIENT_CHECKPOINTING: + print('Applying gradient checkpointing ...') + from gradient_checkpointing import apply_gradient_checkpointing + apply_gradient_checkpointing(model, checkpoint_ratio=GRADIENT_CHECKPOINTING_RATIO) + trainer = transformers.Trainer( model=model, train_dataset=train_data, From 619a177fbbdc00bee58628f869152f958cbe0457 Mon Sep 17 00:00:00 2001 From: John Smith Date: Thu, 23 Mar 2023 16:31:49 +0800 Subject: [PATCH 2/5] Update README.md --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9c3a46d..7f64534 100644 --- a/README.md +++ b/README.md @@ -6,12 +6,16 @@ Made some adjust for the code in peft and gptq for llama, and make it possible f # Update Logs * Resolved numerically unstable issue
+ * Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
+ * Added install script for windows and linux.
-* Added Gradient Checkpointing. Now It can finetune 30b model 4bit on a single GPU with 24G VRAM. (finetune.py updated) + +* Added Gradient Checkpointing. Now It can finetune 30b model 4bit on a single GPU with 24G VRAM with Gradient Checkpointing enabled. (finetune.py updated) (but would reduce training speed, so if having enough VRAM this option is not needed)
+ * Added install manual by s4rduk4r
From 60b227d0ba9ff78378afc8703028448be5bee2f3 Mon Sep 17 00:00:00 2001 From: John Smith Date: Thu, 23 Mar 2023 08:43:18 +0000 Subject: [PATCH 3/5] fix minor bug --- finetune.py | 6 +++++- install.bat | 3 ++- install.sh | 3 ++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/finetune.py b/finetune.py index 075ef68..72bfd6b 100644 --- a/finetune.py +++ b/finetune.py @@ -36,7 +36,7 @@ LEARNING_RATE = 2e-4 CUTOFF_LEN = 256 LORA_R = 8 LORA_ALPHA = 16 -LORA_DROPOUT = 0.05 +LORA_DROPOUT = 0.05 # should be 0 if gradient checkpointing is on VAL_SET_SIZE = 0 TARGET_MODULES = [ "q_proj", @@ -49,6 +49,10 @@ save_steps = 50 save_total_limit = 3 logging_steps = 10 +if LORA_DROPOUT > 0 and GRADIENT_CHECKPOINTING: + LORA_DROPOUT = 0 + print('Disable Dropout.') + # Load Basic Model model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path) diff --git a/install.bat b/install.bat index bf5e817..bd79a55 100644 --- a/install.bat +++ b/install.bat @@ -15,8 +15,9 @@ REM replace ./repository/GPTQ-for-LLaMa/quant_cuda.cpp and quant_cuda_kernel.cu copy .\GPTQ-for-LLaMa\quant_cuda.cpp .\repository\GPTQ-for-LLaMa\quant_cuda.cpp /Y copy .\GPTQ-for-LLaMa\quant_cuda_kernel.cu .\repository\GPTQ-for-LLaMa\quant_cuda_kernel.cu /Y -REM copy autograd_4bit.py into ./repository/GPTQ-for-LLaMa/autograd_4bit.py +REM copy files into ./repository/GPTQ-for-LLaMa/ copy .\GPTQ-for-LLaMa\autograd_4bit.py .\repository\GPTQ-for-LLaMa\autograd_4bit.py /Y +copy .\GPTQ-for-LLaMa\gradient_checkpointing.py .\repository\GPTQ-for-LLaMa\gradient_checkpointing.py /Y REM install quant_cuda cd .\repository\GPTQ-for-LLaMa diff --git a/install.sh b/install.sh index be4c802..fd5a67c 100644 --- a/install.sh +++ b/install.sh @@ -19,8 +19,9 @@ cp ./peft/tuners/lora.py ./repository/peft/src/peft/tuners/lora.py cp ./GPTQ-for-LLaMa/quant_cuda.cpp ./repository/GPTQ-for-LLaMa/quant_cuda.cpp cp ./GPTQ-for-LLaMa/quant_cuda_kernel.cu ./repository/GPTQ-for-LLaMa/quant_cuda_kernel.cu -# Copy autograd_4bit.py into ./repository/GPTQ-for-LLaMa/autograd_4bit.py +# Copy files into ./repository/GPTQ-for-LLaMa/ cp ./GPTQ-for-LLaMa/autograd_4bit.py ./repository/GPTQ-for-LLaMa/autograd_4bit.py +cp ./GPTQ-for-LLaMa/gradient_checkpointing.py ./repository/GPTQ-for-LLaMa/gradient_checkpointing.py # Install quant_cuda and cd into ./repository/GPTQ-for-LLaMa cd ./repository/GPTQ-for-LLaMa From 58998acc9fd50b1e3e8068f4afb867467dc84727 Mon Sep 17 00:00:00 2001 From: Forkoz <59298527+Ph0rk0z@users.noreply.github.com> Date: Thu, 23 Mar 2023 07:33:57 -0500 Subject: [PATCH 4/5] Fix cuda kernel for Pascal & Cuda 6/6.1 When I left the other functions to use normal atomic add it seemed like a small speedup. 4.79 it/s vs 5.23 it/s --- GPTQ-for-LLaMa/quant_cuda_kernel.cu | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/GPTQ-for-LLaMa/quant_cuda_kernel.cu b/GPTQ-for-LLaMa/quant_cuda_kernel.cu index 0077650..de0c0d6 100644 --- a/GPTQ-for-LLaMa/quant_cuda_kernel.cu +++ b/GPTQ-for-LLaMa/quant_cuda_kernel.cu @@ -4,8 +4,10 @@ #include #include +#ifdef __CUDA_ARCH__ +#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600 // adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh -__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { +__device__ __forceinline__ void atomicAddHalf(__half* address, c10::Half val) { unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); unsigned int old = *address_as_ui; unsigned int assumed; @@ -22,6 +24,8 @@ __device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) } while (assumed != old); } +#endif +#endif template __global__ void VecQuant2MatMulKernel( @@ -543,7 +547,14 @@ __global__ void VecQuant4MatMulHalfKernel( } __half* mul2 = (__half*)mul; +#ifdef __CUDA_ARCH__ +#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600 + atomicAddHalf(&mul2[b * width + w], res); +#else atomicAdd(&mul2[b * width + w], res); +#endif +#endif + } void vecquant4matmul_half_cuda( @@ -616,7 +627,13 @@ __global__ void VecQuant4TransposeMatMulHalfKernel( } __half* mul2 = (__half*)mul; - atomicAdd(&mul2[n_cols * height * 8 + n_rows], res); +#ifdef __CUDA_ARCH__ +#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600 + atomicAddHalf(&mul2[n_cols * height * 8 + n_rows], res); +#else + atomicAddHalf(&mul2[n_cols * height * 8 + n_rows], res); +#endif +#endif } void vecquant4transposematmul_half_cuda( From 4906961bf1fd22a5e44b27f926bf3b12775ef384 Mon Sep 17 00:00:00 2001 From: John Smith Date: Thu, 23 Mar 2023 23:37:39 +0800 Subject: [PATCH 5/5] fix bug --- GPTQ-for-LLaMa/quant_cuda_kernel.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GPTQ-for-LLaMa/quant_cuda_kernel.cu b/GPTQ-for-LLaMa/quant_cuda_kernel.cu index de0c0d6..ed9627a 100644 --- a/GPTQ-for-LLaMa/quant_cuda_kernel.cu +++ b/GPTQ-for-LLaMa/quant_cuda_kernel.cu @@ -631,7 +631,7 @@ __global__ void VecQuant4TransposeMatMulHalfKernel( #if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600 atomicAddHalf(&mul2[n_cols * height * 8 + n_rows], res); #else - atomicAddHalf(&mul2[n_cols * height * 8 + n_rows], res); + atomicAdd(&mul2[n_cols * height * 8 + n_rows], res); #endif #endif }