add amp_wrapper for autocast support.
This commit is contained in:
parent
b3c91a5af5
commit
878eada8dd
|
|
@ -0,0 +1,26 @@
|
|||
import torch
|
||||
|
||||
|
||||
class AMPWrapper:
|
||||
|
||||
def __init__(self, model, options=None):
|
||||
self.model = model
|
||||
self.options = options
|
||||
if self.options is None:
|
||||
self.options = {'enabled': True, 'device_type': 'cuda'}
|
||||
|
||||
def autocast_forward(self, *args, **kwargs):
|
||||
with torch.amp.autocast(**self.options):
|
||||
return self.model.non_autocast_forward(*args, **kwargs)
|
||||
|
||||
def autocast_generate(self, *args, **kwargs):
|
||||
with torch.amp.autocast(**self.options):
|
||||
return self.model.non_autocast_generate(*args, **kwargs)
|
||||
|
||||
def apply_forward(self):
|
||||
self.model.non_autocast_forward = self.model.forward
|
||||
self.model.forward = self.autocast_forward
|
||||
|
||||
def apply_generate(self):
|
||||
self.model.non_autocast_generate = self.model.generate
|
||||
self.model.generate = self.autocast_generate
|
||||
|
|
@ -5,9 +5,10 @@ import torch
|
|||
from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear
|
||||
config_path = './llama-13b-4bit/'
|
||||
model_path = './llama-13b-4bit.pt'
|
||||
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path)
|
||||
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1)
|
||||
|
||||
print('Fitting 4bit scales and zeros to half')
|
||||
model.half()
|
||||
for n, m in model.named_modules():
|
||||
if isinstance(m, Autograd4bitQuantLinear):
|
||||
if m.groupsize == -1:
|
||||
|
|
@ -15,6 +16,11 @@ for n, m in model.named_modules():
|
|||
m.scales = m.scales.half()
|
||||
m.bias = m.bias.half()
|
||||
|
||||
print('Apply AMP Wrapper ...')
|
||||
from amp_wrapper import AMPWrapper
|
||||
wrapper = AMPWrapper(model)
|
||||
wrapper.apply_generate()
|
||||
|
||||
prompt = '''I think the meaning of life is'''
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
|
||||
batch = {k: v.cuda() for k, v in batch.items()}
|
||||
|
|
|
|||
Loading…
Reference in New Issue