add amp_wrapper for autocast support.
This commit is contained in:
parent
b3c91a5af5
commit
878eada8dd
|
|
@ -0,0 +1,26 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class AMPWrapper:
|
||||||
|
|
||||||
|
def __init__(self, model, options=None):
|
||||||
|
self.model = model
|
||||||
|
self.options = options
|
||||||
|
if self.options is None:
|
||||||
|
self.options = {'enabled': True, 'device_type': 'cuda'}
|
||||||
|
|
||||||
|
def autocast_forward(self, *args, **kwargs):
|
||||||
|
with torch.amp.autocast(**self.options):
|
||||||
|
return self.model.non_autocast_forward(*args, **kwargs)
|
||||||
|
|
||||||
|
def autocast_generate(self, *args, **kwargs):
|
||||||
|
with torch.amp.autocast(**self.options):
|
||||||
|
return self.model.non_autocast_generate(*args, **kwargs)
|
||||||
|
|
||||||
|
def apply_forward(self):
|
||||||
|
self.model.non_autocast_forward = self.model.forward
|
||||||
|
self.model.forward = self.autocast_forward
|
||||||
|
|
||||||
|
def apply_generate(self):
|
||||||
|
self.model.non_autocast_generate = self.model.generate
|
||||||
|
self.model.generate = self.autocast_generate
|
||||||
|
|
@ -5,9 +5,10 @@ import torch
|
||||||
from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear
|
from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear
|
||||||
config_path = './llama-13b-4bit/'
|
config_path = './llama-13b-4bit/'
|
||||||
model_path = './llama-13b-4bit.pt'
|
model_path = './llama-13b-4bit.pt'
|
||||||
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path)
|
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1)
|
||||||
|
|
||||||
print('Fitting 4bit scales and zeros to half')
|
print('Fitting 4bit scales and zeros to half')
|
||||||
|
model.half()
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
if isinstance(m, Autograd4bitQuantLinear):
|
if isinstance(m, Autograd4bitQuantLinear):
|
||||||
if m.groupsize == -1:
|
if m.groupsize == -1:
|
||||||
|
|
@ -15,6 +16,11 @@ for n, m in model.named_modules():
|
||||||
m.scales = m.scales.half()
|
m.scales = m.scales.half()
|
||||||
m.bias = m.bias.half()
|
m.bias = m.bias.half()
|
||||||
|
|
||||||
|
print('Apply AMP Wrapper ...')
|
||||||
|
from amp_wrapper import AMPWrapper
|
||||||
|
wrapper = AMPWrapper(model)
|
||||||
|
wrapper.apply_generate()
|
||||||
|
|
||||||
prompt = '''I think the meaning of life is'''
|
prompt = '''I think the meaning of life is'''
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
|
||||||
batch = {k: v.cuda() for k, v in batch.items()}
|
batch = {k: v.cuda() for k, v in batch.items()}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue