diff --git a/amp_wrapper.py b/amp_wrapper.py new file mode 100644 index 0000000..f7d01f5 --- /dev/null +++ b/amp_wrapper.py @@ -0,0 +1,26 @@ +import torch + + +class AMPWrapper: + + def __init__(self, model, options=None): + self.model = model + self.options = options + if self.options is None: + self.options = {'enabled': True, 'device_type': 'cuda'} + + def autocast_forward(self, *args, **kwargs): + with torch.amp.autocast(**self.options): + return self.model.non_autocast_forward(*args, **kwargs) + + def autocast_generate(self, *args, **kwargs): + with torch.amp.autocast(**self.options): + return self.model.non_autocast_generate(*args, **kwargs) + + def apply_forward(self): + self.model.non_autocast_forward = self.model.forward + self.model.forward = self.autocast_forward + + def apply_generate(self): + self.model.non_autocast_generate = self.model.generate + self.model.generate = self.autocast_generate diff --git a/inference.py b/inference.py index 84ade3d..c0f4599 100644 --- a/inference.py +++ b/inference.py @@ -5,9 +5,10 @@ import torch from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear config_path = './llama-13b-4bit/' model_path = './llama-13b-4bit.pt' -model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path) +model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1) print('Fitting 4bit scales and zeros to half') +model.half() for n, m in model.named_modules(): if isinstance(m, Autograd4bitQuantLinear): if m.groupsize == -1: @@ -15,6 +16,11 @@ for n, m in model.named_modules(): m.scales = m.scales.half() m.bias = m.bias.half() +print('Apply AMP Wrapper ...') +from amp_wrapper import AMPWrapper +wrapper = AMPWrapper(model) +wrapper.apply_generate() + prompt = '''I think the meaning of life is''' batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) batch = {k: v.cuda() for k, v in batch.items()}