diff --git a/GPTQ-for-LLaMa/gradient_checkpointing.py b/GPTQ-for-LLaMa/gradient_checkpointing.py
new file mode 100644
index 0000000..b75fd2c
--- /dev/null
+++ b/GPTQ-for-LLaMa/gradient_checkpointing.py
@@ -0,0 +1,61 @@
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from torch.utils.checkpoint import checkpoint
+from torch.autograd import Variable
+import torch
+from torch import nn
+import numpy as np
+
+
+class NewForward:
+
+ def __init__(self, layer):
+ self.layer = layer
+ self.apply_patch()
+
+ def apply_patch(self):
+ self.layer.old_forward_for_cp = self.layer.forward
+ self.layer.forward = self.new_forward
+
+ def new_forward(self, *args, **kwargs):
+ def func(*args):
+ return self.layer.old_forward_for_cp(*args, **kwargs)
+ output = checkpoint(func, *args)
+ return output
+
+
+class VarWrapper:
+
+ def __init__(self, model):
+ self.model = model
+ self.apply_patch()
+ print('Var Wrapper Patch Applied')
+
+ def apply_patch(self):
+ self.model.old_forward_for_cp = self.model.forward
+ self.model.forward = self.new_forward
+
+ def new_forward(self, *args, **kwargs):
+ out = self.model.old_forward_for_cp(*args, **kwargs)
+ out = Variable(out.data, requires_grad=True)
+ return out
+
+
+def apply_gradient_checkpointing(model, checkpoint_ratio=1):
+ new_forwards = []
+ modules = []
+ for n, m in model.named_modules():
+ if isinstance(m, LlamaDecoderLayer):
+ modules.append(m)
+ if checkpoint_ratio < 1 and checkpoint_ratio > 0:
+ checkpoint_locs = np.array((np.linspace(0, 1, int(len(modules) * checkpoint_ratio)) * (len(modules)-1)).round(), dtype=int)
+ else:
+ checkpoint_locs = np.arange(len(modules))
+ for i in checkpoint_locs:
+ m = modules[i]
+ new_forwards.append(NewForward(m))
+ print('Forward Patch Applied For Block {}'.format(i))
+ for n, m in model.named_modules():
+ if isinstance(m, torch.nn.Embedding):
+ wrapper = VarWrapper(m)
+ break
+ return new_forwards, wrapper
diff --git a/README.md b/README.md
index c926eac..9c3a46d 100644
--- a/README.md
+++ b/README.md
@@ -1,11 +1,18 @@
# Alpaca Lora 4bit
Made some adjust for the code in peft and gptq for llama, and make it possible for lora finetuning with a 4 bits base model. The same adjustment can be made for 2, 3 and 8 bits.
-~Still numerically unstable.~ Resolved.
+* Install Manual by s4rduk4r: https://github.com/s4rduk4r/alpaca_lora_4bit_readme/blob/main/README.md
+
+# Update Logs
+* Resolved numerically unstable issue
-Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
+* Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
-Added install script for windows and linux.
+* Added install script for windows and linux.
+
+* Added Gradient Checkpointing. Now It can finetune 30b model 4bit on a single GPU with 24G VRAM. (finetune.py updated)
+
+* Added install manual by s4rduk4r
# Requirements
diff --git a/finetune.py b/finetune.py
index 7047a5c..075ef68 100644
--- a/finetune.py
+++ b/finetune.py
@@ -42,6 +42,8 @@ TARGET_MODULES = [
"q_proj",
"v_proj",
]
+GRADIENT_CHECKPOINTING = False
+GRADIENT_CHECKPOINTING_RATIO = 1
warmup_steps = 50
save_steps = 50
save_total_limit = 3
@@ -104,6 +106,12 @@ data = data.shuffle().map(lambda x: tokenize(x))
print('Train Data: {:.2f}%'.format(exceed_count / len(data) * 100), 'outliers')
train_data = data
+# Use gradient checkpointing
+if GRADIENT_CHECKPOINTING:
+ print('Applying gradient checkpointing ...')
+ from gradient_checkpointing import apply_gradient_checkpointing
+ apply_gradient_checkpointing(model, checkpoint_ratio=GRADIENT_CHECKPOINTING_RATIO)
+
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,