Add gradient checkpointing

This commit is contained in:
John Smith 2023-03-23 08:25:29 +00:00
parent eb8ce878d4
commit 44978669cf
3 changed files with 79 additions and 3 deletions

View File

@ -0,0 +1,61 @@
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from torch.utils.checkpoint import checkpoint
from torch.autograd import Variable
import torch
from torch import nn
import numpy as np
class NewForward:
def __init__(self, layer):
self.layer = layer
self.apply_patch()
def apply_patch(self):
self.layer.old_forward_for_cp = self.layer.forward
self.layer.forward = self.new_forward
def new_forward(self, *args, **kwargs):
def func(*args):
return self.layer.old_forward_for_cp(*args, **kwargs)
output = checkpoint(func, *args)
return output
class VarWrapper:
def __init__(self, model):
self.model = model
self.apply_patch()
print('Var Wrapper Patch Applied')
def apply_patch(self):
self.model.old_forward_for_cp = self.model.forward
self.model.forward = self.new_forward
def new_forward(self, *args, **kwargs):
out = self.model.old_forward_for_cp(*args, **kwargs)
out = Variable(out.data, requires_grad=True)
return out
def apply_gradient_checkpointing(model, checkpoint_ratio=1):
new_forwards = []
modules = []
for n, m in model.named_modules():
if isinstance(m, LlamaDecoderLayer):
modules.append(m)
if checkpoint_ratio < 1 and checkpoint_ratio > 0:
checkpoint_locs = np.array((np.linspace(0, 1, int(len(modules) * checkpoint_ratio)) * (len(modules)-1)).round(), dtype=int)
else:
checkpoint_locs = np.arange(len(modules))
for i in checkpoint_locs:
m = modules[i]
new_forwards.append(NewForward(m))
print('Forward Patch Applied For Block {}'.format(i))
for n, m in model.named_modules():
if isinstance(m, torch.nn.Embedding):
wrapper = VarWrapper(m)
break
return new_forwards, wrapper

View File

@ -1,11 +1,18 @@
# Alpaca Lora 4bit
Made some adjust for the code in peft and gptq for llama, and make it possible for lora finetuning with a 4 bits base model. The same adjustment can be made for 2, 3 and 8 bits.
<br>
~Still numerically unstable.~ Resolved.
* Install Manual by s4rduk4r: https://github.com/s4rduk4r/alpaca_lora_4bit_readme/blob/main/README.md
# Update Logs
* Resolved numerically unstable issue
<br>
Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
* Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
<br>
Added install script for windows and linux.
* Added install script for windows and linux.
<br>
* Added Gradient Checkpointing. Now It can finetune 30b model 4bit on a single GPU with 24G VRAM. (finetune.py updated)
<br>
* Added install manual by s4rduk4r
<br>
# Requirements

View File

@ -42,6 +42,8 @@ TARGET_MODULES = [
"q_proj",
"v_proj",
]
GRADIENT_CHECKPOINTING = False
GRADIENT_CHECKPOINTING_RATIO = 1
warmup_steps = 50
save_steps = 50
save_total_limit = 3
@ -104,6 +106,12 @@ data = data.shuffle().map(lambda x: tokenize(x))
print('Train Data: {:.2f}%'.format(exceed_count / len(data) * 100), 'outliers')
train_data = data
# Use gradient checkpointing
if GRADIENT_CHECKPOINTING:
print('Applying gradient checkpointing ...')
from gradient_checkpointing import apply_gradient_checkpointing
apply_gradient_checkpointing(model, checkpoint_ratio=GRADIENT_CHECKPOINTING_RATIO)
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,