add amp_wrapper for autocast support.

This commit is contained in:
John Smith 2023-03-30 19:57:19 +08:00
parent b3c91a5af5
commit 878eada8dd
2 changed files with 33 additions and 1 deletions

26
amp_wrapper.py Normal file
View File

@ -0,0 +1,26 @@
import torch
class AMPWrapper:
def __init__(self, model, options=None):
self.model = model
self.options = options
if self.options is None:
self.options = {'enabled': True, 'device_type': 'cuda'}
def autocast_forward(self, *args, **kwargs):
with torch.amp.autocast(**self.options):
return self.model.non_autocast_forward(*args, **kwargs)
def autocast_generate(self, *args, **kwargs):
with torch.amp.autocast(**self.options):
return self.model.non_autocast_generate(*args, **kwargs)
def apply_forward(self):
self.model.non_autocast_forward = self.model.forward
self.model.forward = self.autocast_forward
def apply_generate(self):
self.model.non_autocast_generate = self.model.generate
self.model.generate = self.autocast_generate

View File

@ -5,9 +5,10 @@ import torch
from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear
config_path = './llama-13b-4bit/'
model_path = './llama-13b-4bit.pt'
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path)
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1)
print('Fitting 4bit scales and zeros to half')
model.half()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
if m.groupsize == -1:
@ -15,6 +16,11 @@ for n, m in model.named_modules():
m.scales = m.scales.half()
m.bias = m.bias.half()
print('Apply AMP Wrapper ...')
from amp_wrapper import AMPWrapper
wrapper = AMPWrapper(model)
wrapper.apply_generate()
prompt = '''I think the meaning of life is'''
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
batch = {k: v.cuda() for k, v in batch.items()}