62 lines
1.9 KiB
Python
62 lines
1.9 KiB
Python
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
from torch.utils.checkpoint import checkpoint
|
|
from torch.autograd import Variable
|
|
import torch
|
|
from torch import nn
|
|
import numpy as np
|
|
|
|
|
|
class NewForward:
|
|
|
|
def __init__(self, layer):
|
|
self.layer = layer
|
|
self.apply_patch()
|
|
|
|
def apply_patch(self):
|
|
self.layer.old_forward_for_cp = self.layer.forward
|
|
self.layer.forward = self.new_forward
|
|
|
|
def new_forward(self, *args, **kwargs):
|
|
def func(*args):
|
|
return self.layer.old_forward_for_cp(*args, **kwargs)
|
|
output = checkpoint(func, *args)
|
|
return output
|
|
|
|
|
|
class VarWrapper:
|
|
|
|
def __init__(self, model):
|
|
self.model = model
|
|
self.apply_patch()
|
|
print('Var Wrapper Patch Applied')
|
|
|
|
def apply_patch(self):
|
|
self.model.old_forward_for_cp = self.model.forward
|
|
self.model.forward = self.new_forward
|
|
|
|
def new_forward(self, *args, **kwargs):
|
|
out = self.model.old_forward_for_cp(*args, **kwargs)
|
|
out = Variable(out.data, requires_grad=True)
|
|
return out
|
|
|
|
|
|
def apply_gradient_checkpointing(model, checkpoint_ratio=1):
|
|
new_forwards = []
|
|
modules = []
|
|
for n, m in model.named_modules():
|
|
if isinstance(m, LlamaDecoderLayer):
|
|
modules.append(m)
|
|
if checkpoint_ratio < 1 and checkpoint_ratio > 0:
|
|
checkpoint_locs = np.array((np.linspace(0, 1, int(len(modules) * checkpoint_ratio)) * (len(modules)-1)).round(), dtype=int)
|
|
else:
|
|
checkpoint_locs = np.arange(len(modules))
|
|
for i in checkpoint_locs:
|
|
m = modules[i]
|
|
new_forwards.append(NewForward(m))
|
|
print('Forward Patch Applied For Block {}'.format(i))
|
|
for n, m in model.named_modules():
|
|
if isinstance(m, torch.nn.Embedding):
|
|
wrapper = VarWrapper(m)
|
|
break
|
|
return new_forwards, wrapper
|