From f20570343f7895c61d814bf53bcbb41bb0fb5f22 Mon Sep 17 00:00:00 2001 From: Andrey Glushenkov Date: Thu, 6 Apr 2023 02:29:36 +0300 Subject: [PATCH 01/19] GPTQv2 support GPTQv2 support. 1. Adds dependency on `triton` 2. Refactors autograd_4bit to include both GPTQv1 and GPTQv2 3. Introduces new environment variable GPTQ_VERSION to select autograd_4bit version 4. Fixes triton kernels 5. Matrix multiplications are in fp16 --- Finetune4bConfig.py | 3 +- README.md | 10 ++ autograd_4bit/__init__.py | 21 +++ autograd_4bit/autograd_4bit_v1.py | 208 ++++++++++++++++++++++++++++ autograd_4bit/autograd_4bit_v2.py | 221 ++++++++++++++++++++++++++++++ finetune.py | 1 + requirements.txt | 1 + triton_test.py | 154 +++++++++++++++++++++ 8 files changed, 618 insertions(+), 1 deletion(-) create mode 100644 autograd_4bit/__init__.py create mode 100644 autograd_4bit/autograd_4bit_v1.py create mode 100644 autograd_4bit/autograd_4bit_v2.py create mode 100644 triton_test.py diff --git a/Finetune4bConfig.py b/Finetune4bConfig.py index a8a33bf..3880462 100644 --- a/Finetune4bConfig.py +++ b/Finetune4bConfig.py @@ -95,5 +95,6 @@ class Finetune4bConfig: f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\ f"{self.logging_steps=}\n" +\ f"{self.checkpoint=}\n{self.skip=}\n" +\ - f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}" + f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}\n" +\ + f"{self.groupsize=}\n" return s.replace("self.", "") diff --git a/README.md b/README.md index 2326cde..161bb26 100644 --- a/README.md +++ b/README.md @@ -34,10 +34,20 @@ pip install -r requirements.txt ~The same finetune script from https://github.com/tloen/alpaca-lora can be used.~
After installation, this script can be used: +GPTQv1: ``` python finetune.py ``` +or +``` +GPTQ_VERSION=1 python finetune.py +``` + +GPTQv2: +``` +GPTQ_VERSION=2 python finetune.py +``` # Inference diff --git a/autograd_4bit/__init__.py b/autograd_4bit/__init__.py new file mode 100644 index 0000000..bee84b3 --- /dev/null +++ b/autograd_4bit/__init__.py @@ -0,0 +1,21 @@ +import os +from colorama import init, Fore, Back, Style +init(autoreset=True) + +try: + GPTQ_VERSION = int(os.environ["GPTQ_VERSION"]) +except: + print(Style.BRIGHT + Fore.YELLOW + "GPTQ_VERSION environment not provided. Fallback to GPTQv1") + GPTQ_VERSION = 1 # Fallback version + +loader = None + + +if GPTQ_VERSION == 1: + from .autograd_4bit_v1 import Autograd4bitQuantLinear, load_llama_model_4bit_low_ram + print(Style.BRIGHT + Fore.GREEN + "GPTQv1 set") +elif GPTQ_VERSION == 2: + from .autograd_4bit_v2 import Autograd4bitQuantLinear, load_llama_model_4bit_low_ram + print(Style.BRIGHT + Fore.GREEN + "GPTQv2 set") +else: + raise ValueError("GPTQ_VERSION not set or invalid") \ No newline at end of file diff --git a/autograd_4bit/autograd_4bit_v1.py b/autograd_4bit/autograd_4bit_v1.py new file mode 100644 index 0000000..cb9d308 --- /dev/null +++ b/autograd_4bit/autograd_4bit_v1.py @@ -0,0 +1,208 @@ +import matmul_utils_4bit as mm4b +import torch +import torch.nn as nn +import time + + +class AutogradMatmul4bit(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, qweight, scales, zeros, groupsize=-1): + ctx.save_for_backward(qweight, scales, zeros) + ctx.groupsize = groupsize + if groupsize == -1: + output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros) + else: + output = mm4b._matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize) + output = output.clone() + return output + + @staticmethod + def backward(ctx, grad_output): + qweight, scales, zeros = ctx.saved_tensors + groupsize = ctx.groupsize + if groupsize == -1: + grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True) + else: + grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, groupsize=groupsize, transpose=True) + return grad, None, None, None, None + + +# Assumes layer is perfectly divisible into 256 * 256 blocks +class Autograd4bitQuantLinear(nn.Module): + + def __init__(self, in_features, out_features, groupsize=None): + super().__init__() + bits = 4 + self.in_features = in_features + self.out_features = out_features + self.bits = bits + self.register_buffer('zeros', torch.empty((out_features, 1))) + self.register_buffer('scales', torch.empty((out_features, 1))) + self.bias = nn.Parameter(torch.empty(out_features)) + self.register_buffer( + 'qweight', torch.empty((in_features // 256 * (bits * 8), out_features), dtype=torch.int) + ) + + + def forward(self, x): + if torch.is_grad_enabled(): + out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, + self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize) + out.add_(self.bias) + else: + out = mm4b.matmul4bit(x, self.qweight, self.scales, + self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize) + out.add_(self.bias) + return out + + +def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1): + if isinstance(module, Autograd4bitQuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + setattr( + module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize) + ) + for name1, child in module.named_children(): + make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize) + + +def model_to_half(model): + model.half() + for n, m in model.named_modules(): + if isinstance(m, Autograd4bitQuantLinear): + m.zeros = m.zeros.half() + m.scales = m.scales.half() + m.bias = m.bias.half() + print('Converted as Half.') + + +def model_to_float(model): + model.float() + for n, m in model.named_modules(): + if isinstance(m, Autograd4bitQuantLinear): + m.zeros = m.zeros.float() + m.scales = m.scales.float() + m.bias = m.bias.float() + print('Converted as Float.') + + +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers( + child, layers=layers, name=name + '.' + name1 if name != '' else name1 + )) + return res + + +def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048): + import accelerate + from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer + + print("Loading Model ...") + t0 = time.time() + + with accelerate.init_empty_weights(): + config = LlamaConfig.from_pretrained(config_path) + model = LlamaForCausalLM(config) + model = model.eval() + layers = find_layers(model) + for name in ['lm_head']: + if name in layers: + del layers[name] + make_quant_for_4bit_autograd(model, layers, groupsize=groupsize) + model = accelerate.load_checkpoint_and_dispatch( + model=model, + checkpoint=model_path, + device_map=device_map, + no_split_module_classes=["LlamaDecoderLayer"] + ) + + model.seqlen = seqlen + + if half: + model_to_half(model) + + tokenizer = LlamaTokenizer.from_pretrained(config_path) + tokenizer.truncation_side = 'left' + + print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") + + return model, tokenizer + +def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None): + import accelerate + from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer + + if max_memory is None: + max_memory = {0: '24Gib', 'cpu': '48Gib'} + + print("Loading Model ...") + t0 = time.time() + + with accelerate.init_empty_weights(): + config = LlamaConfig.from_pretrained(config_path) + model = LlamaForCausalLM(config) + model = model.eval() + layers = find_layers(model) + for name in ['lm_head']: + if name in layers: + del layers[name] + make_quant_for_4bit_autograd(model, layers, groupsize=groupsize) + accelerate.load_checkpoint_in_model(model, checkpoint=model_path, device_map={'': 'cpu'}) + + # rotary_emb fix + for n, m in model.named_modules(): + if 'rotary_emb' in n: + cos_cached = m.cos_cached.clone().cpu() + sin_cached = m.sin_cached.clone().cpu() + break + + if lora_path is not None: + from peft import PeftModel + from peft.tuners.lora import Linear4bitLt + model = PeftModel.from_pretrained(model, lora_path, device_map={'': 'cpu'}, torch_dtype=torch.float32) + print('{} Lora Applied.'.format(lora_path)) + + model.seqlen = seqlen + + print('Apply half ...') + for n, m in model.named_modules(): + if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)): + m.zeros = m.zeros.half() + m.scales = m.scales.half() + m.bias = m.bias.half() + + print('Dispatching model ...') + device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"]) + model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True, main_device=0) + torch.cuda.empty_cache() + print('Total {:.2f} Gib VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024)) + + # rotary_emb fix + for n, m in model.named_modules(): + if 'rotary_emb' in n: + if getattr(m, '_hf_hook', None): + if isinstance(m._hf_hook, accelerate.hooks.SequentialHook): + hooks = m._hf_hook.hooks + else: + hooks = [m._hf_hook] + for hook in hooks: + if hook.offload: + if n + '.sin_cached' not in hook.weights_map.dataset.state_dict.keys(): + hook.weights_map.dataset.state_dict[n + '.sin_cached'] = sin_cached.clone().cpu() + hook.weights_map.dataset.state_dict[n + '.cos_cached'] = cos_cached.clone().cpu() + + tokenizer = LlamaTokenizer.from_pretrained(config_path) + tokenizer.truncation_side = 'left' + + print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") + + return model, tokenizer diff --git a/autograd_4bit/autograd_4bit_v2.py b/autograd_4bit/autograd_4bit_v2.py new file mode 100644 index 0000000..20c253d --- /dev/null +++ b/autograd_4bit/autograd_4bit_v2.py @@ -0,0 +1,221 @@ +from colorama import init, Fore, Back, Style +import torch +import torch.nn as nn +import time +import math +import triton +from triton_utils import matmul_248_kernel, trans_matmul_248_kernel + + +class AutogradMatmul4bit(torch.autograd.Function): + @staticmethod + def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): + output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),) + matmul_248_kernel[grid](input, qweight, output, + scales, qzeros, g_idx, + input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, + input.stride(0), input.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), qzeros.stride(0)) + + ctx.save_for_backward(qweight, scales, qzeros, g_idx) + ctx.input_shape, ctx.bits,ctx.maxq = input.shape,bits, maxq + return output + + @staticmethod + def backward(ctx, grad_output): + input_shape, bits, maxq = ctx.input_shape, ctx.bits, ctx.maxq + qweight, scales, qzeros, g_idx = ctx.saved_tensors + grade_input = None + + if ctx.needs_input_grad[0]: + grade_input = torch.empty((input_shape[0], input_shape[1]), device='cuda', dtype=torch.float32) + grid = lambda META: (triton.cdiv(input_shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(input_shape[1], META['BLOCK_SIZE_K']),) + trans_matmul_248_kernel[grid](grad_output, qweight, grade_input, + scales, qzeros, g_idx, + input_shape[0], qweight.shape[1], input_shape[1], bits, maxq, + grad_output.stride(0), grad_output.stride(1), + qweight.stride(0), qweight.stride(1), + grade_input.stride(0), grade_input.stride(1), + scales.stride(0), qzeros.stride(0)) + return grade_input, None, None, None, None, None, None + + +class Autograd4bitQuantLinear(nn.Module): + + def __init__(self, in_features, out_features, groupsize, bias=True): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.bits = 4 # Hardcoded 4-bits quantizations + self.maxq = 2 ** self.bits - 1 + self.groupsize = groupsize if groupsize != -1 else in_features + + self.register_buffer('qweight', torch.zeros((in_features // 32 * self.bits, out_features), dtype=torch.int32)) + self.register_buffer('qzeros', torch.zeros((math.ceil(in_features / self.groupsize), out_features // 32 * self.bits), dtype=torch.int32)) + self.register_buffer('scales', torch.zeros((math.ceil(in_features / self.groupsize), out_features), dtype=torch.float16)) + self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(in_features)], dtype = torch.int32)) + if bias: + self.register_buffer('bias', torch.zeros(out_features,dtype=torch.float16)) + else: + self.bias = None + + def forward(self, x): + out_shape = x.shape[:-1] + (self.out_features, ) + out = AutogradMatmul4bit.apply(x.reshape(-1,x.shape[-1]), self.qweight, self.scales, + self.qzeros, self.g_idx, self.bits, self.maxq) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) + + +def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1): + if isinstance(module, Autograd4bitQuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + setattr( + module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize) + ) + for name1, child in module.named_children(): + make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize) + + +def model_to_half(model): + model.half() + for n, m in model.named_modules(): + if isinstance(m, Autograd4bitQuantLinear): + m.qzeros = m.qzeros.half() + m.scales = m.scales.half() + m.bias = m.bias.half() + print(Style.BRIGHT + Fore.YELLOW + 'Converted as Half.') + + +def model_to_float(model): + model.float() + for n, m in model.named_modules(): + if isinstance(m, Autograd4bitQuantLinear): + m.qzeros = m.qzeros.float() + m.scales = m.scales.float() + m.bias = m.bias.float() + print(Style.BRIGHT + Fore.YELLOW + 'Converted as Float.') + + +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers( + child, layers=layers, name=name + '.' + name1 if name != '' else name1 + )) + return res + + +def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048): + import accelerate + from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer + + print(Style.BRIGHT + Fore.CYAN + "Loading Model ...") + t0 = time.time() + + with accelerate.init_empty_weights(): + config = LlamaConfig.from_pretrained(config_path) + model = LlamaForCausalLM(config) + model = model.eval() + layers = find_layers(model) + for name in ['lm_head']: + if name in layers: + del layers[name] + make_quant_for_4bit_autograd(model, layers, groupsize=groupsize) + model = accelerate.load_checkpoint_and_dispatch( + model=model, + checkpoint=model_path, + device_map=device_map, + no_split_module_classes=["LlamaDecoderLayer"] + ) + + model.seqlen = seqlen + + if half: + model_to_half(model) + + tokenizer = LlamaTokenizer.from_pretrained(config_path) + tokenizer.truncation_side = 'left' + + print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.") + + return model, tokenizer + +def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None): + import accelerate + from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer + + if max_memory is None: + max_memory = {0: '24Gib', 'cpu': '48Gib'} + + print(Style.BRIGHT + Fore.CYAN + "Loading Model ...") + t0 = time.time() + + with accelerate.init_empty_weights(): + config = LlamaConfig.from_pretrained(config_path) + model = LlamaForCausalLM(config) + model = model.eval() + layers = find_layers(model) + for name in ['lm_head']: + if name in layers: + del layers[name] + make_quant_for_4bit_autograd(model, layers, groupsize=groupsize) + accelerate.load_checkpoint_in_model(model, checkpoint=model_path, device_map={'': 'cpu'}) + + # rotary_emb fix + for n, m in model.named_modules(): + if 'rotary_emb' in n: + cos_cached = m.cos_cached.clone().cpu() + sin_cached = m.sin_cached.clone().cpu() + break + + if lora_path is not None: + from peft import PeftModel + from peft.tuners.lora import Linear4bitLt + model = PeftModel.from_pretrained(model, lora_path, device_map={'': 'cpu'}, torch_dtype=torch.float32) + print(Style.BRIGHT + Fore.GREEN + '{} Lora Applied.'.format(lora_path)) + + model.seqlen = seqlen + + print('Apply half ...') + for n, m in model.named_modules(): + if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)): + m.qzeros = m.qzeros.half() + m.scales = m.scales.half() + m.bias = m.bias.half() + + print('Dispatching model ...') + device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"]) + model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True, main_device=0) + torch.cuda.empty_cache() + print(Style.BRIGHT + Fore.YELLOW + 'Total {:.2f} Gib VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024)) + + # rotary_emb fix + for n, m in model.named_modules(): + if 'rotary_emb' in n: + if getattr(m, '_hf_hook', None): + if isinstance(m._hf_hook, accelerate.hooks.SequentialHook): + hooks = m._hf_hook.hooks + else: + hooks = [m._hf_hook] + for hook in hooks: + if hook.offload: + if n + '.sin_cached' not in hook.weights_map.dataset.state_dict.keys(): + hook.weights_map.dataset.state_dict[n + '.sin_cached'] = sin_cached.clone().cpu() + hook.weights_map.dataset.state_dict[n + '.cos_cached'] = cos_cached.clone().cpu() + + tokenizer = LlamaTokenizer.from_pretrained(config_path) + tokenizer.truncation_side = 'left' + + print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.") + + return model, tokenizer diff --git a/finetune.py b/finetune.py index f374e2b..80d2156 100644 --- a/finetune.py +++ b/finetune.py @@ -109,6 +109,7 @@ if not ft_config.skip: per_device_train_batch_size=ft_config.mbatch_size, gradient_accumulation_steps=ft_config.gradient_accumulation_steps, warmup_steps=ft_config.warmup_steps, + optim="adamw_torch", num_train_epochs=ft_config.epochs, learning_rate=ft_config.lr, fp16=True, diff --git a/requirements.txt b/requirements.txt index 605c0d1..9b117d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ bitsandbytes datasets sentencepiece safetensors +triton git+https://github.com/huggingface/transformers.git git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit git+https://github.com/sterlind/peft.git diff --git a/triton_test.py b/triton_test.py new file mode 100644 index 0000000..eeb77a9 --- /dev/null +++ b/triton_test.py @@ -0,0 +1,154 @@ +import torch + +import triton +import triton.language as tl + +# % +# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` +# decorator, which consumes: +# - A list of :code:`triton.Config` objects that define different configurations of +# meta-parameters (e.g., BLOCK_SIZE_M) and compilation options (e.g., num_warps) to try +# - An autotuning *key* whose change in values will trigger evaluation of all the +# provided configs + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. stride_am is how much to increase a_ptr + # by to get the element one row down (A has M rows) + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse + # See above `L2 Cache Optimizations` section for details + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers + # see above `Pointer Arithmetics` section for details + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + # Note that for simplicity, we don't apply a mask here. + # This means that if K is not a multiple of BLOCK_SIZE_K, + # this will access out-of-bounds memory and produce an + # error or (worse!) incorrect results. + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + # We accumulate along the K dimension + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # you can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul` +@triton.jit +def leaky_relu(x): + x = x + 1 + return tl.where(x >= 0, x, 0.01 * x) + +def matmul(a, b, activation=""): + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + assert a.is_contiguous(), "matrix A must be contiguous" + assert b.is_contiguous(), "matrix B must be contiguous" + M, K = a.shape + K, N = b.shape + assert ( + K % 32 == 0 + ), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K" + # allocates output + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + # 1D launch kernel where each block gets its own program. + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + ACTIVATION=activation, + ) + return c + + + +torch.manual_seed(0) +a = torch.randn((512, 512), device='cuda', dtype=torch.float16) +b = torch.randn((512, 512), device='cuda', dtype=torch.float16) +triton_output = matmul(a, b) +torch_output = torch.matmul(a, b) +print(f"triton_output={triton_output}") +print(f"torch_output={torch_output}") +if triton.testing.allclose(triton_output, torch_output): + print("✅ Triton and Torch match") +else: + print("❌ Triton and Torch differ") From 4a2d23aa293e61dd06b9294e5af05aa6789ce140 Mon Sep 17 00:00:00 2001 From: Andrey Glushenkov Date: Thu, 6 Apr 2023 02:31:06 +0300 Subject: [PATCH 02/19] Delete autograd_4bit.py File moved to autograd_4bit module --- autograd_4bit.py | 220 ----------------------------------------------- 1 file changed, 220 deletions(-) delete mode 100644 autograd_4bit.py diff --git a/autograd_4bit.py b/autograd_4bit.py deleted file mode 100644 index bb63cab..0000000 --- a/autograd_4bit.py +++ /dev/null @@ -1,220 +0,0 @@ -import matmul_utils_4bit as mm4b -import torch -import torch.nn as nn -import time -import math - - -class AutogradMatmul4bit(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, qweight, scales, zeros, groupsize=-1): - ctx.save_for_backward(qweight, scales, zeros) - ctx.groupsize = groupsize - if groupsize == -1: - output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros) - else: - output = mm4b._matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize) - output = output.clone() - return output - - @staticmethod - def backward(ctx, grad_output): - qweight, scales, zeros = ctx.saved_tensors - groupsize = ctx.groupsize - if groupsize == -1: - grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True) - else: - grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, groupsize=groupsize, transpose=True) - return grad, None, None, None, None - - -# Assumes layer is perfectly divisible into 256 * 256 blocks -class Autograd4bitQuantLinear(nn.Module): - - def __init__(self, infeatures, outfeatures, groupsize=-1): - super().__init__() - bits = 4 - self.in_features = infeatures - self.out_features = outfeatures - self.bits = bits - self.groupsize = groupsize - if groupsize == -1: - self.register_buffer('zeros', torch.empty((outfeatures, 1))) - self.register_buffer('scales', torch.empty((outfeatures, 1))) - else: - self.register_buffer('qzeros', - torch.empty((math.ceil(infeatures/groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int) - ) - self.register_buffer('scales', torch.empty((math.ceil(infeatures/groupsize), outfeatures))) - self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype = torch.int32)) - self.bias = nn.Parameter(torch.empty(outfeatures)) - self.register_buffer( - 'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int) - ) - - - def forward(self, x): - if torch.is_grad_enabled(): - out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, - self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize) - out.add_(self.bias) - else: - out = mm4b.matmul4bit(x, self.qweight, self.scales, - self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize) - out.add_(self.bias) - return out - - -def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1): - if isinstance(module, Autograd4bitQuantLinear): - return - for attr in dir(module): - tmp = getattr(module, attr) - name1 = name + '.' + attr if name != '' else attr - if name1 in names: - setattr( - module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize) - ) - for name1, child in module.named_children(): - make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize) - - -def model_to_half(model): - model.half() - for n, m in model.named_modules(): - if isinstance(m, Autograd4bitQuantLinear): - if m.groupsize == -1: - m.zeros = m.zeros.half() - m.scales = m.scales.half() - m.bias = m.bias.half() - print('Converted as Half.') - - -def model_to_float(model): - model.float() - for n, m in model.named_modules(): - if isinstance(m, Autograd4bitQuantLinear): - if m.groupsize == -1: - m.zeros = m.zeros.float() - m.scales = m.scales.float() - m.bias = m.bias.float() - print('Converted as Float.') - - -def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): - if type(module) in layers: - return {name: module} - res = {} - for name1, child in module.named_children(): - res.update(find_layers( - child, layers=layers, name=name + '.' + name1 if name != '' else name1 - )) - return res - - -def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048): - import accelerate - from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer - - print("Loading Model ...") - t0 = time.time() - - with accelerate.init_empty_weights(): - config = LlamaConfig.from_pretrained(config_path) - model = LlamaForCausalLM(config) - model = model.eval() - layers = find_layers(model) - for name in ['lm_head']: - if name in layers: - del layers[name] - make_quant_for_4bit_autograd(model, layers, groupsize=groupsize) - model = accelerate.load_checkpoint_and_dispatch( - model=model, - checkpoint=model_path, - device_map=device_map, - no_split_module_classes=["LlamaDecoderLayer"] - ) - - model.seqlen = seqlen - - if half: - model_to_half(model) - - tokenizer = LlamaTokenizer.from_pretrained(config_path) - tokenizer.truncation_side = 'left' - - print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") - - return model, tokenizer - -def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None): - import accelerate - from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer - - if max_memory is None: - max_memory = {0: '24Gib', 'cpu': '48Gib'} - - print("Loading Model ...") - t0 = time.time() - - with accelerate.init_empty_weights(): - config = LlamaConfig.from_pretrained(config_path) - model = LlamaForCausalLM(config) - model = model.eval() - layers = find_layers(model) - for name in ['lm_head']: - if name in layers: - del layers[name] - make_quant_for_4bit_autograd(model, layers, groupsize=groupsize) - accelerate.load_checkpoint_in_model(model, checkpoint=model_path, device_map={'': 'cpu'}) - - # rotary_emb fix - for n, m in model.named_modules(): - if 'rotary_emb' in n: - cos_cached = m.cos_cached.clone().cpu() - sin_cached = m.sin_cached.clone().cpu() - break - - if lora_path is not None: - from peft import PeftModel - from peft.tuners.lora import Linear4bitLt - model = PeftModel.from_pretrained(model, lora_path, device_map={'': 'cpu'}, torch_dtype=torch.float32) - print('{} Lora Applied.'.format(lora_path)) - - model.seqlen = seqlen - - print('Apply half ...') - for n, m in model.named_modules(): - if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)): - if m.groupsize == -1: - m.zeros = m.zeros.half() - m.scales = m.scales.half() - m.bias = m.bias.half() - - print('Dispatching model ...') - device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"]) - model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True, main_device=0) - torch.cuda.empty_cache() - print('Total {:.2f} Gib VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024)) - - # rotary_emb fix - for n, m in model.named_modules(): - if 'rotary_emb' in n: - if getattr(m, '_hf_hook', None): - if isinstance(m._hf_hook, accelerate.hooks.SequentialHook): - hooks = m._hf_hook.hooks - else: - hooks = [m._hf_hook] - for hook in hooks: - if hook.offload: - if n + '.sin_cached' not in hook.weights_map.dataset.state_dict.keys(): - hook.weights_map.dataset.state_dict[n + '.sin_cached'] = sin_cached.clone().cpu() - hook.weights_map.dataset.state_dict[n + '.cos_cached'] = cos_cached.clone().cpu() - - tokenizer = LlamaTokenizer.from_pretrained(config_path) - tokenizer.truncation_side = 'left' - - print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") - - return model, tokenizer From 0d271d5d90dff890be586af6177f076fe4448ddc Mon Sep 17 00:00:00 2001 From: Andrey Glushenkov Date: Thu, 6 Apr 2023 02:38:06 +0300 Subject: [PATCH 03/19] Add files via upload Fix triton kernels --- triton_utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/triton_utils.py b/triton_utils.py index 7f50c5e..57f9e66 100644 --- a/triton_utils.py +++ b/triton_utils.py @@ -62,7 +62,7 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr, g_ptrs = g_ptr + offs_k # shifter is used to extract the N bits of each element in the 32-bit word from B scales_ptrs = scales_ptr + offs_bn[None, :] - zeros_ptrs = zeros_ptr + (offs_bn[None, :]// infearure_per_bits) + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) shifter = (offs_k % infearure_per_bits) * bits zeros_shifter = (offs_bn % infearure_per_bits) * bits @@ -78,12 +78,15 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr, zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros + 1) - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated # Now we need to unpack b (which is N-bit values) into 32-bit values b = (b >> shifter[:, None]) & maxq # Extract the N-bit values b = (b - zeros) * scales # Scale and shift + # ! Convert to fp16 + b = b.to(tl.float16) + a = a.to(tl.float16) accumulator += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K @@ -93,7 +96,7 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr, c = accumulator.to(tl.float16) c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) + tl.store(c_ptrs, c, mask=c_mask) # code based https://github.com/fpgaminer/GPTQ-triton @triton.autotune( @@ -157,7 +160,7 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, # shifter is used to extract the N bits of each element in the 32-bit word from B scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales - zeros_ptrs = zeros_ptr + (offs_n[None, :]// infearure_per_bits) + g_idx[:, None] * stride_zeros + zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros shifter = (offs_bk % infearure_per_bits) * bits zeros_shifter = (offs_n % infearure_per_bits) * bits @@ -178,6 +181,9 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, b = (b >> shifter[:, None]) & maxq # Extract the N-bit values b = (b - zeros) * scales # Scale and shift b = tl.trans(b) + # ! Convert to fp16 + b = b.to(tl.float16) + a = a.to(tl.float16) accumulator += tl.dot(a, b) a_ptrs += BLOCK_SIZE_N @@ -188,7 +194,7 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, c = accumulator.to(tl.float16) c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) - tl.store(c_ptrs, accumulator, mask=c_mask) + tl.store(c_ptrs, c, mask=c_mask) def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): @@ -202,4 +208,3 @@ def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) return output - \ No newline at end of file From c991e2a091211316ccf4e5c53c807e08a6da5f79 Mon Sep 17 00:00:00 2001 From: Andrey Glushenkov Date: Thu, 6 Apr 2023 02:39:40 +0300 Subject: [PATCH 04/19] Delete triton_test.py --- triton_test.py | 154 ------------------------------------------------- 1 file changed, 154 deletions(-) delete mode 100644 triton_test.py diff --git a/triton_test.py b/triton_test.py deleted file mode 100644 index eeb77a9..0000000 --- a/triton_test.py +++ /dev/null @@ -1,154 +0,0 @@ -import torch - -import triton -import triton.language as tl - -# % -# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` -# decorator, which consumes: -# - A list of :code:`triton.Config` objects that define different configurations of -# meta-parameters (e.g., BLOCK_SIZE_M) and compilation options (e.g., num_warps) to try -# - An autotuning *key* whose change in values will trigger evaluation of all the -# provided configs - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def matmul_kernel( - # Pointers to matrices - a_ptr, b_ptr, c_ptr, - # Matrix dimensions - M, N, K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. stride_am is how much to increase a_ptr - # by to get the element one row down (A has M rows) - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - ACTIVATION: tl.constexpr, -): - """Kernel for computing the matmul C = A x B. - A has shape (M, K), B has shape (K, N) and C has shape (M, N) - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse - # See above `L2 Cache Optimizations` section for details - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers - # see above `Pointer Arithmetics` section for details - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, K, BLOCK_SIZE_K): - # Note that for simplicity, we don't apply a mask here. - # This means that if K is not a multiple of BLOCK_SIZE_K, - # this will access out-of-bounds memory and produce an - # error or (worse!) incorrect results. - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) - # We accumulate along the K dimension - accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - # you can fuse arbitrary activation functions here - # while the accumulator is still in FP32! - if ACTIVATION == "leaky_relu": - accumulator = leaky_relu(accumulator) - c = accumulator.to(tl.float16) - - # ----------------------------------------------------------- - # Write back the block of the output matrix C - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) - - -# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul` -@triton.jit -def leaky_relu(x): - x = x + 1 - return tl.where(x >= 0, x, 0.01 * x) - -def matmul(a, b, activation=""): - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - assert a.is_contiguous(), "matrix A must be contiguous" - assert b.is_contiguous(), "matrix B must be contiguous" - M, K = a.shape - K, N = b.shape - assert ( - K % 32 == 0 - ), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K" - # allocates output - c = torch.empty((M, N), device=a.device, dtype=a.dtype) - # 1D launch kernel where each block gets its own program. - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - matmul_kernel[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - ACTIVATION=activation, - ) - return c - - - -torch.manual_seed(0) -a = torch.randn((512, 512), device='cuda', dtype=torch.float16) -b = torch.randn((512, 512), device='cuda', dtype=torch.float16) -triton_output = matmul(a, b) -torch_output = torch.matmul(a, b) -print(f"triton_output={triton_output}") -print(f"torch_output={torch_output}") -if triton.testing.allclose(triton_output, torch_output): - print("✅ Triton and Torch match") -else: - print("❌ Triton and Torch differ") From 085d9556f9f7413ac994a53d0ea150c6b43f515b Mon Sep 17 00:00:00 2001 From: John Smith Date: Thu, 6 Apr 2023 10:46:42 +0800 Subject: [PATCH 05/19] fix bug --- autograd_4bit.py | 6 +- triton_utils.py | 408 +++++++++++++++++++++++------------------------ 2 files changed, 207 insertions(+), 207 deletions(-) diff --git a/autograd_4bit.py b/autograd_4bit.py index bb63cab..b7e883f 100644 --- a/autograd_4bit.py +++ b/autograd_4bit.py @@ -48,7 +48,7 @@ class Autograd4bitQuantLinear(nn.Module): ) self.register_buffer('scales', torch.empty((math.ceil(infeatures/groupsize), outfeatures))) self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype = torch.int32)) - self.bias = nn.Parameter(torch.empty(outfeatures)) + self.register_buffer('bias', torch.empty(outfeatures)) self.register_buffer( 'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int) ) @@ -58,11 +58,11 @@ class Autograd4bitQuantLinear(nn.Module): if torch.is_grad_enabled(): out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize) - out.add_(self.bias) + out += self.bias else: out = mm4b.matmul4bit(x, self.qweight, self.scales, self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize) - out.add_(self.bias) + out += self.bias return out diff --git a/triton_utils.py b/triton_utils.py index 7f50c5e..940c73c 100644 --- a/triton_utils.py +++ b/triton_utils.py @@ -1,205 +1,205 @@ -import triton -import triton.language as tl -import torch - -# code based https://github.com/fpgaminer/GPTQ-triton -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - ], - key=['M', 'N', 'K'], -) - -@triton.jit -def matmul_248_kernel(a_ptr, b_ptr, c_ptr, - scales_ptr, zeros_ptr, g_ptr, - M, N, K, bits, maxq, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_scales, stride_zeros, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = (offs_am[:, None] < M) - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - g_ptrs = g_ptr + offs_k - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + offs_bn[None, :] - zeros_ptrs = zeros_ptr + (offs_bn[None, :]// infearure_per_bits) - - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, num_pid_k): - g_idx = tl.load(g_ptrs) - - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) - - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros) * scales # Scale and shift - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - g_ptrs += BLOCK_SIZE_K - - c = accumulator.to(tl.float16) - c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - -# code based https://github.com/fpgaminer/GPTQ-triton -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - ], - key=['M', 'N', 'K'], -) - -@triton.jit -def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, - scales_ptr, zeros_ptr, g_ptr, - M, N, K, bits, maxq, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_scales, stride_zeros, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, N) float16 - B is of shape (K//8, N) int32 - C is of shape (M, K) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_k - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_k = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - offs_n = tl.arange(0, BLOCK_SIZE_N) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) - a_mask = (offs_am[:, None] < M) - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - g_ptrs = g_ptr + offs_bk - g_idx = tl.load(g_ptrs) - - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales - zeros_ptrs = zeros_ptr + (offs_n[None, :]// infearure_per_bits) + g_idx[:, None] * stride_zeros - - shifter = (offs_bk % infearure_per_bits) * bits - zeros_shifter = (offs_n % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) - - for k in range(0, num_pid_n): - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) - - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros) * scales # Scale and shift - b = tl.trans(b) - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_N - b_ptrs += BLOCK_SIZE_N - scales_ptrs += BLOCK_SIZE_N - zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) - - c = accumulator.to(tl.float16) - c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) - tl.store(c_ptrs, accumulator, mask=c_mask) - - -def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): - output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16) - grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),) - matmul_248_kernel[grid](input, qweight, output, - scales, qzeros, g_idx, - input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, - input.stride(0), input.stride(1), - qweight.stride(0), qweight.stride(1), - output.stride(0), output.stride(1), - scales.stride(0), qzeros.stride(0)) - return output +import triton +import triton.language as tl +import torch + +# code based https://github.com/fpgaminer/GPTQ-triton +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + ], + key=['M', 'N', 'K'], +) + +@triton.jit +def matmul_248_kernel(a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, g_ptr, + M, N, K, bits, maxq, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales, stride_zeros, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_bn[None, :] + zeros_ptrs = zeros_ptr + (offs_bn[None, :]// infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c = accumulator.to(tl.float16) + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + +# code based https://github.com/fpgaminer/GPTQ-triton +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + ], + key=['M', 'N', 'K'], +) + +@triton.jit +def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, g_ptr, + M, N, K, bits, maxq, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales, stride_zeros, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, N) float16 + B is of shape (K//8, N) int32 + C is of shape (M, K) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_k + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_k = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_n = tl.arange(0, BLOCK_SIZE_N) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_bk + g_idx = tl.load(g_ptrs) + + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales + zeros_ptrs = zeros_ptr + (offs_n[None, :]// infearure_per_bits) + g_idx[:, None] * stride_zeros + + shifter = (offs_bk % infearure_per_bits) * bits + zeros_shifter = (offs_n % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + + for k in range(0, num_pid_n): + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + b = tl.trans(b) + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_N + b_ptrs += BLOCK_SIZE_N + scales_ptrs += BLOCK_SIZE_N + zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) + + c = accumulator.to(tl.float16) + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): + output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),) + matmul_248_kernel[grid](input, qweight, output, + scales, qzeros, g_idx, + input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, + input.stride(0), input.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), qzeros.stride(0)) + return output \ No newline at end of file From 9a02a88fb8484de9cb3b3dd1470eb16d69f633c9 Mon Sep 17 00:00:00 2001 From: John Smith Date: Thu, 6 Apr 2023 12:56:27 +0800 Subject: [PATCH 06/19] add patch for encode function to remove eos token at the beginning of left side --- text-generation-webui/custom_monkey_patch.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/text-generation-webui/custom_monkey_patch.py b/text-generation-webui/custom_monkey_patch.py index 6f586e3..0f4d370 100644 --- a/text-generation-webui/custom_monkey_patch.py +++ b/text-generation-webui/custom_monkey_patch.py @@ -5,6 +5,8 @@ from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear from peft import PeftModel from peft.tuners.lora import Linear4bitLt +patch_encode_func = False + def load_model_llama(*args, **kwargs): config_path = '../llama-13b-4bit/' @@ -41,4 +43,15 @@ shared.settings['name2'] = 'Assistant' shared.settings['chat_prompt_size_max'] = 2048 shared.settings['chat_prompt_size'] = 2048 +if patch_encode_func: + from modules import text_generation + text_generation.encode_old = text_generation.encode + def encode_patched(*args, **kwargs): + input_ids = text_generation.encode_old(*args, **kwargs) + if input_ids[0,0] == 0: + input_ids = input_ids[:, 1:] + return input_ids + text_generation.encode = encode_patched + print('Encode Function Patched.') + print('Monkey Patch Completed.') From 8020b3ec3b997084080bdb5b9cd6c4cc83d881ca Mon Sep 17 00:00:00 2001 From: John Smith Date: Thu, 6 Apr 2023 13:57:32 +0800 Subject: [PATCH 07/19] Update README.md --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 2326cde..27dcdf6 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ Made some adjust for the code in peft and gptq for llama, and make it possible f * Added V2 model support (with groupsize, both inference + finetune) * Added some options on finetune: set default to use eos_token instead of padding, add resume_checkpoint to continue training * Added offload support. load_llama_model_4bit_low_ram_and_offload_to_cpu function can be used. +* Added monkey patch for text generation webui for fixing initial eos token issue. # Requirements gptq-for-llama
@@ -67,3 +68,10 @@ Use the command to run ``` python server.py ``` + +# Flash Attention + +It seems that we can apply a monkey patch for llama model. To use it, simply download the file from [MonkeyPatch](https://github.com/lm-sys/FastChat/blob/daa9c11080ceced2bd52c3e0027e4f64b1512683/fastchat/train/llama_flash_attn_monkey_patch.py). And also, flash-attention is needed, and currently do not support pytorch 2.0. +``` +pip install flash-attn +``` From 3ea18575c7af99356b52f884dc9dc6eec8c50852 Mon Sep 17 00:00:00 2001 From: yamashi Date: Thu, 6 Apr 2023 13:49:12 +0200 Subject: [PATCH 08/19] Use flash attention monkeypatch --- finetune.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/finetune.py b/finetune.py index f374e2b..998b1d6 100644 --- a/finetune.py +++ b/finetune.py @@ -16,6 +16,9 @@ } ] """ +from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn + +replace_llama_attn_with_flash_attn() import sys From 7b18b39dd8f29545f3e728a7777a100cfe64bb0a Mon Sep 17 00:00:00 2001 From: yamashi Date: Thu, 6 Apr 2023 13:49:36 +0200 Subject: [PATCH 09/19] Create llama_flash_attn_monkey_patch.py --- monkeypatch/llama_flash_attn_monkey_patch.py | 143 +++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 monkeypatch/llama_flash_attn_monkey_patch.py diff --git a/monkeypatch/llama_flash_attn_monkey_patch.py b/monkeypatch/llama_flash_attn_monkey_patch.py new file mode 100644 index 0000000..1d48bb5 --- /dev/null +++ b/monkeypatch/llama_flash_attn_monkey_patch.py @@ -0,0 +1,143 @@ +from typing import List, Optional, Tuple + +import torch +from torch import nn + +import transformers +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb + +from einops import rearrange + +from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads}).") + self.q_proj = nn.Linear( + hidden_size, + num_heads * self.head_dim, + bias=False, + ) + self.k_proj = nn.Linear( + hidden_size, + num_heads * self.head_dim, + bias=False, + ) + self.v_proj = nn.Linear( + hidden_size, + num_heads * self.head_dim, + bias=False, + ) + self.o_proj = nn.Linear( + num_heads * self.head_dim, + hidden_size, + bias=False, + ) + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel + + attention_mask: [bsz, q_len] + """ + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view( + bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view( + bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view( + bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + # [bsz, q_len, nh, hd] + # [bsz, nh, q_len, hd] + + kv_seq_len = key_states.shape[-2] + offset = 0 + if past_key_value is not None: + offset = past_key_value[0].shape[-2] + kv_seq_len += offset + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, + key_states, + cos, + sin, + offset=offset) + # [bsz, nh, t, hd] + assert not output_attentions, "output_attentions is not supported" + assert not use_cache, "use_cache is not supported" + assert past_key_value is None, "past_key_value is not supported" + + # Flash attention codes from + # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py + + # transform the data into the format required by flash attention + qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + # We have disabled _prepare_decoder_attention_mask in LlamaModel + # the attention_mask should be the same as the key_padding_mask + key_padding_mask = attention_mask + + + if key_padding_mask is None: + qkv = rearrange(qkv, 'b s ... -> (b s) ...') + max_s = q_len + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, + device=qkv.device) + output = flash_attn_unpadded_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, + softmax_scale=None, causal=True + ) + output = rearrange(output, '(b s) ... -> b s ...', b=bsz) + else: + nheads = qkv.shape[-2] + x = rearrange(qkv, 'b s three h d -> b s (three h d)') + x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) + x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) + output_unpad = flash_attn_unpadded_qkvpacked_func( + x_unpad, cu_q_lens, max_s, 0.0, + softmax_scale=None, causal=True + ) + output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), + indices, bsz, q_len), + 'b s (h d) -> b s h d', h=nheads) + return self.o_proj(rearrange(output, + 'b s h d -> b s (h d)')), None, None + + +# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask +def _prepare_decoder_attention_mask(self, attention_mask, input_shape, + inputs_embeds, past_key_values_length): + # [bsz, seq_len] + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention From 30bf938d03ada62a680820529b0db025484a67c1 Mon Sep 17 00:00:00 2001 From: yamashi Date: Thu, 6 Apr 2023 13:50:25 +0200 Subject: [PATCH 10/19] Update requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 605c0d1..2db1f01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ bitsandbytes datasets sentencepiece safetensors +flash-attn git+https://github.com/huggingface/transformers.git git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit git+https://github.com/sterlind/peft.git From 7770e76c9c143eaab04657d082d3f3406ebede71 Mon Sep 17 00:00:00 2001 From: yamashi Date: Thu, 6 Apr 2023 17:32:01 +0200 Subject: [PATCH 11/19] Fix args of flash attention --- monkeypatch/llama_flash_attn_monkey_patch.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/monkeypatch/llama_flash_attn_monkey_patch.py b/monkeypatch/llama_flash_attn_monkey_patch.py index 1d48bb5..dd2589d 100644 --- a/monkeypatch/llama_flash_attn_monkey_patch.py +++ b/monkeypatch/llama_flash_attn_monkey_patch.py @@ -4,7 +4,7 @@ import torch from torch import nn import transformers -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb +from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb from einops import rearrange @@ -16,13 +16,14 @@ class LlamaAttention(nn.Module): def __init__( self, - hidden_size: int, - num_heads: int, + config: LlamaConfig, ): super().__init__() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads self.hidden_size = hidden_size self.num_heads = num_heads - self.head_dim = hidden_size // num_heads + self.head_dim = self.hidden_size // num_heads if (self.head_dim * num_heads) != self.hidden_size: raise ValueError( From 2bf5d42f287127427e09951c78c4b023203b7b22 Mon Sep 17 00:00:00 2001 From: yamashi Date: Thu, 6 Apr 2023 17:46:15 +0200 Subject: [PATCH 12/19] Add position_ids to flash attention --- monkeypatch/llama_flash_attn_monkey_patch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monkeypatch/llama_flash_attn_monkey_patch.py b/monkeypatch/llama_flash_attn_monkey_patch.py index dd2589d..0c80227 100644 --- a/monkeypatch/llama_flash_attn_monkey_patch.py +++ b/monkeypatch/llama_flash_attn_monkey_patch.py @@ -60,6 +60,7 @@ class LlamaAttention(nn.Module): hidden_states: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], @@ -80,16 +81,15 @@ class LlamaAttention(nn.Module): # [bsz, nh, q_len, hd] kv_seq_len = key_states.shape[-2] - offset = 0 if past_key_value is not None: - offset = past_key_value[0].shape[-2] - kv_seq_len += offset + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, - offset=offset) + position_ids) # [bsz, nh, t, hd] assert not output_attentions, "output_attentions is not supported" assert not use_cache, "use_cache is not supported" From 778035152d7ba63a2f980eea98142d5e6eec7fe3 Mon Sep 17 00:00:00 2001 From: yamashi Date: Fri, 7 Apr 2023 00:42:34 +0200 Subject: [PATCH 13/19] Update arg_parser.py --- arg_parser.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/arg_parser.py b/arg_parser.py index a985380..be18154 100644 --- a/arg_parser.py +++ b/arg_parser.py @@ -66,6 +66,8 @@ def parse_commandline(): # Multi GPU Support parser_training.add_argument("--local_rank", type=int, default=0, help="local rank if using torch.distributed.launch") + + parser_training.add_argument("--flash_attention", help="enables flash attention, can improve performance and reduce VRAM use") return vars(parser.parse_args()) @@ -102,4 +104,5 @@ def get_config() -> Finetune4bConfig: use_eos_token=args["use_eos_token"]!=0, groupsize=args["groupsize"], local_rank=args["local_rank"], + flash_attention=args["flash_attention"], ) From 95cd390d2528dcaaa618b6c2f4765b136e4e313b Mon Sep 17 00:00:00 2001 From: yamashi Date: Fri, 7 Apr 2023 00:43:15 +0200 Subject: [PATCH 14/19] Update Finetune4bConfig.py --- Finetune4bConfig.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Finetune4bConfig.py b/Finetune4bConfig.py index a8a33bf..09c535c 100644 --- a/Finetune4bConfig.py +++ b/Finetune4bConfig.py @@ -15,7 +15,7 @@ class Finetune4bConfig: warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int, checkpoint: bool, skip: bool, verbose: bool, txt_row_thd: int, use_eos_token: bool, groupsize: int, - local_rank: int, + local_rank: int, flash_attention: bool ): """ Args: @@ -48,6 +48,7 @@ class Finetune4bConfig: use_eos_token (bool): Use Eos token instead of padding with 0 groupsize (int): Group size of V2 model, use -1 to load V1 model local_rank (int): local rank if using torch.distributed.launch + flash_attention (bool): Enables flash attention """ self.dataset = dataset self.ds_type = ds_type @@ -84,6 +85,7 @@ class Finetune4bConfig: if self.ddp: self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size self.groupsize = groupsize + self.flash_attention = flash_attention def __str__(self) -> str: From c5aa7fb6951ac573089daf7a1aa990bc68270f1f Mon Sep 17 00:00:00 2001 From: yamashi Date: Fri, 7 Apr 2023 00:43:36 +0200 Subject: [PATCH 15/19] Update finetune.py --- finetune.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/finetune.py b/finetune.py index 998b1d6..eea0e66 100644 --- a/finetune.py +++ b/finetune.py @@ -16,9 +16,13 @@ } ] """ -from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn +# Early load config to replace attn if needed +from arg_parser import get_config +ft_config = get_config() -replace_llama_attn_with_flash_attn() +if ft_config.flash_attention: + from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn + replace_llama_attn_with_flash_attn() import sys @@ -32,10 +36,9 @@ from autograd_4bit import load_llama_model_4bit_low_ram from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel # ! Config -from arg_parser import get_config import train_data -ft_config = get_config() + # * Show loaded parameters if ft_config.local_rank == 0: From dba3773b30bd0b72a744d9d77da56c7b2681ef59 Mon Sep 17 00:00:00 2001 From: John Smith Date: Fri, 7 Apr 2023 15:34:06 +0800 Subject: [PATCH 16/19] add triton backend support for v2 model --- .gitignore | 4 +- Finetune4bConfig.py | 5 +- arg_parser.py | 7 +- .../autograd_4bit_v1.py => autograd_4bit.py | 124 ++++- autograd_4bit/__init__.py | 21 - autograd_4bit/autograd_4bit_v2.py | 221 --------- custom_autotune.py | 167 +++++++ finetune.py | 16 +- requirements.txt | 1 + triton_utils.py | 449 ++++++++++-------- 10 files changed, 536 insertions(+), 479 deletions(-) rename autograd_4bit/autograd_4bit_v1.py => autograd_4bit.py (57%) delete mode 100644 autograd_4bit/__init__.py delete mode 100644 autograd_4bit/autograd_4bit_v2.py create mode 100644 custom_autotune.py diff --git a/.gitignore b/.gitignore index fab98f9..e531c1e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ alpaca_lora/ repository/ __pycache__/ llama-13b-4bit -llama-13b-4bit.pt \ No newline at end of file +llama-13b-4bit.pt +text-generation-webui/ +repository/ diff --git a/Finetune4bConfig.py b/Finetune4bConfig.py index 06eb26e..459102a 100644 --- a/Finetune4bConfig.py +++ b/Finetune4bConfig.py @@ -15,7 +15,7 @@ class Finetune4bConfig: warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int, checkpoint: bool, skip: bool, verbose: bool, txt_row_thd: int, use_eos_token: bool, groupsize: int, - local_rank: int, flash_attention: bool + local_rank: int, flash_attention: bool, backend: str ): """ Args: @@ -86,6 +86,7 @@ class Finetune4bConfig: self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size self.groupsize = groupsize self.flash_attention = flash_attention + self.backend = backend def __str__(self) -> str: @@ -98,5 +99,5 @@ class Finetune4bConfig: f"{self.logging_steps=}\n" +\ f"{self.checkpoint=}\n{self.skip=}\n" +\ f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}\n" +\ - f"{self.groupsize=}\n" + f"{self.groupsize=}\n{self.backend=}\n" return s.replace("self.", "") diff --git a/arg_parser.py b/arg_parser.py index be18154..b83c939 100644 --- a/arg_parser.py +++ b/arg_parser.py @@ -67,7 +67,11 @@ def parse_commandline(): # Multi GPU Support parser_training.add_argument("--local_rank", type=int, default=0, help="local rank if using torch.distributed.launch") - parser_training.add_argument("--flash_attention", help="enables flash attention, can improve performance and reduce VRAM use") + # Flash Attention + parser_training.add_argument("--flash_attention", action="store_true", help="enables flash attention, can improve performance and reduce VRAM use") + + # Train Backend + parser_training.add_argument("--backend", type=str, default='cuda', help="Backend to use. Triton or Cuda.") return vars(parser.parse_args()) @@ -105,4 +109,5 @@ def get_config() -> Finetune4bConfig: groupsize=args["groupsize"], local_rank=args["local_rank"], flash_attention=args["flash_attention"], + backend=args["backend"], ) diff --git a/autograd_4bit/autograd_4bit_v1.py b/autograd_4bit.py similarity index 57% rename from autograd_4bit/autograd_4bit_v1.py rename to autograd_4bit.py index 6432f2d..cf0faa1 100644 --- a/autograd_4bit/autograd_4bit_v1.py +++ b/autograd_4bit.py @@ -2,12 +2,15 @@ import matmul_utils_4bit as mm4b import torch import torch.nn as nn import time +import math +from torch.cuda.amp import custom_bwd, custom_fwd -class AutogradMatmul4bit(torch.autograd.Function): +class AutogradMatmul4bitCuda(torch.autograd.Function): @staticmethod - def forward(ctx, x, qweight, scales, zeros, groupsize=-1): + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, x, qweight, scales, zeros, g_idx, bits, maxq, groupsize=-1): ctx.save_for_backward(qweight, scales, zeros) ctx.groupsize = groupsize if groupsize == -1: @@ -18,42 +21,116 @@ class AutogradMatmul4bit(torch.autograd.Function): return output @staticmethod + @custom_bwd def backward(ctx, grad_output): qweight, scales, zeros = ctx.saved_tensors groupsize = ctx.groupsize - if groupsize == -1: - grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True) - else: - grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, groupsize=groupsize, transpose=True) - return grad, None, None, None, None + if ctx.needs_input_grad[0]: + if groupsize == -1: + grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True) + else: + grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, groupsize=groupsize, transpose=True) + return grad, None, None, None, None, None, None, None + + +try: + import triton_utils as tu + + class AutogradMatmul4bitTriton(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, x, qweight, scales, qzeros, g_idx, bits, maxq, groupsize=-1): + output = tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq) + ctx.save_for_backward(qweight, scales, qzeros, g_idx) + ctx.bits, ctx.maxq = bits, maxq + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + qweight, scales, qzeros, g_idx = ctx.saved_tensors + bits, maxq = ctx.bits, ctx.maxq + grad_input = None + + if ctx.needs_input_grad[0]: + grad_input = tu.triton_matmul_transpose(grad_output, qweight, scales, qzeros, g_idx, bits, maxq) + return grad_input, None, None, None, None, None, None, None + +except ImportError: + print('Triton not found. Please run "pip install triton".') + + +AutogradMatmul4bit = AutogradMatmul4bitCuda +backend = 'cuda' + + +def switch_backend_to(to_backend): + global AutogradMatmul4bit + global backend + if to_backend == 'cuda': + AutogradMatmul4bit = AutogradMatmul4bitCuda + backend = 'cuda' + print('Using CUDA implementation.') + elif to_backend == 'triton': + # detect if AutogradMatmul4bitTriton is defined + if 'AutogradMatmul4bitTriton' not in globals(): + raise ValueError('Triton not found. Please install triton_utils.') + AutogradMatmul4bit = AutogradMatmul4bitTriton + backend = 'triton' + print('Using Triton implementation.') + else: + raise ValueError('Backend not supported.') + + +def matmul4bit_with_backend(x, qweight, scales, qzeros, g_idx, bits, maxq, groupsize): + if backend == 'cuda': + return mm4b.matmul4bit(x, qweight, scales, qzeros, groupsize) + elif backend == 'triton': + assert qzeros.dtype == torch.int32 + return tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq) + else: + raise ValueError('Backend not supported.') # Assumes layer is perfectly divisible into 256 * 256 blocks class Autograd4bitQuantLinear(nn.Module): - def __init__(self, in_features, out_features, groupsize=None): + def __init__(self, in_features, out_features, groupsize=-1): super().__init__() bits = 4 self.in_features = in_features self.out_features = out_features self.bits = bits - self.register_buffer('zeros', torch.empty((out_features, 1))) - self.register_buffer('scales', torch.empty((out_features, 1))) - self.bias = nn.Parameter(torch.empty(out_features)) + self.maxq = 2 ** self.bits - 1 + self.groupsize = groupsize + if groupsize == -1: + self.register_buffer('zeros', torch.empty((out_features, 1))) + self.register_buffer('scales', torch.empty((out_features, 1))) + else: + self.register_buffer('qzeros', + torch.empty((math.ceil(in_features/groupsize), out_features // 256 * (bits * 8)), dtype=torch.int32) + ) + self.register_buffer('scales', torch.empty((math.ceil(in_features/groupsize), out_features))) + self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(in_features)], dtype = torch.int32)) + self.register_buffer('bias', torch.empty(out_features)) self.register_buffer( - 'qweight', torch.empty((in_features // 256 * (bits * 8), out_features), dtype=torch.int) + 'qweight', torch.empty((in_features // 256 * (bits * 8), out_features), dtype=torch.int32) ) def forward(self, x): if torch.is_grad_enabled(): out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, - self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize) - out += self.bias + self.qzeros if self.groupsize != -1 else self.zeros, + self.g_idx, self.bits, self.maxq, + self.groupsize) else: - out = mm4b.matmul4bit(x, self.qweight, self.scales, - self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize) - out += self.bias + out = matmul4bit_with_backend(x, self.qweight, self.scales, + self.qzeros if self.groupsize != -1 else self.zeros, + self.g_idx, self.bits, self.maxq, + self.groupsize) + out += self.bias return out @@ -75,7 +152,8 @@ def model_to_half(model): model.half() for n, m in model.named_modules(): if isinstance(m, Autograd4bitQuantLinear): - m.zeros = m.zeros.half() + if m.groupsize == -1: + m.zeros = m.zeros.half() m.scales = m.scales.half() m.bias = m.bias.half() print('Converted as Half.') @@ -85,7 +163,8 @@ def model_to_float(model): model.float() for n, m in model.named_modules(): if isinstance(m, Autograd4bitQuantLinear): - m.zeros = m.zeros.float() + if m.groupsize == -1: + m.zeros = m.zeros.float() m.scales = m.scales.float() m.bias = m.bias.float() print('Converted as Float.') @@ -137,7 +216,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=Fa return model, tokenizer -def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None): +def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None): import accelerate from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer @@ -176,7 +255,8 @@ def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lo print('Apply half ...') for n, m in model.named_modules(): if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)): - m.zeros = m.zeros.half() + if m.groupsize == -1: + m.zeros = m.zeros.half() m.scales = m.scales.half() m.bias = m.bias.half() @@ -206,3 +286,5 @@ def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lo print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") return model, tokenizer + +load_llama_model_4bit_low_ram_and_offload_to_cpu = load_llama_model_4bit_low_ram_and_offload diff --git a/autograd_4bit/__init__.py b/autograd_4bit/__init__.py deleted file mode 100644 index bee84b3..0000000 --- a/autograd_4bit/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -import os -from colorama import init, Fore, Back, Style -init(autoreset=True) - -try: - GPTQ_VERSION = int(os.environ["GPTQ_VERSION"]) -except: - print(Style.BRIGHT + Fore.YELLOW + "GPTQ_VERSION environment not provided. Fallback to GPTQv1") - GPTQ_VERSION = 1 # Fallback version - -loader = None - - -if GPTQ_VERSION == 1: - from .autograd_4bit_v1 import Autograd4bitQuantLinear, load_llama_model_4bit_low_ram - print(Style.BRIGHT + Fore.GREEN + "GPTQv1 set") -elif GPTQ_VERSION == 2: - from .autograd_4bit_v2 import Autograd4bitQuantLinear, load_llama_model_4bit_low_ram - print(Style.BRIGHT + Fore.GREEN + "GPTQv2 set") -else: - raise ValueError("GPTQ_VERSION not set or invalid") \ No newline at end of file diff --git a/autograd_4bit/autograd_4bit_v2.py b/autograd_4bit/autograd_4bit_v2.py deleted file mode 100644 index 20c253d..0000000 --- a/autograd_4bit/autograd_4bit_v2.py +++ /dev/null @@ -1,221 +0,0 @@ -from colorama import init, Fore, Back, Style -import torch -import torch.nn as nn -import time -import math -import triton -from triton_utils import matmul_248_kernel, trans_matmul_248_kernel - - -class AutogradMatmul4bit(torch.autograd.Function): - @staticmethod - def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): - output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16) - grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),) - matmul_248_kernel[grid](input, qweight, output, - scales, qzeros, g_idx, - input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, - input.stride(0), input.stride(1), - qweight.stride(0), qweight.stride(1), - output.stride(0), output.stride(1), - scales.stride(0), qzeros.stride(0)) - - ctx.save_for_backward(qweight, scales, qzeros, g_idx) - ctx.input_shape, ctx.bits,ctx.maxq = input.shape,bits, maxq - return output - - @staticmethod - def backward(ctx, grad_output): - input_shape, bits, maxq = ctx.input_shape, ctx.bits, ctx.maxq - qweight, scales, qzeros, g_idx = ctx.saved_tensors - grade_input = None - - if ctx.needs_input_grad[0]: - grade_input = torch.empty((input_shape[0], input_shape[1]), device='cuda', dtype=torch.float32) - grid = lambda META: (triton.cdiv(input_shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(input_shape[1], META['BLOCK_SIZE_K']),) - trans_matmul_248_kernel[grid](grad_output, qweight, grade_input, - scales, qzeros, g_idx, - input_shape[0], qweight.shape[1], input_shape[1], bits, maxq, - grad_output.stride(0), grad_output.stride(1), - qweight.stride(0), qweight.stride(1), - grade_input.stride(0), grade_input.stride(1), - scales.stride(0), qzeros.stride(0)) - return grade_input, None, None, None, None, None, None - - -class Autograd4bitQuantLinear(nn.Module): - - def __init__(self, in_features, out_features, groupsize, bias=True): - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.bits = 4 # Hardcoded 4-bits quantizations - self.maxq = 2 ** self.bits - 1 - self.groupsize = groupsize if groupsize != -1 else in_features - - self.register_buffer('qweight', torch.zeros((in_features // 32 * self.bits, out_features), dtype=torch.int32)) - self.register_buffer('qzeros', torch.zeros((math.ceil(in_features / self.groupsize), out_features // 32 * self.bits), dtype=torch.int32)) - self.register_buffer('scales', torch.zeros((math.ceil(in_features / self.groupsize), out_features), dtype=torch.float16)) - self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(in_features)], dtype = torch.int32)) - if bias: - self.register_buffer('bias', torch.zeros(out_features,dtype=torch.float16)) - else: - self.bias = None - - def forward(self, x): - out_shape = x.shape[:-1] + (self.out_features, ) - out = AutogradMatmul4bit.apply(x.reshape(-1,x.shape[-1]), self.qweight, self.scales, - self.qzeros, self.g_idx, self.bits, self.maxq) - out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) - - -def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1): - if isinstance(module, Autograd4bitQuantLinear): - return - for attr in dir(module): - tmp = getattr(module, attr) - name1 = name + '.' + attr if name != '' else attr - if name1 in names: - setattr( - module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize) - ) - for name1, child in module.named_children(): - make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize) - - -def model_to_half(model): - model.half() - for n, m in model.named_modules(): - if isinstance(m, Autograd4bitQuantLinear): - m.qzeros = m.qzeros.half() - m.scales = m.scales.half() - m.bias = m.bias.half() - print(Style.BRIGHT + Fore.YELLOW + 'Converted as Half.') - - -def model_to_float(model): - model.float() - for n, m in model.named_modules(): - if isinstance(m, Autograd4bitQuantLinear): - m.qzeros = m.qzeros.float() - m.scales = m.scales.float() - m.bias = m.bias.float() - print(Style.BRIGHT + Fore.YELLOW + 'Converted as Float.') - - -def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): - if type(module) in layers: - return {name: module} - res = {} - for name1, child in module.named_children(): - res.update(find_layers( - child, layers=layers, name=name + '.' + name1 if name != '' else name1 - )) - return res - - -def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048): - import accelerate - from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer - - print(Style.BRIGHT + Fore.CYAN + "Loading Model ...") - t0 = time.time() - - with accelerate.init_empty_weights(): - config = LlamaConfig.from_pretrained(config_path) - model = LlamaForCausalLM(config) - model = model.eval() - layers = find_layers(model) - for name in ['lm_head']: - if name in layers: - del layers[name] - make_quant_for_4bit_autograd(model, layers, groupsize=groupsize) - model = accelerate.load_checkpoint_and_dispatch( - model=model, - checkpoint=model_path, - device_map=device_map, - no_split_module_classes=["LlamaDecoderLayer"] - ) - - model.seqlen = seqlen - - if half: - model_to_half(model) - - tokenizer = LlamaTokenizer.from_pretrained(config_path) - tokenizer.truncation_side = 'left' - - print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.") - - return model, tokenizer - -def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None): - import accelerate - from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer - - if max_memory is None: - max_memory = {0: '24Gib', 'cpu': '48Gib'} - - print(Style.BRIGHT + Fore.CYAN + "Loading Model ...") - t0 = time.time() - - with accelerate.init_empty_weights(): - config = LlamaConfig.from_pretrained(config_path) - model = LlamaForCausalLM(config) - model = model.eval() - layers = find_layers(model) - for name in ['lm_head']: - if name in layers: - del layers[name] - make_quant_for_4bit_autograd(model, layers, groupsize=groupsize) - accelerate.load_checkpoint_in_model(model, checkpoint=model_path, device_map={'': 'cpu'}) - - # rotary_emb fix - for n, m in model.named_modules(): - if 'rotary_emb' in n: - cos_cached = m.cos_cached.clone().cpu() - sin_cached = m.sin_cached.clone().cpu() - break - - if lora_path is not None: - from peft import PeftModel - from peft.tuners.lora import Linear4bitLt - model = PeftModel.from_pretrained(model, lora_path, device_map={'': 'cpu'}, torch_dtype=torch.float32) - print(Style.BRIGHT + Fore.GREEN + '{} Lora Applied.'.format(lora_path)) - - model.seqlen = seqlen - - print('Apply half ...') - for n, m in model.named_modules(): - if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)): - m.qzeros = m.qzeros.half() - m.scales = m.scales.half() - m.bias = m.bias.half() - - print('Dispatching model ...') - device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"]) - model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True, main_device=0) - torch.cuda.empty_cache() - print(Style.BRIGHT + Fore.YELLOW + 'Total {:.2f} Gib VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024)) - - # rotary_emb fix - for n, m in model.named_modules(): - if 'rotary_emb' in n: - if getattr(m, '_hf_hook', None): - if isinstance(m._hf_hook, accelerate.hooks.SequentialHook): - hooks = m._hf_hook.hooks - else: - hooks = [m._hf_hook] - for hook in hooks: - if hook.offload: - if n + '.sin_cached' not in hook.weights_map.dataset.state_dict.keys(): - hook.weights_map.dataset.state_dict[n + '.sin_cached'] = sin_cached.clone().cpu() - hook.weights_map.dataset.state_dict[n + '.cos_cached'] = cos_cached.clone().cpu() - - tokenizer = LlamaTokenizer.from_pretrained(config_path) - tokenizer.truncation_side = 'left' - - print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.") - - return model, tokenizer diff --git a/custom_autotune.py b/custom_autotune.py new file mode 100644 index 0000000..8bafea7 --- /dev/null +++ b/custom_autotune.py @@ -0,0 +1,167 @@ +#https://github.com/fpgaminer/GPTQ-triton +""" +Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. +""" + +import builtins +import math +import time +from typing import Dict + +import triton + + +class Autotuner(triton.KernelInterface): + def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False): + ''' + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. + 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results + ''' + if not configs: + self.configs = [triton.Config({}, num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.nearest_power_of_two = nearest_power_of_two + self.cache = {} + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] + if 'early_config_prune' in prune_configs_by: + early_config_prune = prune_configs_by['early_config_prune'] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.fn = fn + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) + try: + # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses + # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default + return triton.testing.do_bench(kernel_call, rep=40) + except triton.compiler.OutOfResources: + return float('inf') + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + key = tuple(args[i] for i in self.key_idx) + + # This reduces the amount of autotuning by rounding the keys to the nearest power of two + # In my testing this gives decent results, and greatly reduces the amount of tuning required + if self.nearest_power_of_two: + key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) + + if key not in self.cache: + # prune configs + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) + for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, + num_warps=config.num_warps) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + .. highlight:: python + .. code-block:: python + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple time. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + reset the value of the provided tensor to `zero` before running any configuration. + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + """ + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two) + + return decorator diff --git a/finetune.py b/finetune.py index 82dbde5..576ec31 100644 --- a/finetune.py +++ b/finetune.py @@ -24,6 +24,12 @@ if ft_config.flash_attention: from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() +import autograd_4bit +if ft_config.backend.lower() == 'triton': + autograd_4bit.switch_backend_to('triton') +else: + autograd_4bit.switch_backend_to('cuda') + import sys import peft @@ -65,10 +71,16 @@ lora_config = LoraConfig( if ft_config.lora_apply_dir is None: model = get_peft_model(model, lora_config) else: + device_map = ft_config.device_map if ft_config.ddp: - model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map="auto", torch_dtype=torch.float32) # ! Direct copy from inference.py + device_map = {'': 0} else: - model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map={'': 0}, torch_dtype=torch.float32) + if torch.cuda.device_count() > 1: + device_map = "auto" + else: + device_map = {'': 0} + print('Device map for lora:', device_map) + model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map=device_map, torch_dtype=torch.float32) print(ft_config.lora_apply_dir, 'loaded') diff --git a/requirements.txt b/requirements.txt index 95a23a6..9832253 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ sentencepiece safetensors flash-attn triton +colorama git+https://github.com/huggingface/transformers.git git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit git+https://github.com/sterlind/peft.git diff --git a/triton_utils.py b/triton_utils.py index 57f9e66..9722628 100644 --- a/triton_utils.py +++ b/triton_utils.py @@ -1,210 +1,239 @@ -import triton -import triton.language as tl -import torch - -# code based https://github.com/fpgaminer/GPTQ-triton -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - ], - key=['M', 'N', 'K'], -) - -@triton.jit -def matmul_248_kernel(a_ptr, b_ptr, c_ptr, - scales_ptr, zeros_ptr, g_ptr, - M, N, K, bits, maxq, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_scales, stride_zeros, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = (offs_am[:, None] < M) - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - g_ptrs = g_ptr + offs_k - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + offs_bn[None, :] - zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) - - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, num_pid_k): - g_idx = tl.load(g_ptrs) - - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) - - a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros) * scales # Scale and shift - # ! Convert to fp16 - b = b.to(tl.float16) - a = a.to(tl.float16) - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - g_ptrs += BLOCK_SIZE_K - - c = accumulator.to(tl.float16) - c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) - -# code based https://github.com/fpgaminer/GPTQ-triton -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - ], - key=['M', 'N', 'K'], -) - -@triton.jit -def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, - scales_ptr, zeros_ptr, g_ptr, - M, N, K, bits, maxq, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_scales, stride_zeros, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, N) float16 - B is of shape (K//8, N) int32 - C is of shape (M, K) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_k - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_k = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - offs_n = tl.arange(0, BLOCK_SIZE_N) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) - a_mask = (offs_am[:, None] < M) - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - g_ptrs = g_ptr + offs_bk - g_idx = tl.load(g_ptrs) - - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales - zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros - - shifter = (offs_bk % infearure_per_bits) * bits - zeros_shifter = (offs_n % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) - - for k in range(0, num_pid_n): - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) - - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros) * scales # Scale and shift - b = tl.trans(b) - # ! Convert to fp16 - b = b.to(tl.float16) - a = a.to(tl.float16) - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_N - b_ptrs += BLOCK_SIZE_N - scales_ptrs += BLOCK_SIZE_N - zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) - - c = accumulator.to(tl.float16) - c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) - tl.store(c_ptrs, c, mask=c_mask) - - -def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): - output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16) - grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),) - matmul_248_kernel[grid](input, qweight, output, - scales, qzeros, g_idx, - input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, - input.stride(0), input.stride(1), - qweight.stride(0), qweight.stride(1), - output.stride(0), output.stride(1), - scales.stride(0), qzeros.stride(0)) - return output +import triton +import triton.language as tl +import torch +import custom_autotune + + +# code based https://github.com/fpgaminer/GPTQ-triton +@custom_autotune.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + # These provided a benefit on a 3090 + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N'], + nearest_power_of_two=True, +) + + +@triton.jit +def matmul_248_kernel(a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, g_ptr, + M, N, K, bits, maxq, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales, stride_zeros, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_bn[None, :] + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + # ! Convert to fp16 + b = b.to(tl.float16) + a = a.to(tl.float16) + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c = accumulator.to(tl.float16) + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# code based https://github.com/fpgaminer/GPTQ-triton +@custom_autotune.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + # These provided a benefit on a 3090 + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'K'], + nearest_power_of_two=True, +) + + +@triton.jit +def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, g_ptr, + M, N, K, bits, maxq, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales, stride_zeros, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, N) float16 + B is of shape (K//8, N) int32 + C is of shape (M, K) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_k + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_k = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_n = tl.arange(0, BLOCK_SIZE_N) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_bk + g_idx = tl.load(g_ptrs) + + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales + zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros + + shifter = (offs_bk % infearure_per_bits) * bits + zeros_shifter = (offs_n % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + + for k in range(0, num_pid_n): + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + b = tl.trans(b) + # ! Convert to fp16 + b = b.to(tl.float16) + a = a.to(tl.float16) + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_N + b_ptrs += BLOCK_SIZE_N + scales_ptrs += BLOCK_SIZE_N + zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) + + c = accumulator.to(tl.float16) + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) + tl.store(c_ptrs, c, mask=c_mask) + + +def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): + assert input.shape[1] == qweight.shape[0] * 32 // bits + output = torch.empty((input.shape[0], qweight.shape[1]), device=scales.device, dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),) + matmul_248_kernel[grid](input, qweight, output, + scales, qzeros, g_idx, + input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, + input.stride(0), input.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), qzeros.stride(0)) + return output + + +def triton_matmul_transpose(input, qweight, scales, qzeros, g_idx, bits, maxq): + assert input.shape[1] == qweight.shape[1] + output_shape = (input.shape[0], qweight.shape[0] * 32 // bits) + output = torch.empty((output_shape[0], output_shape[1]), device=scales.device, dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape[1], META['BLOCK_SIZE_K']),) + trans_matmul_248_kernel[grid](input, qweight, output, + scales, qzeros, g_idx, + input.shape[0], qweight.shape[1], output_shape[1], bits, maxq, + input.stride(0), input.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), qzeros.stride(0)) + return output From 32904da1ff2fb033439d88da8d6256163d83a8ea Mon Sep 17 00:00:00 2001 From: John Smith Date: Fri, 7 Apr 2023 15:50:55 +0800 Subject: [PATCH 17/19] fix bug on triton matmul --- autograd_4bit.py | 1 + triton_utils.py | 19 +++++++++++++------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/autograd_4bit.py b/autograd_4bit.py index cf0faa1..55430ec 100644 --- a/autograd_4bit.py +++ b/autograd_4bit.py @@ -44,6 +44,7 @@ try: output = tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq) ctx.save_for_backward(qweight, scales, qzeros, g_idx) ctx.bits, ctx.maxq = bits, maxq + output = output.clone() return output @staticmethod diff --git a/triton_utils.py b/triton_utils.py index 9722628..7afcf46 100644 --- a/triton_utils.py +++ b/triton_utils.py @@ -211,7 +211,9 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): - assert input.shape[1] == qweight.shape[0] * 32 // bits + assert input.shape[-1] == qweight.shape[0] * 32 // bits + outshape = input.shape[:-1] + (qweight.shape[1],) + input = input.reshape(-1, input.shape[-1]) output = torch.empty((input.shape[0], qweight.shape[1]), device=scales.device, dtype=torch.float16) grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),) matmul_248_kernel[grid](input, qweight, output, @@ -221,19 +223,24 @@ def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): qweight.stride(0), qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) + output = output.reshape(outshape) return output def triton_matmul_transpose(input, qweight, scales, qzeros, g_idx, bits, maxq): - assert input.shape[1] == qweight.shape[1] - output_shape = (input.shape[0], qweight.shape[0] * 32 // bits) - output = torch.empty((output_shape[0], output_shape[1]), device=scales.device, dtype=torch.float16) - grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape[1], META['BLOCK_SIZE_K']),) + assert input.shape[-1] == qweight.shape[1] + out_dim = qweight.shape[0] * 32 // bits + outshape = input.shape[:-1] + (out_dim,) + input = input.reshape(-1, input.shape[-1]) + output_shape_mid = (input.shape[0], out_dim) + output = torch.empty((output_shape_mid[0], output_shape_mid[1]), device=scales.device, dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape_mid[1], META['BLOCK_SIZE_K']),) trans_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, - input.shape[0], qweight.shape[1], output_shape[1], bits, maxq, + input.shape[0], qweight.shape[1], output_shape_mid[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) + output = output.reshape(outshape) return output From b01b10eb4d6d132d3118c814d1dc3747b94ec543 Mon Sep 17 00:00:00 2001 From: John Smith Date: Fri, 7 Apr 2023 15:58:38 +0800 Subject: [PATCH 18/19] Colorized output --- autograd_4bit.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/autograd_4bit.py b/autograd_4bit.py index 55430ec..5d19908 100644 --- a/autograd_4bit.py +++ b/autograd_4bit.py @@ -4,6 +4,8 @@ import torch.nn as nn import time import math from torch.cuda.amp import custom_bwd, custom_fwd +from colorama import init, Fore, Back, Style +init(autoreset=True) class AutogradMatmul4bitCuda(torch.autograd.Function): @@ -72,14 +74,14 @@ def switch_backend_to(to_backend): if to_backend == 'cuda': AutogradMatmul4bit = AutogradMatmul4bitCuda backend = 'cuda' - print('Using CUDA implementation.') + print(Style.BRIGHT + Fore.GREEN + 'Using CUDA implementation.') elif to_backend == 'triton': # detect if AutogradMatmul4bitTriton is defined if 'AutogradMatmul4bitTriton' not in globals(): raise ValueError('Triton not found. Please install triton_utils.') AutogradMatmul4bit = AutogradMatmul4bitTriton backend = 'triton' - print('Using Triton implementation.') + print(Style.BRIGHT + Fore.GREEN + 'Using Triton implementation.') else: raise ValueError('Backend not supported.') @@ -157,7 +159,7 @@ def model_to_half(model): m.zeros = m.zeros.half() m.scales = m.scales.half() m.bias = m.bias.half() - print('Converted as Half.') + print(Style.BRIGHT + Fore.YELLOW + 'Converted as Half.') def model_to_float(model): @@ -168,7 +170,7 @@ def model_to_float(model): m.zeros = m.zeros.float() m.scales = m.scales.float() m.bias = m.bias.float() - print('Converted as Float.') + print(Style.BRIGHT + Fore.YELLOW + 'Converted as Float.') def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): @@ -186,7 +188,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=Fa import accelerate from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer - print("Loading Model ...") + print(Style.BRIGHT + Fore.CYAN + "Loading Model ...") t0 = time.time() with accelerate.init_empty_weights(): @@ -213,7 +215,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=Fa tokenizer = LlamaTokenizer.from_pretrained(config_path) tokenizer.truncation_side = 'left' - print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") + print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.") return model, tokenizer @@ -224,7 +226,7 @@ def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path if max_memory is None: max_memory = {0: '24Gib', 'cpu': '48Gib'} - print("Loading Model ...") + print(Style.BRIGHT + Fore.CYAN + "Loading Model ...") t0 = time.time() with accelerate.init_empty_weights(): @@ -249,7 +251,7 @@ def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path from peft import PeftModel from peft.tuners.lora import Linear4bitLt model = PeftModel.from_pretrained(model, lora_path, device_map={'': 'cpu'}, torch_dtype=torch.float32) - print('{} Lora Applied.'.format(lora_path)) + print(Style.BRIGHT + Fore.GREEN + '{} Lora Applied.'.format(lora_path)) model.seqlen = seqlen @@ -265,7 +267,7 @@ def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"]) model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True, main_device=0) torch.cuda.empty_cache() - print('Total {:.2f} Gib VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024)) + print(Style.BRIGHT + Fore.YELLOW + 'Total {:.2f} Gib VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024)) # rotary_emb fix for n, m in model.named_modules(): @@ -284,7 +286,7 @@ def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path tokenizer = LlamaTokenizer.from_pretrained(config_path) tokenizer.truncation_side = 'left' - print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") + print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.") return model, tokenizer From f91d4cbb593b097f5dfb60866a04e90044414da6 Mon Sep 17 00:00:00 2001 From: John Smith Date: Fri, 7 Apr 2023 16:10:36 +0800 Subject: [PATCH 19/19] Update README.md --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 4256ac0..d5336fd 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,8 @@ Made some adjust for the code in peft and gptq for llama, and make it possible f * Added some options on finetune: set default to use eos_token instead of padding, add resume_checkpoint to continue training * Added offload support. load_llama_model_4bit_low_ram_and_offload_to_cpu function can be used. * Added monkey patch for text generation webui for fixing initial eos token issue. +* Added Flash attention support. (Use --flash-attention) +* Added Triton backend to support model using groupsize and act-order. (Use --backend=triton) # Requirements gptq-for-llama
@@ -82,6 +84,4 @@ python server.py # Flash Attention It seems that we can apply a monkey patch for llama model. To use it, simply download the file from [MonkeyPatch](https://github.com/lm-sys/FastChat/blob/daa9c11080ceced2bd52c3e0027e4f64b1512683/fastchat/train/llama_flash_attn_monkey_patch.py). And also, flash-attention is needed, and currently do not support pytorch 2.0. -``` -pip install flash-attn -``` +Just add --flash-attention to use it for finetuning.