Add gradient checkpointing
This commit is contained in:
parent
eb8ce878d4
commit
44978669cf
|
|
@ -0,0 +1,61 @@
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class NewForward:
|
||||||
|
|
||||||
|
def __init__(self, layer):
|
||||||
|
self.layer = layer
|
||||||
|
self.apply_patch()
|
||||||
|
|
||||||
|
def apply_patch(self):
|
||||||
|
self.layer.old_forward_for_cp = self.layer.forward
|
||||||
|
self.layer.forward = self.new_forward
|
||||||
|
|
||||||
|
def new_forward(self, *args, **kwargs):
|
||||||
|
def func(*args):
|
||||||
|
return self.layer.old_forward_for_cp(*args, **kwargs)
|
||||||
|
output = checkpoint(func, *args)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class VarWrapper:
|
||||||
|
|
||||||
|
def __init__(self, model):
|
||||||
|
self.model = model
|
||||||
|
self.apply_patch()
|
||||||
|
print('Var Wrapper Patch Applied')
|
||||||
|
|
||||||
|
def apply_patch(self):
|
||||||
|
self.model.old_forward_for_cp = self.model.forward
|
||||||
|
self.model.forward = self.new_forward
|
||||||
|
|
||||||
|
def new_forward(self, *args, **kwargs):
|
||||||
|
out = self.model.old_forward_for_cp(*args, **kwargs)
|
||||||
|
out = Variable(out.data, requires_grad=True)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def apply_gradient_checkpointing(model, checkpoint_ratio=1):
|
||||||
|
new_forwards = []
|
||||||
|
modules = []
|
||||||
|
for n, m in model.named_modules():
|
||||||
|
if isinstance(m, LlamaDecoderLayer):
|
||||||
|
modules.append(m)
|
||||||
|
if checkpoint_ratio < 1 and checkpoint_ratio > 0:
|
||||||
|
checkpoint_locs = np.array((np.linspace(0, 1, int(len(modules) * checkpoint_ratio)) * (len(modules)-1)).round(), dtype=int)
|
||||||
|
else:
|
||||||
|
checkpoint_locs = np.arange(len(modules))
|
||||||
|
for i in checkpoint_locs:
|
||||||
|
m = modules[i]
|
||||||
|
new_forwards.append(NewForward(m))
|
||||||
|
print('Forward Patch Applied For Block {}'.format(i))
|
||||||
|
for n, m in model.named_modules():
|
||||||
|
if isinstance(m, torch.nn.Embedding):
|
||||||
|
wrapper = VarWrapper(m)
|
||||||
|
break
|
||||||
|
return new_forwards, wrapper
|
||||||
13
README.md
13
README.md
|
|
@ -1,11 +1,18 @@
|
||||||
# Alpaca Lora 4bit
|
# Alpaca Lora 4bit
|
||||||
Made some adjust for the code in peft and gptq for llama, and make it possible for lora finetuning with a 4 bits base model. The same adjustment can be made for 2, 3 and 8 bits.
|
Made some adjust for the code in peft and gptq for llama, and make it possible for lora finetuning with a 4 bits base model. The same adjustment can be made for 2, 3 and 8 bits.
|
||||||
<br>
|
<br>
|
||||||
~Still numerically unstable.~ Resolved.
|
* Install Manual by s4rduk4r: https://github.com/s4rduk4r/alpaca_lora_4bit_readme/blob/main/README.md
|
||||||
|
|
||||||
|
# Update Logs
|
||||||
|
* Resolved numerically unstable issue
|
||||||
<br>
|
<br>
|
||||||
Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
|
* Reconstruct fp16 matrix from 4bit data and call torch.matmul largely increased the inference speed.
|
||||||
<br>
|
<br>
|
||||||
Added install script for windows and linux.
|
* Added install script for windows and linux.
|
||||||
|
<br>
|
||||||
|
* Added Gradient Checkpointing. Now It can finetune 30b model 4bit on a single GPU with 24G VRAM. (finetune.py updated)
|
||||||
|
<br>
|
||||||
|
* Added install manual by s4rduk4r
|
||||||
<br>
|
<br>
|
||||||
|
|
||||||
# Requirements
|
# Requirements
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,8 @@ TARGET_MODULES = [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
"v_proj",
|
"v_proj",
|
||||||
]
|
]
|
||||||
|
GRADIENT_CHECKPOINTING = False
|
||||||
|
GRADIENT_CHECKPOINTING_RATIO = 1
|
||||||
warmup_steps = 50
|
warmup_steps = 50
|
||||||
save_steps = 50
|
save_steps = 50
|
||||||
save_total_limit = 3
|
save_total_limit = 3
|
||||||
|
|
@ -104,6 +106,12 @@ data = data.shuffle().map(lambda x: tokenize(x))
|
||||||
print('Train Data: {:.2f}%'.format(exceed_count / len(data) * 100), 'outliers')
|
print('Train Data: {:.2f}%'.format(exceed_count / len(data) * 100), 'outliers')
|
||||||
train_data = data
|
train_data = data
|
||||||
|
|
||||||
|
# Use gradient checkpointing
|
||||||
|
if GRADIENT_CHECKPOINTING:
|
||||||
|
print('Applying gradient checkpointing ...')
|
||||||
|
from gradient_checkpointing import apply_gradient_checkpointing
|
||||||
|
apply_gradient_checkpointing(model, checkpoint_ratio=GRADIENT_CHECKPOINTING_RATIO)
|
||||||
|
|
||||||
trainer = transformers.Trainer(
|
trainer = transformers.Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
train_dataset=train_data,
|
train_dataset=train_data,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue