merge pull request in new branch

This commit is contained in:
John Smith 2023-04-07 10:40:24 +08:00
commit 9351f49542
8 changed files with 476 additions and 228 deletions

View File

@ -97,5 +97,6 @@ class Finetune4bConfig:
f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\ f"{self.warmup_steps=}\n{self.save_steps=}\n{self.save_total_limit=}\n" +\
f"{self.logging_steps=}\n" +\ f"{self.logging_steps=}\n" +\
f"{self.checkpoint=}\n{self.skip=}\n" +\ f"{self.checkpoint=}\n{self.skip=}\n" +\
f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}" f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}\n" +\
f"{self.groupsize=}\n"
return s.replace("self.", "") return s.replace("self.", "")

View File

@ -35,10 +35,20 @@ pip install -r requirements.txt
~The same finetune script from https://github.com/tloen/alpaca-lora can be used.~<br> ~The same finetune script from https://github.com/tloen/alpaca-lora can be used.~<br>
After installation, this script can be used: After installation, this script can be used:
GPTQv1:
``` ```
python finetune.py python finetune.py
``` ```
or
```
GPTQ_VERSION=1 python finetune.py
```
GPTQv2:
```
GPTQ_VERSION=2 python finetune.py
```
# Inference # Inference

21
autograd_4bit/__init__.py Normal file
View File

@ -0,0 +1,21 @@
import os
from colorama import init, Fore, Back, Style
init(autoreset=True)
try:
GPTQ_VERSION = int(os.environ["GPTQ_VERSION"])
except:
print(Style.BRIGHT + Fore.YELLOW + "GPTQ_VERSION environment not provided. Fallback to GPTQv1")
GPTQ_VERSION = 1 # Fallback version
loader = None
if GPTQ_VERSION == 1:
from .autograd_4bit_v1 import Autograd4bitQuantLinear, load_llama_model_4bit_low_ram
print(Style.BRIGHT + Fore.GREEN + "GPTQv1 set")
elif GPTQ_VERSION == 2:
from .autograd_4bit_v2 import Autograd4bitQuantLinear, load_llama_model_4bit_low_ram
print(Style.BRIGHT + Fore.GREEN + "GPTQv2 set")
else:
raise ValueError("GPTQ_VERSION not set or invalid")

View File

@ -2,7 +2,6 @@ import matmul_utils_4bit as mm4b
import torch import torch
import torch.nn as nn import torch.nn as nn
import time import time
import math
class AutogradMatmul4bit(torch.autograd.Function): class AutogradMatmul4bit(torch.autograd.Function):
@ -32,25 +31,17 @@ class AutogradMatmul4bit(torch.autograd.Function):
# Assumes layer is perfectly divisible into 256 * 256 blocks # Assumes layer is perfectly divisible into 256 * 256 blocks
class Autograd4bitQuantLinear(nn.Module): class Autograd4bitQuantLinear(nn.Module):
def __init__(self, infeatures, outfeatures, groupsize=-1): def __init__(self, in_features, out_features, groupsize=None):
super().__init__() super().__init__()
bits = 4 bits = 4
self.in_features = infeatures self.in_features = in_features
self.out_features = outfeatures self.out_features = out_features
self.bits = bits self.bits = bits
self.groupsize = groupsize self.register_buffer('zeros', torch.empty((out_features, 1)))
if groupsize == -1: self.register_buffer('scales', torch.empty((out_features, 1)))
self.register_buffer('zeros', torch.empty((outfeatures, 1))) self.bias = nn.Parameter(torch.empty(out_features))
self.register_buffer('scales', torch.empty((outfeatures, 1)))
else:
self.register_buffer('qzeros',
torch.empty((math.ceil(infeatures/groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
)
self.register_buffer('scales', torch.empty((math.ceil(infeatures/groupsize), outfeatures)))
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype = torch.int32))
self.register_buffer('bias', torch.empty(outfeatures))
self.register_buffer( self.register_buffer(
'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int) 'qweight', torch.empty((in_features // 256 * (bits * 8), out_features), dtype=torch.int)
) )
@ -84,8 +75,7 @@ def model_to_half(model):
model.half() model.half()
for n, m in model.named_modules(): for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear): if isinstance(m, Autograd4bitQuantLinear):
if m.groupsize == -1: m.zeros = m.zeros.half()
m.zeros = m.zeros.half()
m.scales = m.scales.half() m.scales = m.scales.half()
m.bias = m.bias.half() m.bias = m.bias.half()
print('Converted as Half.') print('Converted as Half.')
@ -95,8 +85,7 @@ def model_to_float(model):
model.float() model.float()
for n, m in model.named_modules(): for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear): if isinstance(m, Autograd4bitQuantLinear):
if m.groupsize == -1: m.zeros = m.zeros.float()
m.zeros = m.zeros.float()
m.scales = m.scales.float() m.scales = m.scales.float()
m.bias = m.bias.float() m.bias = m.bias.float()
print('Converted as Float.') print('Converted as Float.')
@ -187,8 +176,7 @@ def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lo
print('Apply half ...') print('Apply half ...')
for n, m in model.named_modules(): for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)): if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)):
if m.groupsize == -1: m.zeros = m.zeros.half()
m.zeros = m.zeros.half()
m.scales = m.scales.half() m.scales = m.scales.half()
m.bias = m.bias.half() m.bias = m.bias.half()

View File

@ -0,0 +1,221 @@
from colorama import init, Fore, Back, Style
import torch
import torch.nn as nn
import time
import math
import triton
from triton_utils import matmul_248_kernel, trans_matmul_248_kernel
class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16)
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)
matmul_248_kernel[grid](input, qweight, output,
scales, qzeros, g_idx,
input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,
input.stride(0), input.stride(1),
qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0))
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.input_shape, ctx.bits,ctx.maxq = input.shape,bits, maxq
return output
@staticmethod
def backward(ctx, grad_output):
input_shape, bits, maxq = ctx.input_shape, ctx.bits, ctx.maxq
qweight, scales, qzeros, g_idx = ctx.saved_tensors
grade_input = None
if ctx.needs_input_grad[0]:
grade_input = torch.empty((input_shape[0], input_shape[1]), device='cuda', dtype=torch.float32)
grid = lambda META: (triton.cdiv(input_shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(input_shape[1], META['BLOCK_SIZE_K']),)
trans_matmul_248_kernel[grid](grad_output, qweight, grade_input,
scales, qzeros, g_idx,
input_shape[0], qweight.shape[1], input_shape[1], bits, maxq,
grad_output.stride(0), grad_output.stride(1),
qweight.stride(0), qweight.stride(1),
grade_input.stride(0), grade_input.stride(1),
scales.stride(0), qzeros.stride(0))
return grade_input, None, None, None, None, None, None
class Autograd4bitQuantLinear(nn.Module):
def __init__(self, in_features, out_features, groupsize, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.bits = 4 # Hardcoded 4-bits quantizations
self.maxq = 2 ** self.bits - 1
self.groupsize = groupsize if groupsize != -1 else in_features
self.register_buffer('qweight', torch.zeros((in_features // 32 * self.bits, out_features), dtype=torch.int32))
self.register_buffer('qzeros', torch.zeros((math.ceil(in_features / self.groupsize), out_features // 32 * self.bits), dtype=torch.int32))
self.register_buffer('scales', torch.zeros((math.ceil(in_features / self.groupsize), out_features), dtype=torch.float16))
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(in_features)], dtype = torch.int32))
if bias:
self.register_buffer('bias', torch.zeros(out_features,dtype=torch.float16))
else:
self.bias = None
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, )
out = AutogradMatmul4bit.apply(x.reshape(-1,x.shape[-1]), self.qweight, self.scales,
self.qzeros, self.g_idx, self.bits, self.maxq)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1):
if isinstance(module, Autograd4bitQuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
setattr(
module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize)
)
for name1, child in module.named_children():
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize)
def model_to_half(model):
model.half()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
m.qzeros = m.qzeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
print(Style.BRIGHT + Fore.YELLOW + 'Converted as Half.')
def model_to_float(model):
model.float()
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear):
m.qzeros = m.qzeros.float()
m.scales = m.scales.float()
m.bias = m.bias.float()
print(Style.BRIGHT + Fore.YELLOW + 'Converted as Float.')
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(
child, layers=layers, name=name + '.' + name1 if name != '' else name1
))
return res
def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048):
import accelerate
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
print(Style.BRIGHT + Fore.CYAN + "Loading Model ...")
t0 = time.time()
with accelerate.init_empty_weights():
config = LlamaConfig.from_pretrained(config_path)
model = LlamaForCausalLM(config)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant_for_4bit_autograd(model, layers, groupsize=groupsize)
model = accelerate.load_checkpoint_and_dispatch(
model=model,
checkpoint=model_path,
device_map=device_map,
no_split_module_classes=["LlamaDecoderLayer"]
)
model.seqlen = seqlen
if half:
model_to_half(model)
tokenizer = LlamaTokenizer.from_pretrained(config_path)
tokenizer.truncation_side = 'left'
print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer
def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None):
import accelerate
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
if max_memory is None:
max_memory = {0: '24Gib', 'cpu': '48Gib'}
print(Style.BRIGHT + Fore.CYAN + "Loading Model ...")
t0 = time.time()
with accelerate.init_empty_weights():
config = LlamaConfig.from_pretrained(config_path)
model = LlamaForCausalLM(config)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant_for_4bit_autograd(model, layers, groupsize=groupsize)
accelerate.load_checkpoint_in_model(model, checkpoint=model_path, device_map={'': 'cpu'})
# rotary_emb fix
for n, m in model.named_modules():
if 'rotary_emb' in n:
cos_cached = m.cos_cached.clone().cpu()
sin_cached = m.sin_cached.clone().cpu()
break
if lora_path is not None:
from peft import PeftModel
from peft.tuners.lora import Linear4bitLt
model = PeftModel.from_pretrained(model, lora_path, device_map={'': 'cpu'}, torch_dtype=torch.float32)
print(Style.BRIGHT + Fore.GREEN + '{} Lora Applied.'.format(lora_path))
model.seqlen = seqlen
print('Apply half ...')
for n, m in model.named_modules():
if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)):
m.qzeros = m.qzeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
print('Dispatching model ...')
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True, main_device=0)
torch.cuda.empty_cache()
print(Style.BRIGHT + Fore.YELLOW + 'Total {:.2f} Gib VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024))
# rotary_emb fix
for n, m in model.named_modules():
if 'rotary_emb' in n:
if getattr(m, '_hf_hook', None):
if isinstance(m._hf_hook, accelerate.hooks.SequentialHook):
hooks = m._hf_hook.hooks
else:
hooks = [m._hf_hook]
for hook in hooks:
if hook.offload:
if n + '.sin_cached' not in hook.weights_map.dataset.state_dict.keys():
hook.weights_map.dataset.state_dict[n + '.sin_cached'] = sin_cached.clone().cpu()
hook.weights_map.dataset.state_dict[n + '.cos_cached'] = cos_cached.clone().cpu()
tokenizer = LlamaTokenizer.from_pretrained(config_path)
tokenizer.truncation_side = 'left'
print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer

View File

@ -115,6 +115,7 @@ if not ft_config.skip:
per_device_train_batch_size=ft_config.mbatch_size, per_device_train_batch_size=ft_config.mbatch_size,
gradient_accumulation_steps=ft_config.gradient_accumulation_steps, gradient_accumulation_steps=ft_config.gradient_accumulation_steps,
warmup_steps=ft_config.warmup_steps, warmup_steps=ft_config.warmup_steps,
optim="adamw_torch",
num_train_epochs=ft_config.epochs, num_train_epochs=ft_config.epochs,
learning_rate=ft_config.lr, learning_rate=ft_config.lr,
fp16=True, fp16=True,

View File

@ -5,6 +5,7 @@ datasets
sentencepiece sentencepiece
safetensors safetensors
flash-attn flash-attn
triton
git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/transformers.git
git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit
git+https://github.com/sterlind/peft.git git+https://github.com/sterlind/peft.git

View File

@ -1,205 +1,210 @@
import triton import triton
import triton.language as tl import triton.language as tl
import torch import torch
# code based https://github.com/fpgaminer/GPTQ-triton # code based https://github.com/fpgaminer/GPTQ-triton
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
], ],
key=['M', 'N', 'K'], key=['M', 'N', 'K'],
) )
@triton.jit @triton.jit
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, def matmul_248_kernel(a_ptr, b_ptr, c_ptr,
scales_ptr, zeros_ptr, g_ptr, scales_ptr, zeros_ptr, g_ptr,
M, N, K, bits, maxq, M, N, K, bits, maxq,
stride_am, stride_ak, stride_am, stride_ak,
stride_bk, stride_bn, stride_bk, stride_bn,
stride_cm, stride_cn, stride_cm, stride_cn,
stride_scales, stride_zeros, stride_scales, stride_zeros,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr): GROUP_SIZE_M: tl.constexpr):
""" """
Compute the matrix multiplication C = A x B. Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16 A is of shape (M, K) float16
B is of shape (K//8, N) int32 B is of shape (K//8, N) int32
C is of shape (M, N) float16 C is of shape (M, N) float16
scales is of shape (G, N) float16 scales is of shape (G, N) float16
zeros is of shape (G, N) float16 zeros is of shape (G, N) float16
g_ptr is of shape (K) int32 g_ptr is of shape (K) int32
""" """
infearure_per_bits = 32 // bits infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m) pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = (offs_am[:, None] < M) a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times # b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B # shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :] scales_ptrs = scales_ptr + offs_bn[None, :]
zeros_ptrs = zeros_ptr + (offs_bn[None, :]// infearure_per_bits) zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k): for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs) g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1) zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values # Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift b = (b - zeros) * scales # Scale and shift
# ! Convert to fp16
accumulator += tl.dot(a, b) b = b.to(tl.float16)
a_ptrs += BLOCK_SIZE_K a = a.to(tl.float16)
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
c = accumulator.to(tl.float16) b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] g_ptrs += BLOCK_SIZE_K
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask) c = accumulator.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
# code based https://github.com/fpgaminer/GPTQ-triton c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
@triton.autotune( tl.store(c_ptrs, c, mask=c_mask)
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), # code based https://github.com/fpgaminer/GPTQ-triton
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), @triton.autotune(
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), configs=[
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
], triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
key=['M', 'N', 'K'], triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
) triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
],
@triton.jit key=['M', 'N', 'K'],
def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, )
scales_ptr, zeros_ptr, g_ptr,
M, N, K, bits, maxq, @triton.jit
stride_am, stride_ak, def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
stride_bk, stride_bn, scales_ptr, zeros_ptr, g_ptr,
stride_cm, stride_cn, M, N, K, bits, maxq,
stride_scales, stride_zeros, stride_am, stride_ak,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, stride_bk, stride_bn,
GROUP_SIZE_M: tl.constexpr): stride_cm, stride_cn,
""" stride_scales, stride_zeros,
Compute the matrix multiplication C = A x B. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
A is of shape (M, N) float16 GROUP_SIZE_M: tl.constexpr):
B is of shape (K//8, N) int32 """
C is of shape (M, K) float16 Compute the matrix multiplication C = A x B.
scales is of shape (G, N) float16 A is of shape (M, N) float16
zeros is of shape (G, N) float16 B is of shape (K//8, N) int32
g_ptr is of shape (K) int32 C is of shape (M, K) float16
""" scales is of shape (G, N) float16
infearure_per_bits = 32 // bits zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
pid = tl.program_id(axis=0) """
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) infearure_per_bits = 32 // bits
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) pid = tl.program_id(axis=0)
num_pid_in_group = GROUP_SIZE_M * num_pid_k num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
group_id = pid // num_pid_in_group num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
first_pid_m = group_id * GROUP_SIZE_M num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) num_pid_in_group = GROUP_SIZE_M * num_pid_k
pid_m = first_pid_m + (pid % group_size_m) group_id = pid // num_pid_in_group
pid_k = (pid % num_pid_in_group) // group_size_m first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) pid_m = first_pid_m + (pid % group_size_m)
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) pid_k = (pid % num_pid_in_group) // group_size_m
offs_n = tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
a_mask = (offs_am[:, None] < M) offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
# b_ptrs is set up such that it repeats elements along the K axis 8 times offs_n = tl.arange(0, BLOCK_SIZE_N)
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_bk a_mask = (offs_am[:, None] < M)
g_idx = tl.load(g_ptrs) # b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
# shifter is used to extract the N bits of each element in the 32-bit word from B g_ptrs = g_ptr + offs_bk
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales g_idx = tl.load(g_ptrs)
zeros_ptrs = zeros_ptr + (offs_n[None, :]// infearure_per_bits) + g_idx[:, None] * stride_zeros
# shifter is used to extract the N bits of each element in the 32-bit word from B
shifter = (offs_bk % infearure_per_bits) * bits scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
zeros_shifter = (offs_n % infearure_per_bits) * bits zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
shifter = (offs_bk % infearure_per_bits) * bits
for k in range(0, num_pid_n): zeros_shifter = (offs_n % infearure_per_bits) * bits
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) for k in range(0, num_pid_n):
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
zeros = (zeros >> zeros_shifter[None, :]) & maxq scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros + 1) zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) zeros = (zeros >> zeros_shifter[None, :]) & maxq
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated zeros = (zeros + 1)
# Now we need to unpack b (which is N-bit values) into 32-bit values a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
b = (b - zeros) * scales # Scale and shift
b = tl.trans(b) # Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
accumulator += tl.dot(a, b) b = (b - zeros) * scales # Scale and shift
a_ptrs += BLOCK_SIZE_N b = tl.trans(b)
b_ptrs += BLOCK_SIZE_N # ! Convert to fp16
scales_ptrs += BLOCK_SIZE_N b = b.to(tl.float16)
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) a = a.to(tl.float16)
c = accumulator.to(tl.float16) accumulator += tl.dot(a, b)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] a_ptrs += BLOCK_SIZE_N
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) b_ptrs += BLOCK_SIZE_N
tl.store(c_ptrs, accumulator, mask=c_mask) scales_ptrs += BLOCK_SIZE_N
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): c = accumulator.to(tl.float16)
output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16) c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),) c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
matmul_248_kernel[grid](input, qweight, output, tl.store(c_ptrs, c, mask=c_mask)
scales, qzeros, g_idx,
input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,
input.stride(0), input.stride(1), def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
qweight.stride(0), qweight.stride(1), output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16)
output.stride(0), output.stride(1), grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)
scales.stride(0), qzeros.stride(0)) matmul_248_kernel[grid](input, qweight, output,
return output scales, qzeros, g_idx,
input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,
input.stride(0), input.stride(1),
qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0))
return output