add triton backend support for v2 model

This commit is contained in:
John Smith 2023-04-07 15:34:06 +08:00
parent 9351f49542
commit dba3773b30
10 changed files with 536 additions and 479 deletions

4
.gitignore vendored
View File

@ -2,4 +2,6 @@ alpaca_lora/
repository/
__pycache__/
llama-13b-4bit
llama-13b-4bit.pt
llama-13b-4bit.pt
text-generation-webui/
repository/

View File

@ -15,7 +15,7 @@ class Finetune4bConfig:
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
checkpoint: bool, skip: bool, verbose: bool,
txt_row_thd: int, use_eos_token: bool, groupsize: int,
local_rank: int, flash_attention: bool
local_rank: int, flash_attention: bool, backend: str
):
"""
Args:
@ -86,6 +86,7 @@ class Finetune4bConfig:
self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size
self.groupsize = groupsize
self.flash_attention = flash_attention
self.backend = backend
def __str__(self) -> str:
@ -98,5 +99,5 @@ class Finetune4bConfig:
f"{self.logging_steps=}\n" +\
f"{self.checkpoint=}\n{self.skip=}\n" +\
f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}\n" +\
f"{self.groupsize=}\n"
f"{self.groupsize=}\n{self.backend=}\n"
return s.replace("self.", "")

View File

@ -67,7 +67,11 @@ def parse_commandline():
# Multi GPU Support
parser_training.add_argument("--local_rank", type=int, default=0, help="local rank if using torch.distributed.launch")
parser_training.add_argument("--flash_attention", help="enables flash attention, can improve performance and reduce VRAM use")
# Flash Attention
parser_training.add_argument("--flash_attention", action="store_true", help="enables flash attention, can improve performance and reduce VRAM use")
# Train Backend
parser_training.add_argument("--backend", type=str, default='cuda', help="Backend to use. Triton or Cuda.")
return vars(parser.parse_args())
@ -105,4 +109,5 @@ def get_config() -> Finetune4bConfig:
groupsize=args["groupsize"],
local_rank=args["local_rank"],
flash_attention=args["flash_attention"],
backend=args["backend"],
)

View File

@ -2,12 +2,15 @@ import matmul_utils_4bit as mm4b
import torch
import torch.nn as nn
import time
import math
from torch.cuda.amp import custom_bwd, custom_fwd
class AutogradMatmul4bit(torch.autograd.Function):
class AutogradMatmul4bitCuda(torch.autograd.Function):
@staticmethod
def forward(ctx, x, qweight, scales, zeros, groupsize=-1):
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, x, qweight, scales, zeros, g_idx, bits, maxq, groupsize=-1):
ctx.save_for_backward(qweight, scales, zeros)
ctx.groupsize = groupsize
if groupsize == -1:
@ -18,42 +21,116 @@ class AutogradMatmul4bit(torch.autograd.Function):
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
qweight, scales, zeros = ctx.saved_tensors
groupsize = ctx.groupsize
if groupsize == -1:
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
else:
grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, groupsize=groupsize, transpose=True)
return grad, None, None, None, None
if ctx.needs_input_grad[0]:
if groupsize == -1:
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
else:
grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, groupsize=groupsize, transpose=True)
return grad, None, None, None, None, None, None, None
try:
import triton_utils as tu
class AutogradMatmul4bitTriton(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, x, qweight, scales, qzeros, g_idx, bits, maxq, groupsize=-1):
output = tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq)
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.bits, ctx.maxq = bits, maxq
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
qweight, scales, qzeros, g_idx = ctx.saved_tensors
bits, maxq = ctx.bits, ctx.maxq
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = tu.triton_matmul_transpose(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
return grad_input, None, None, None, None, None, None, None
except ImportError:
print('Triton not found. Please run "pip install triton".')
AutogradMatmul4bit = AutogradMatmul4bitCuda
backend = 'cuda'
def switch_backend_to(to_backend):
global AutogradMatmul4bit
global backend
if to_backend == 'cuda':
AutogradMatmul4bit = AutogradMatmul4bitCuda
backend = 'cuda'
print('Using CUDA implementation.')
elif to_backend == 'triton':
# detect if AutogradMatmul4bitTriton is defined
if 'AutogradMatmul4bitTriton' not in globals():
raise ValueError('Triton not found. Please install triton_utils.')
AutogradMatmul4bit = AutogradMatmul4bitTriton
backend = 'triton'
print('Using Triton implementation.')
else:
raise ValueError('Backend not supported.')
def matmul4bit_with_backend(x, qweight, scales, qzeros, g_idx, bits, maxq, groupsize):
if backend == 'cuda':
return mm4b.matmul4bit(x, qweight, scales, qzeros, groupsize)
elif backend == 'triton':
assert qzeros.dtype == torch.int32
return tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq)
else:
raise ValueError('Backend not supported.')
# Assumes layer is perfectly divisible into 256 * 256 blocks
class Autograd4bitQuantLinear(nn.Module):
def __init__(self, in_features, out_features, groupsize=None):
def __init__(self, in_features, out_features, groupsize=-1):
super().__init__()
bits = 4
self.in_features = in_features
self.out_features = out_features
self.bits = bits
self.register_buffer('zeros', torch.empty((out_features, 1)))
self.register_buffer('scales', torch.empty((out_features, 1)))
self.bias = nn.Parameter(torch.empty(out_features))
self.maxq = 2 ** self.bits - 1
self.groupsize = groupsize
if groupsize == -1:
self.register_buffer('zeros', torch.empty((out_features, 1)))
self.register_buffer('scales', torch.empty((out_features, 1)))
else:
self.register_buffer('qzeros',
torch.empty((math.ceil(in_features/groupsize), out_features // 256 * (bits * 8)), dtype=torch.int32)
)
self.register_buffer('scales', torch.empty((math.ceil(in_features/groupsize), out_features)))
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(in_features)], dtype = torch.int32))
self.register_buffer('bias', torch.empty(out_features))
self.register_buffer(
'qweight', torch.empty((in_features // 256 * (bits * 8), out_features), dtype=torch.int)
'qweight', torch.empty((in_features // 256 * (bits * 8), out_features), dtype=torch.int32)
)
def forward(self, x):
if torch.is_grad_enabled():
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales,
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
out += self.bias
self.qzeros if self.groupsize != -1 else self.zeros,
self.g_idx, self.bits, self.maxq,
self.groupsize)
else:
out = mm4b.matmul4bit(x, self.qweight, self.scales,
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
out += self.bias
out = matmul4bit_with_backend(x, self.qweight, self.scales,
self.qzeros if self.groupsize != -1 else self.zeros,
self.g_idx, self.bits, self.maxq,
self.groupsize)
out += self.bias
return out
@ -75,7 +152,8 @@ def model_to_half(model):
model.half()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
m.zeros = m.zeros.half()
if m.groupsize == -1:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
print('Converted as Half.')
@ -85,7 +163,8 @@ def model_to_float(model):
model.float()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
m.zeros = m.zeros.float()
if m.groupsize == -1:
m.zeros = m.zeros.float()
m.scales = m.scales.float()
m.bias = m.bias.float()
print('Converted as Float.')
@ -137,7 +216,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=Fa
return model, tokenizer
def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None):
def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None):
import accelerate
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
@ -176,7 +255,8 @@ def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lo
print('Apply half ...')
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)):
m.zeros = m.zeros.half()
if m.groupsize == -1:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
@ -206,3 +286,5 @@ def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lo
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer
load_llama_model_4bit_low_ram_and_offload_to_cpu = load_llama_model_4bit_low_ram_and_offload

View File

@ -1,21 +0,0 @@
import os
from colorama import init, Fore, Back, Style
init(autoreset=True)
try:
GPTQ_VERSION = int(os.environ["GPTQ_VERSION"])
except:
print(Style.BRIGHT + Fore.YELLOW + "GPTQ_VERSION environment not provided. Fallback to GPTQv1")
GPTQ_VERSION = 1 # Fallback version
loader = None
if GPTQ_VERSION == 1:
from .autograd_4bit_v1 import Autograd4bitQuantLinear, load_llama_model_4bit_low_ram
print(Style.BRIGHT + Fore.GREEN + "GPTQv1 set")
elif GPTQ_VERSION == 2:
from .autograd_4bit_v2 import Autograd4bitQuantLinear, load_llama_model_4bit_low_ram
print(Style.BRIGHT + Fore.GREEN + "GPTQv2 set")
else:
raise ValueError("GPTQ_VERSION not set or invalid")

View File

@ -1,221 +0,0 @@
from colorama import init, Fore, Back, Style
import torch
import torch.nn as nn
import time
import math
import triton
from triton_utils import matmul_248_kernel, trans_matmul_248_kernel
class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16)
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)
matmul_248_kernel[grid](input, qweight, output,
scales, qzeros, g_idx,
input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,
input.stride(0), input.stride(1),
qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0))
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.input_shape, ctx.bits,ctx.maxq = input.shape,bits, maxq
return output
@staticmethod
def backward(ctx, grad_output):
input_shape, bits, maxq = ctx.input_shape, ctx.bits, ctx.maxq
qweight, scales, qzeros, g_idx = ctx.saved_tensors
grade_input = None
if ctx.needs_input_grad[0]:
grade_input = torch.empty((input_shape[0], input_shape[1]), device='cuda', dtype=torch.float32)
grid = lambda META: (triton.cdiv(input_shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(input_shape[1], META['BLOCK_SIZE_K']),)
trans_matmul_248_kernel[grid](grad_output, qweight, grade_input,
scales, qzeros, g_idx,
input_shape[0], qweight.shape[1], input_shape[1], bits, maxq,
grad_output.stride(0), grad_output.stride(1),
qweight.stride(0), qweight.stride(1),
grade_input.stride(0), grade_input.stride(1),
scales.stride(0), qzeros.stride(0))
return grade_input, None, None, None, None, None, None
class Autograd4bitQuantLinear(nn.Module):
def __init__(self, in_features, out_features, groupsize, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.bits = 4 # Hardcoded 4-bits quantizations
self.maxq = 2 ** self.bits - 1
self.groupsize = groupsize if groupsize != -1 else in_features
self.register_buffer('qweight', torch.zeros((in_features // 32 * self.bits, out_features), dtype=torch.int32))
self.register_buffer('qzeros', torch.zeros((math.ceil(in_features / self.groupsize), out_features // 32 * self.bits), dtype=torch.int32))
self.register_buffer('scales', torch.zeros((math.ceil(in_features / self.groupsize), out_features), dtype=torch.float16))
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(in_features)], dtype = torch.int32))
if bias:
self.register_buffer('bias', torch.zeros(out_features,dtype=torch.float16))
else:
self.bias = None
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, )
out = AutogradMatmul4bit.apply(x.reshape(-1,x.shape[-1]), self.qweight, self.scales,
self.qzeros, self.g_idx, self.bits, self.maxq)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1):
if isinstance(module, Autograd4bitQuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
setattr(
module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize)
)
for name1, child in module.named_children():
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize)
def model_to_half(model):
model.half()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
m.qzeros = m.qzeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
print(Style.BRIGHT + Fore.YELLOW + 'Converted as Half.')
def model_to_float(model):
model.float()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
m.qzeros = m.qzeros.float()
m.scales = m.scales.float()
m.bias = m.bias.float()
print(Style.BRIGHT + Fore.YELLOW + 'Converted as Float.')
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(
child, layers=layers, name=name + '.' + name1 if name != '' else name1
))
return res
def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048):
import accelerate
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
print(Style.BRIGHT + Fore.CYAN + "Loading Model ...")
t0 = time.time()
with accelerate.init_empty_weights():
config = LlamaConfig.from_pretrained(config_path)
model = LlamaForCausalLM(config)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant_for_4bit_autograd(model, layers, groupsize=groupsize)
model = accelerate.load_checkpoint_and_dispatch(
model=model,
checkpoint=model_path,
device_map=device_map,
no_split_module_classes=["LlamaDecoderLayer"]
)
model.seqlen = seqlen
if half:
model_to_half(model)
tokenizer = LlamaTokenizer.from_pretrained(config_path)
tokenizer.truncation_side = 'left'
print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer
def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None):
import accelerate
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
if max_memory is None:
max_memory = {0: '24Gib', 'cpu': '48Gib'}
print(Style.BRIGHT + Fore.CYAN + "Loading Model ...")
t0 = time.time()
with accelerate.init_empty_weights():
config = LlamaConfig.from_pretrained(config_path)
model = LlamaForCausalLM(config)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant_for_4bit_autograd(model, layers, groupsize=groupsize)
accelerate.load_checkpoint_in_model(model, checkpoint=model_path, device_map={'': 'cpu'})
# rotary_emb fix
for n, m in model.named_modules():
if 'rotary_emb' in n:
cos_cached = m.cos_cached.clone().cpu()
sin_cached = m.sin_cached.clone().cpu()
break
if lora_path is not None:
from peft import PeftModel
from peft.tuners.lora import Linear4bitLt
model = PeftModel.from_pretrained(model, lora_path, device_map={'': 'cpu'}, torch_dtype=torch.float32)
print(Style.BRIGHT + Fore.GREEN + '{} Lora Applied.'.format(lora_path))
model.seqlen = seqlen
print('Apply half ...')
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)):
m.qzeros = m.qzeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
print('Dispatching model ...')
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True, main_device=0)
torch.cuda.empty_cache()
print(Style.BRIGHT + Fore.YELLOW + 'Total {:.2f} Gib VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024))
# rotary_emb fix
for n, m in model.named_modules():
if 'rotary_emb' in n:
if getattr(m, '_hf_hook', None):
if isinstance(m._hf_hook, accelerate.hooks.SequentialHook):
hooks = m._hf_hook.hooks
else:
hooks = [m._hf_hook]
for hook in hooks:
if hook.offload:
if n + '.sin_cached' not in hook.weights_map.dataset.state_dict.keys():
hook.weights_map.dataset.state_dict[n + '.sin_cached'] = sin_cached.clone().cpu()
hook.weights_map.dataset.state_dict[n + '.cos_cached'] = cos_cached.clone().cpu()
tokenizer = LlamaTokenizer.from_pretrained(config_path)
tokenizer.truncation_side = 'left'
print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer

167
custom_autotune.py Normal file
View File

@ -0,0 +1,167 @@
#https://github.com/fpgaminer/GPTQ-triton
"""
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
"""
import builtins
import math
import time
from typing import Dict
import triton
class Autotuner(triton.KernelInterface):
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False):
'''
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
'''
if not configs:
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.nearest_power_of_two = nearest_power_of_two
self.cache = {}
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
if 'early_config_prune' in prune_configs_by:
early_config_prune = prune_configs_by['early_config_prune']
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
try:
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
return triton.testing.do_bench(kernel_call, rep=40)
except triton.compiler.OutOfResources:
return float('inf')
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple(args[i] for i in self.key_idx)
# This reduces the amount of autotuning by rounding the keys to the nearest power of two
# In my testing this gives decent results, and greatly reduces the amount of tuning required
if self.nearest_power_of_two:
key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
if key not in self.cache:
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook is not None:
config.pre_hook(self.nargs)
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
def prune_configs(self, kwargs):
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
num_warps=config.num_warps)
for config in pruned_configs
}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
return pruned_configs
def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two)
return decorator

View File

@ -24,6 +24,12 @@ if ft_config.flash_attention:
from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
import autograd_4bit
if ft_config.backend.lower() == 'triton':
autograd_4bit.switch_backend_to('triton')
else:
autograd_4bit.switch_backend_to('cuda')
import sys
import peft
@ -65,10 +71,16 @@ lora_config = LoraConfig(
if ft_config.lora_apply_dir is None:
model = get_peft_model(model, lora_config)
else:
device_map = ft_config.device_map
if ft_config.ddp:
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map="auto", torch_dtype=torch.float32) # ! Direct copy from inference.py
device_map = {'': 0}
else:
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map={'': 0}, torch_dtype=torch.float32)
if torch.cuda.device_count() > 1:
device_map = "auto"
else:
device_map = {'': 0}
print('Device map for lora:', device_map)
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map=device_map, torch_dtype=torch.float32)
print(ft_config.lora_apply_dir, 'loaded')

View File

@ -6,6 +6,7 @@ sentencepiece
safetensors
flash-attn
triton
colorama
git+https://github.com/huggingface/transformers.git
git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit
git+https://github.com/sterlind/peft.git

View File

@ -1,210 +1,239 @@
import triton
import triton.language as tl
import torch
# code based https://github.com/fpgaminer/GPTQ-triton
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_248_kernel(a_ptr, b_ptr, c_ptr,
scales_ptr, zeros_ptr, g_ptr,
M, N, K, bits, maxq,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
stride_scales, stride_zeros,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :]
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
# ! Convert to fp16
b = b.to(tl.float16)
a = a.to(tl.float16)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K
c = accumulator.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# code based https://github.com/fpgaminer/GPTQ-triton
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'],
)
@triton.jit
def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
scales_ptr, zeros_ptr, g_ptr,
M, N, K, bits, maxq,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
stride_scales, stride_zeros,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, N) float16
B is of shape (K//8, N) int32
C is of shape (M, K) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_k
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_k = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_n = tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_bk
g_idx = tl.load(g_ptrs)
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
shifter = (offs_bk % infearure_per_bits) * bits
zeros_shifter = (offs_n % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
for k in range(0, num_pid_n):
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
b = tl.trans(b)
# ! Convert to fp16
b = b.to(tl.float16)
a = a.to(tl.float16)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_N
b_ptrs += BLOCK_SIZE_N
scales_ptrs += BLOCK_SIZE_N
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
c = accumulator.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, c, mask=c_mask)
def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16)
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)
matmul_248_kernel[grid](input, qweight, output,
scales, qzeros, g_idx,
input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,
input.stride(0), input.stride(1),
qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0))
return output
import triton
import triton.language as tl
import torch
import custom_autotune
# code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# These provided a benefit on a 3090
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
],
key=['M', 'N'],
nearest_power_of_two=True,
)
@triton.jit
def matmul_248_kernel(a_ptr, b_ptr, c_ptr,
scales_ptr, zeros_ptr, g_ptr,
M, N, K, bits, maxq,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
stride_scales, stride_zeros,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :]
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
# ! Convert to fp16
b = b.to(tl.float16)
a = a.to(tl.float16)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K
c = accumulator.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
# These provided a benefit on a 3090
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
],
key=['M', 'K'],
nearest_power_of_two=True,
)
@triton.jit
def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
scales_ptr, zeros_ptr, g_ptr,
M, N, K, bits, maxq,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
stride_scales, stride_zeros,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, N) float16
B is of shape (K//8, N) int32
C is of shape (M, K) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_k
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_k = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_n = tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_bk
g_idx = tl.load(g_ptrs)
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
shifter = (offs_bk % infearure_per_bits) * bits
zeros_shifter = (offs_n % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
for k in range(0, num_pid_n):
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
b = tl.trans(b)
# ! Convert to fp16
b = b.to(tl.float16)
a = a.to(tl.float16)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_N
b_ptrs += BLOCK_SIZE_N
scales_ptrs += BLOCK_SIZE_N
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
c = accumulator.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, c, mask=c_mask)
def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
assert input.shape[1] == qweight.shape[0] * 32 // bits
output = torch.empty((input.shape[0], qweight.shape[1]), device=scales.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)
matmul_248_kernel[grid](input, qweight, output,
scales, qzeros, g_idx,
input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,
input.stride(0), input.stride(1),
qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0))
return output
def triton_matmul_transpose(input, qweight, scales, qzeros, g_idx, bits, maxq):
assert input.shape[1] == qweight.shape[1]
output_shape = (input.shape[0], qweight.shape[0] * 32 // bits)
output = torch.empty((output_shape[0], output_shape[1]), device=scales.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape[1], META['BLOCK_SIZE_K']),)
trans_matmul_248_kernel[grid](input, qweight, output,
scales, qzeros, g_idx,
input.shape[0], qweight.shape[1], output_shape[1], bits, maxq,
input.stride(0), input.stride(1),
qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0))
return output