add triton backend support for v2 model
This commit is contained in:
parent
9351f49542
commit
dba3773b30
|
|
@ -2,4 +2,6 @@ alpaca_lora/
|
||||||
repository/
|
repository/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
llama-13b-4bit
|
llama-13b-4bit
|
||||||
llama-13b-4bit.pt
|
llama-13b-4bit.pt
|
||||||
|
text-generation-webui/
|
||||||
|
repository/
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ class Finetune4bConfig:
|
||||||
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
|
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
|
||||||
checkpoint: bool, skip: bool, verbose: bool,
|
checkpoint: bool, skip: bool, verbose: bool,
|
||||||
txt_row_thd: int, use_eos_token: bool, groupsize: int,
|
txt_row_thd: int, use_eos_token: bool, groupsize: int,
|
||||||
local_rank: int, flash_attention: bool
|
local_rank: int, flash_attention: bool, backend: str
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -86,6 +86,7 @@ class Finetune4bConfig:
|
||||||
self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size
|
self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size
|
||||||
self.groupsize = groupsize
|
self.groupsize = groupsize
|
||||||
self.flash_attention = flash_attention
|
self.flash_attention = flash_attention
|
||||||
|
self.backend = backend
|
||||||
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
|
@ -98,5 +99,5 @@ class Finetune4bConfig:
|
||||||
f"{self.logging_steps=}\n" +\
|
f"{self.logging_steps=}\n" +\
|
||||||
f"{self.checkpoint=}\n{self.skip=}\n" +\
|
f"{self.checkpoint=}\n{self.skip=}\n" +\
|
||||||
f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}\n" +\
|
f"{self.world_size=}\n{self.ddp=}\n{self.device_map=}\n" +\
|
||||||
f"{self.groupsize=}\n"
|
f"{self.groupsize=}\n{self.backend=}\n"
|
||||||
return s.replace("self.", "")
|
return s.replace("self.", "")
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,11 @@ def parse_commandline():
|
||||||
# Multi GPU Support
|
# Multi GPU Support
|
||||||
parser_training.add_argument("--local_rank", type=int, default=0, help="local rank if using torch.distributed.launch")
|
parser_training.add_argument("--local_rank", type=int, default=0, help="local rank if using torch.distributed.launch")
|
||||||
|
|
||||||
parser_training.add_argument("--flash_attention", help="enables flash attention, can improve performance and reduce VRAM use")
|
# Flash Attention
|
||||||
|
parser_training.add_argument("--flash_attention", action="store_true", help="enables flash attention, can improve performance and reduce VRAM use")
|
||||||
|
|
||||||
|
# Train Backend
|
||||||
|
parser_training.add_argument("--backend", type=str, default='cuda', help="Backend to use. Triton or Cuda.")
|
||||||
|
|
||||||
return vars(parser.parse_args())
|
return vars(parser.parse_args())
|
||||||
|
|
||||||
|
|
@ -105,4 +109,5 @@ def get_config() -> Finetune4bConfig:
|
||||||
groupsize=args["groupsize"],
|
groupsize=args["groupsize"],
|
||||||
local_rank=args["local_rank"],
|
local_rank=args["local_rank"],
|
||||||
flash_attention=args["flash_attention"],
|
flash_attention=args["flash_attention"],
|
||||||
|
backend=args["backend"],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,15 @@ import matmul_utils_4bit as mm4b
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import time
|
import time
|
||||||
|
import math
|
||||||
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
|
|
||||||
|
|
||||||
class AutogradMatmul4bit(torch.autograd.Function):
|
class AutogradMatmul4bitCuda(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x, qweight, scales, zeros, groupsize=-1):
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
|
def forward(ctx, x, qweight, scales, zeros, g_idx, bits, maxq, groupsize=-1):
|
||||||
ctx.save_for_backward(qweight, scales, zeros)
|
ctx.save_for_backward(qweight, scales, zeros)
|
||||||
ctx.groupsize = groupsize
|
ctx.groupsize = groupsize
|
||||||
if groupsize == -1:
|
if groupsize == -1:
|
||||||
|
|
@ -18,42 +21,116 @@ class AutogradMatmul4bit(torch.autograd.Function):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
qweight, scales, zeros = ctx.saved_tensors
|
qweight, scales, zeros = ctx.saved_tensors
|
||||||
groupsize = ctx.groupsize
|
groupsize = ctx.groupsize
|
||||||
if groupsize == -1:
|
if ctx.needs_input_grad[0]:
|
||||||
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
|
if groupsize == -1:
|
||||||
else:
|
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
|
||||||
grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, groupsize=groupsize, transpose=True)
|
else:
|
||||||
return grad, None, None, None, None
|
grad = mm4b._matmul4bit_v2_recons(grad_output, qweight, scales, zeros, groupsize=groupsize, transpose=True)
|
||||||
|
return grad, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import triton_utils as tu
|
||||||
|
|
||||||
|
class AutogradMatmul4bitTriton(torch.autograd.Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
|
def forward(ctx, x, qweight, scales, qzeros, g_idx, bits, maxq, groupsize=-1):
|
||||||
|
output = tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||||
|
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
|
||||||
|
ctx.bits, ctx.maxq = bits, maxq
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
qweight, scales, qzeros, g_idx = ctx.saved_tensors
|
||||||
|
bits, maxq = ctx.bits, ctx.maxq
|
||||||
|
grad_input = None
|
||||||
|
|
||||||
|
if ctx.needs_input_grad[0]:
|
||||||
|
grad_input = tu.triton_matmul_transpose(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||||
|
return grad_input, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
print('Triton not found. Please run "pip install triton".')
|
||||||
|
|
||||||
|
|
||||||
|
AutogradMatmul4bit = AutogradMatmul4bitCuda
|
||||||
|
backend = 'cuda'
|
||||||
|
|
||||||
|
|
||||||
|
def switch_backend_to(to_backend):
|
||||||
|
global AutogradMatmul4bit
|
||||||
|
global backend
|
||||||
|
if to_backend == 'cuda':
|
||||||
|
AutogradMatmul4bit = AutogradMatmul4bitCuda
|
||||||
|
backend = 'cuda'
|
||||||
|
print('Using CUDA implementation.')
|
||||||
|
elif to_backend == 'triton':
|
||||||
|
# detect if AutogradMatmul4bitTriton is defined
|
||||||
|
if 'AutogradMatmul4bitTriton' not in globals():
|
||||||
|
raise ValueError('Triton not found. Please install triton_utils.')
|
||||||
|
AutogradMatmul4bit = AutogradMatmul4bitTriton
|
||||||
|
backend = 'triton'
|
||||||
|
print('Using Triton implementation.')
|
||||||
|
else:
|
||||||
|
raise ValueError('Backend not supported.')
|
||||||
|
|
||||||
|
|
||||||
|
def matmul4bit_with_backend(x, qweight, scales, qzeros, g_idx, bits, maxq, groupsize):
|
||||||
|
if backend == 'cuda':
|
||||||
|
return mm4b.matmul4bit(x, qweight, scales, qzeros, groupsize)
|
||||||
|
elif backend == 'triton':
|
||||||
|
assert qzeros.dtype == torch.int32
|
||||||
|
return tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||||
|
else:
|
||||||
|
raise ValueError('Backend not supported.')
|
||||||
|
|
||||||
|
|
||||||
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
||||||
class Autograd4bitQuantLinear(nn.Module):
|
class Autograd4bitQuantLinear(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_features, out_features, groupsize=None):
|
def __init__(self, in_features, out_features, groupsize=-1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
bits = 4
|
bits = 4
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
self.register_buffer('zeros', torch.empty((out_features, 1)))
|
self.maxq = 2 ** self.bits - 1
|
||||||
self.register_buffer('scales', torch.empty((out_features, 1)))
|
self.groupsize = groupsize
|
||||||
self.bias = nn.Parameter(torch.empty(out_features))
|
if groupsize == -1:
|
||||||
|
self.register_buffer('zeros', torch.empty((out_features, 1)))
|
||||||
|
self.register_buffer('scales', torch.empty((out_features, 1)))
|
||||||
|
else:
|
||||||
|
self.register_buffer('qzeros',
|
||||||
|
torch.empty((math.ceil(in_features/groupsize), out_features // 256 * (bits * 8)), dtype=torch.int32)
|
||||||
|
)
|
||||||
|
self.register_buffer('scales', torch.empty((math.ceil(in_features/groupsize), out_features)))
|
||||||
|
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(in_features)], dtype = torch.int32))
|
||||||
|
self.register_buffer('bias', torch.empty(out_features))
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
'qweight', torch.empty((in_features // 256 * (bits * 8), out_features), dtype=torch.int)
|
'qweight', torch.empty((in_features // 256 * (bits * 8), out_features), dtype=torch.int32)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if torch.is_grad_enabled():
|
if torch.is_grad_enabled():
|
||||||
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales,
|
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales,
|
||||||
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
|
self.qzeros if self.groupsize != -1 else self.zeros,
|
||||||
out += self.bias
|
self.g_idx, self.bits, self.maxq,
|
||||||
|
self.groupsize)
|
||||||
else:
|
else:
|
||||||
out = mm4b.matmul4bit(x, self.qweight, self.scales,
|
out = matmul4bit_with_backend(x, self.qweight, self.scales,
|
||||||
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
|
self.qzeros if self.groupsize != -1 else self.zeros,
|
||||||
out += self.bias
|
self.g_idx, self.bits, self.maxq,
|
||||||
|
self.groupsize)
|
||||||
|
out += self.bias
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -75,7 +152,8 @@ def model_to_half(model):
|
||||||
model.half()
|
model.half()
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
if isinstance(m, Autograd4bitQuantLinear):
|
if isinstance(m, Autograd4bitQuantLinear):
|
||||||
m.zeros = m.zeros.half()
|
if m.groupsize == -1:
|
||||||
|
m.zeros = m.zeros.half()
|
||||||
m.scales = m.scales.half()
|
m.scales = m.scales.half()
|
||||||
m.bias = m.bias.half()
|
m.bias = m.bias.half()
|
||||||
print('Converted as Half.')
|
print('Converted as Half.')
|
||||||
|
|
@ -85,7 +163,8 @@ def model_to_float(model):
|
||||||
model.float()
|
model.float()
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
if isinstance(m, Autograd4bitQuantLinear):
|
if isinstance(m, Autograd4bitQuantLinear):
|
||||||
m.zeros = m.zeros.float()
|
if m.groupsize == -1:
|
||||||
|
m.zeros = m.zeros.float()
|
||||||
m.scales = m.scales.float()
|
m.scales = m.scales.float()
|
||||||
m.bias = m.bias.float()
|
m.bias = m.bias.float()
|
||||||
print('Converted as Float.')
|
print('Converted as Float.')
|
||||||
|
|
@ -137,7 +216,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=Fa
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None):
|
def load_llama_model_4bit_low_ram_and_offload(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None):
|
||||||
import accelerate
|
import accelerate
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
|
|
@ -176,7 +255,8 @@ def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lo
|
||||||
print('Apply half ...')
|
print('Apply half ...')
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)):
|
if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)):
|
||||||
m.zeros = m.zeros.half()
|
if m.groupsize == -1:
|
||||||
|
m.zeros = m.zeros.half()
|
||||||
m.scales = m.scales.half()
|
m.scales = m.scales.half()
|
||||||
m.bias = m.bias.half()
|
m.bias = m.bias.half()
|
||||||
|
|
||||||
|
|
@ -206,3 +286,5 @@ def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lo
|
||||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
load_llama_model_4bit_low_ram_and_offload_to_cpu = load_llama_model_4bit_low_ram_and_offload
|
||||||
|
|
@ -1,21 +0,0 @@
|
||||||
import os
|
|
||||||
from colorama import init, Fore, Back, Style
|
|
||||||
init(autoreset=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
GPTQ_VERSION = int(os.environ["GPTQ_VERSION"])
|
|
||||||
except:
|
|
||||||
print(Style.BRIGHT + Fore.YELLOW + "GPTQ_VERSION environment not provided. Fallback to GPTQv1")
|
|
||||||
GPTQ_VERSION = 1 # Fallback version
|
|
||||||
|
|
||||||
loader = None
|
|
||||||
|
|
||||||
|
|
||||||
if GPTQ_VERSION == 1:
|
|
||||||
from .autograd_4bit_v1 import Autograd4bitQuantLinear, load_llama_model_4bit_low_ram
|
|
||||||
print(Style.BRIGHT + Fore.GREEN + "GPTQv1 set")
|
|
||||||
elif GPTQ_VERSION == 2:
|
|
||||||
from .autograd_4bit_v2 import Autograd4bitQuantLinear, load_llama_model_4bit_low_ram
|
|
||||||
print(Style.BRIGHT + Fore.GREEN + "GPTQv2 set")
|
|
||||||
else:
|
|
||||||
raise ValueError("GPTQ_VERSION not set or invalid")
|
|
||||||
|
|
@ -1,221 +0,0 @@
|
||||||
from colorama import init, Fore, Back, Style
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import time
|
|
||||||
import math
|
|
||||||
import triton
|
|
||||||
from triton_utils import matmul_248_kernel, trans_matmul_248_kernel
|
|
||||||
|
|
||||||
|
|
||||||
class AutogradMatmul4bit(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
|
||||||
output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16)
|
|
||||||
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)
|
|
||||||
matmul_248_kernel[grid](input, qweight, output,
|
|
||||||
scales, qzeros, g_idx,
|
|
||||||
input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,
|
|
||||||
input.stride(0), input.stride(1),
|
|
||||||
qweight.stride(0), qweight.stride(1),
|
|
||||||
output.stride(0), output.stride(1),
|
|
||||||
scales.stride(0), qzeros.stride(0))
|
|
||||||
|
|
||||||
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
|
|
||||||
ctx.input_shape, ctx.bits,ctx.maxq = input.shape,bits, maxq
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
input_shape, bits, maxq = ctx.input_shape, ctx.bits, ctx.maxq
|
|
||||||
qweight, scales, qzeros, g_idx = ctx.saved_tensors
|
|
||||||
grade_input = None
|
|
||||||
|
|
||||||
if ctx.needs_input_grad[0]:
|
|
||||||
grade_input = torch.empty((input_shape[0], input_shape[1]), device='cuda', dtype=torch.float32)
|
|
||||||
grid = lambda META: (triton.cdiv(input_shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(input_shape[1], META['BLOCK_SIZE_K']),)
|
|
||||||
trans_matmul_248_kernel[grid](grad_output, qweight, grade_input,
|
|
||||||
scales, qzeros, g_idx,
|
|
||||||
input_shape[0], qweight.shape[1], input_shape[1], bits, maxq,
|
|
||||||
grad_output.stride(0), grad_output.stride(1),
|
|
||||||
qweight.stride(0), qweight.stride(1),
|
|
||||||
grade_input.stride(0), grade_input.stride(1),
|
|
||||||
scales.stride(0), qzeros.stride(0))
|
|
||||||
return grade_input, None, None, None, None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
class Autograd4bitQuantLinear(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_features, out_features, groupsize, bias=True):
|
|
||||||
super().__init__()
|
|
||||||
self.in_features = in_features
|
|
||||||
self.out_features = out_features
|
|
||||||
self.bits = 4 # Hardcoded 4-bits quantizations
|
|
||||||
self.maxq = 2 ** self.bits - 1
|
|
||||||
self.groupsize = groupsize if groupsize != -1 else in_features
|
|
||||||
|
|
||||||
self.register_buffer('qweight', torch.zeros((in_features // 32 * self.bits, out_features), dtype=torch.int32))
|
|
||||||
self.register_buffer('qzeros', torch.zeros((math.ceil(in_features / self.groupsize), out_features // 32 * self.bits), dtype=torch.int32))
|
|
||||||
self.register_buffer('scales', torch.zeros((math.ceil(in_features / self.groupsize), out_features), dtype=torch.float16))
|
|
||||||
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(in_features)], dtype = torch.int32))
|
|
||||||
if bias:
|
|
||||||
self.register_buffer('bias', torch.zeros(out_features,dtype=torch.float16))
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out_shape = x.shape[:-1] + (self.out_features, )
|
|
||||||
out = AutogradMatmul4bit.apply(x.reshape(-1,x.shape[-1]), self.qweight, self.scales,
|
|
||||||
self.qzeros, self.g_idx, self.bits, self.maxq)
|
|
||||||
out = out + self.bias if self.bias is not None else out
|
|
||||||
return out.reshape(out_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def make_quant_for_4bit_autograd(module, names, name='', groupsize=-1):
|
|
||||||
if isinstance(module, Autograd4bitQuantLinear):
|
|
||||||
return
|
|
||||||
for attr in dir(module):
|
|
||||||
tmp = getattr(module, attr)
|
|
||||||
name1 = name + '.' + attr if name != '' else attr
|
|
||||||
if name1 in names:
|
|
||||||
setattr(
|
|
||||||
module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features, groupsize=groupsize)
|
|
||||||
)
|
|
||||||
for name1, child in module.named_children():
|
|
||||||
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1, groupsize=groupsize)
|
|
||||||
|
|
||||||
|
|
||||||
def model_to_half(model):
|
|
||||||
model.half()
|
|
||||||
for n, m in model.named_modules():
|
|
||||||
if isinstance(m, Autograd4bitQuantLinear):
|
|
||||||
m.qzeros = m.qzeros.half()
|
|
||||||
m.scales = m.scales.half()
|
|
||||||
m.bias = m.bias.half()
|
|
||||||
print(Style.BRIGHT + Fore.YELLOW + 'Converted as Half.')
|
|
||||||
|
|
||||||
|
|
||||||
def model_to_float(model):
|
|
||||||
model.float()
|
|
||||||
for n, m in model.named_modules():
|
|
||||||
if isinstance(m, Autograd4bitQuantLinear):
|
|
||||||
m.qzeros = m.qzeros.float()
|
|
||||||
m.scales = m.scales.float()
|
|
||||||
m.bias = m.bias.float()
|
|
||||||
print(Style.BRIGHT + Fore.YELLOW + 'Converted as Float.')
|
|
||||||
|
|
||||||
|
|
||||||
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
|
||||||
if type(module) in layers:
|
|
||||||
return {name: module}
|
|
||||||
res = {}
|
|
||||||
for name1, child in module.named_children():
|
|
||||||
res.update(find_layers(
|
|
||||||
child, layers=layers, name=name + '.' + name1 if name != '' else name1
|
|
||||||
))
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def load_llama_model_4bit_low_ram(config_path, model_path, groupsize=-1, half=False, device_map="auto", seqlen=2048):
|
|
||||||
import accelerate
|
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
|
||||||
|
|
||||||
print(Style.BRIGHT + Fore.CYAN + "Loading Model ...")
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
with accelerate.init_empty_weights():
|
|
||||||
config = LlamaConfig.from_pretrained(config_path)
|
|
||||||
model = LlamaForCausalLM(config)
|
|
||||||
model = model.eval()
|
|
||||||
layers = find_layers(model)
|
|
||||||
for name in ['lm_head']:
|
|
||||||
if name in layers:
|
|
||||||
del layers[name]
|
|
||||||
make_quant_for_4bit_autograd(model, layers, groupsize=groupsize)
|
|
||||||
model = accelerate.load_checkpoint_and_dispatch(
|
|
||||||
model=model,
|
|
||||||
checkpoint=model_path,
|
|
||||||
device_map=device_map,
|
|
||||||
no_split_module_classes=["LlamaDecoderLayer"]
|
|
||||||
)
|
|
||||||
|
|
||||||
model.seqlen = seqlen
|
|
||||||
|
|
||||||
if half:
|
|
||||||
model_to_half(model)
|
|
||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(config_path)
|
|
||||||
tokenizer.truncation_side = 'left'
|
|
||||||
|
|
||||||
print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
|
||||||
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
def load_llama_model_4bit_low_ram_and_offload_to_cpu(config_path, model_path, lora_path=None, groupsize=-1, seqlen=2048, max_memory=None):
|
|
||||||
import accelerate
|
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
|
||||||
|
|
||||||
if max_memory is None:
|
|
||||||
max_memory = {0: '24Gib', 'cpu': '48Gib'}
|
|
||||||
|
|
||||||
print(Style.BRIGHT + Fore.CYAN + "Loading Model ...")
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
with accelerate.init_empty_weights():
|
|
||||||
config = LlamaConfig.from_pretrained(config_path)
|
|
||||||
model = LlamaForCausalLM(config)
|
|
||||||
model = model.eval()
|
|
||||||
layers = find_layers(model)
|
|
||||||
for name in ['lm_head']:
|
|
||||||
if name in layers:
|
|
||||||
del layers[name]
|
|
||||||
make_quant_for_4bit_autograd(model, layers, groupsize=groupsize)
|
|
||||||
accelerate.load_checkpoint_in_model(model, checkpoint=model_path, device_map={'': 'cpu'})
|
|
||||||
|
|
||||||
# rotary_emb fix
|
|
||||||
for n, m in model.named_modules():
|
|
||||||
if 'rotary_emb' in n:
|
|
||||||
cos_cached = m.cos_cached.clone().cpu()
|
|
||||||
sin_cached = m.sin_cached.clone().cpu()
|
|
||||||
break
|
|
||||||
|
|
||||||
if lora_path is not None:
|
|
||||||
from peft import PeftModel
|
|
||||||
from peft.tuners.lora import Linear4bitLt
|
|
||||||
model = PeftModel.from_pretrained(model, lora_path, device_map={'': 'cpu'}, torch_dtype=torch.float32)
|
|
||||||
print(Style.BRIGHT + Fore.GREEN + '{} Lora Applied.'.format(lora_path))
|
|
||||||
|
|
||||||
model.seqlen = seqlen
|
|
||||||
|
|
||||||
print('Apply half ...')
|
|
||||||
for n, m in model.named_modules():
|
|
||||||
if isinstance(m, Autograd4bitQuantLinear) or ((lora_path is not None) and isinstance(m, Linear4bitLt)):
|
|
||||||
m.qzeros = m.qzeros.half()
|
|
||||||
m.scales = m.scales.half()
|
|
||||||
m.bias = m.bias.half()
|
|
||||||
|
|
||||||
print('Dispatching model ...')
|
|
||||||
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
|
|
||||||
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True, main_device=0)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
print(Style.BRIGHT + Fore.YELLOW + 'Total {:.2f} Gib VRAM used.'.format(torch.cuda.memory_allocated() / 1024 / 1024))
|
|
||||||
|
|
||||||
# rotary_emb fix
|
|
||||||
for n, m in model.named_modules():
|
|
||||||
if 'rotary_emb' in n:
|
|
||||||
if getattr(m, '_hf_hook', None):
|
|
||||||
if isinstance(m._hf_hook, accelerate.hooks.SequentialHook):
|
|
||||||
hooks = m._hf_hook.hooks
|
|
||||||
else:
|
|
||||||
hooks = [m._hf_hook]
|
|
||||||
for hook in hooks:
|
|
||||||
if hook.offload:
|
|
||||||
if n + '.sin_cached' not in hook.weights_map.dataset.state_dict.keys():
|
|
||||||
hook.weights_map.dataset.state_dict[n + '.sin_cached'] = sin_cached.clone().cpu()
|
|
||||||
hook.weights_map.dataset.state_dict[n + '.cos_cached'] = cos_cached.clone().cpu()
|
|
||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(config_path)
|
|
||||||
tokenizer.truncation_side = 'left'
|
|
||||||
|
|
||||||
print(Style.BRIGHT + Fore.GREEN + f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
|
||||||
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
@ -0,0 +1,167 @@
|
||||||
|
#https://github.com/fpgaminer/GPTQ-triton
|
||||||
|
"""
|
||||||
|
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import triton
|
||||||
|
|
||||||
|
|
||||||
|
class Autotuner(triton.KernelInterface):
|
||||||
|
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False):
|
||||||
|
'''
|
||||||
|
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||||
|
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||||
|
'top_k': number of configs to bench
|
||||||
|
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||||
|
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
|
||||||
|
'''
|
||||||
|
if not configs:
|
||||||
|
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
|
||||||
|
else:
|
||||||
|
self.configs = configs
|
||||||
|
self.key_idx = [arg_names.index(k) for k in key]
|
||||||
|
self.nearest_power_of_two = nearest_power_of_two
|
||||||
|
self.cache = {}
|
||||||
|
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||||
|
self.hook = lambda args: 0
|
||||||
|
if reset_to_zero is not None:
|
||||||
|
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||||
|
|
||||||
|
def _hook(args):
|
||||||
|
for i in self.reset_idx:
|
||||||
|
args[i].zero_()
|
||||||
|
self.hook = _hook
|
||||||
|
self.arg_names = arg_names
|
||||||
|
# prune configs
|
||||||
|
if prune_configs_by:
|
||||||
|
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||||
|
if 'early_config_prune' in prune_configs_by:
|
||||||
|
early_config_prune = prune_configs_by['early_config_prune']
|
||||||
|
else:
|
||||||
|
perf_model, top_k, early_config_prune = None, None, None
|
||||||
|
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||||
|
self.early_config_prune = early_config_prune
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def _bench(self, *args, config, **meta):
|
||||||
|
# check for conflicts, i.e. meta-parameters both provided
|
||||||
|
# as kwargs and by the autotuner
|
||||||
|
conflicts = meta.keys() & config.kwargs.keys()
|
||||||
|
if conflicts:
|
||||||
|
raise ValueError(
|
||||||
|
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||||
|
" Make sure that you don't re-define auto-tuned symbols."
|
||||||
|
)
|
||||||
|
# augment meta-parameters with tunable ones
|
||||||
|
current = dict(meta, **config.kwargs)
|
||||||
|
|
||||||
|
def kernel_call():
|
||||||
|
if config.pre_hook:
|
||||||
|
config.pre_hook(self.nargs)
|
||||||
|
self.hook(args)
|
||||||
|
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||||
|
try:
|
||||||
|
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
|
||||||
|
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
|
||||||
|
return triton.testing.do_bench(kernel_call, rep=40)
|
||||||
|
except triton.compiler.OutOfResources:
|
||||||
|
return float('inf')
|
||||||
|
|
||||||
|
def run(self, *args, **kwargs):
|
||||||
|
self.nargs = dict(zip(self.arg_names, args))
|
||||||
|
if len(self.configs) > 1:
|
||||||
|
key = tuple(args[i] for i in self.key_idx)
|
||||||
|
|
||||||
|
# This reduces the amount of autotuning by rounding the keys to the nearest power of two
|
||||||
|
# In my testing this gives decent results, and greatly reduces the amount of tuning required
|
||||||
|
if self.nearest_power_of_two:
|
||||||
|
key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
|
||||||
|
|
||||||
|
if key not in self.cache:
|
||||||
|
# prune configs
|
||||||
|
pruned_configs = self.prune_configs(kwargs)
|
||||||
|
bench_start = time.time()
|
||||||
|
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||||
|
for config in pruned_configs}
|
||||||
|
bench_end = time.time()
|
||||||
|
self.bench_time = bench_end - bench_start
|
||||||
|
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||||
|
self.hook(args)
|
||||||
|
self.configs_timings = timings
|
||||||
|
config = self.cache[key]
|
||||||
|
else:
|
||||||
|
config = self.configs[0]
|
||||||
|
self.best_config = config
|
||||||
|
if config.pre_hook is not None:
|
||||||
|
config.pre_hook(self.nargs)
|
||||||
|
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||||
|
|
||||||
|
def prune_configs(self, kwargs):
|
||||||
|
pruned_configs = self.configs
|
||||||
|
if self.early_config_prune:
|
||||||
|
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
||||||
|
if self.perf_model:
|
||||||
|
top_k = self.configs_top_k
|
||||||
|
if isinstance(top_k, float) and top_k <= 1.0:
|
||||||
|
top_k = int(len(self.configs) * top_k)
|
||||||
|
if len(pruned_configs) > top_k:
|
||||||
|
est_timing = {
|
||||||
|
config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
|
||||||
|
num_warps=config.num_warps)
|
||||||
|
for config in pruned_configs
|
||||||
|
}
|
||||||
|
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||||
|
return pruned_configs
|
||||||
|
|
||||||
|
def warmup(self, *args, **kwargs):
|
||||||
|
self.nargs = dict(zip(self.arg_names, args))
|
||||||
|
for config in self.prune_configs(kwargs):
|
||||||
|
self.fn.warmup(
|
||||||
|
*args,
|
||||||
|
num_warps=config.num_warps,
|
||||||
|
num_stages=config.num_stages,
|
||||||
|
**kwargs,
|
||||||
|
**config.kwargs,
|
||||||
|
)
|
||||||
|
self.nargs = None
|
||||||
|
|
||||||
|
|
||||||
|
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):
|
||||||
|
"""
|
||||||
|
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||||
|
.. highlight:: python
|
||||||
|
.. code-block:: python
|
||||||
|
@triton.autotune(configs=[
|
||||||
|
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||||
|
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||||
|
],
|
||||||
|
key=['x_size'] # the two above configs will be evaluated anytime
|
||||||
|
# the value of x_size changes
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def kernel(x_ptr, x_size, **META):
|
||||||
|
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||||
|
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||||
|
This means that whatever value the kernel updates will be updated multiple times.
|
||||||
|
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||||
|
reset the value of the provided tensor to `zero` before running any configuration.
|
||||||
|
:param configs: a list of :code:`triton.Config` objects
|
||||||
|
:type configs: list[triton.Config]
|
||||||
|
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||||
|
:type key: list[str]
|
||||||
|
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||||
|
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||||
|
'top_k': number of configs to bench
|
||||||
|
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
|
||||||
|
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||||
|
:type reset_to_zero: list[str]
|
||||||
|
"""
|
||||||
|
def decorator(fn):
|
||||||
|
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two)
|
||||||
|
|
||||||
|
return decorator
|
||||||
16
finetune.py
16
finetune.py
|
|
@ -24,6 +24,12 @@ if ft_config.flash_attention:
|
||||||
from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
||||||
replace_llama_attn_with_flash_attn()
|
replace_llama_attn_with_flash_attn()
|
||||||
|
|
||||||
|
import autograd_4bit
|
||||||
|
if ft_config.backend.lower() == 'triton':
|
||||||
|
autograd_4bit.switch_backend_to('triton')
|
||||||
|
else:
|
||||||
|
autograd_4bit.switch_backend_to('cuda')
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import peft
|
import peft
|
||||||
|
|
@ -65,10 +71,16 @@ lora_config = LoraConfig(
|
||||||
if ft_config.lora_apply_dir is None:
|
if ft_config.lora_apply_dir is None:
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
else:
|
else:
|
||||||
|
device_map = ft_config.device_map
|
||||||
if ft_config.ddp:
|
if ft_config.ddp:
|
||||||
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map="auto", torch_dtype=torch.float32) # ! Direct copy from inference.py
|
device_map = {'': 0}
|
||||||
else:
|
else:
|
||||||
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map={'': 0}, torch_dtype=torch.float32)
|
if torch.cuda.device_count() > 1:
|
||||||
|
device_map = "auto"
|
||||||
|
else:
|
||||||
|
device_map = {'': 0}
|
||||||
|
print('Device map for lora:', device_map)
|
||||||
|
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map=device_map, torch_dtype=torch.float32)
|
||||||
print(ft_config.lora_apply_dir, 'loaded')
|
print(ft_config.lora_apply_dir, 'loaded')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ sentencepiece
|
||||||
safetensors
|
safetensors
|
||||||
flash-attn
|
flash-attn
|
||||||
triton
|
triton
|
||||||
|
colorama
|
||||||
git+https://github.com/huggingface/transformers.git
|
git+https://github.com/huggingface/transformers.git
|
||||||
git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit
|
git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit
|
||||||
git+https://github.com/sterlind/peft.git
|
git+https://github.com/sterlind/peft.git
|
||||||
|
|
|
||||||
449
triton_utils.py
449
triton_utils.py
|
|
@ -1,210 +1,239 @@
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
import torch
|
import torch
|
||||||
|
import custom_autotune
|
||||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
|
||||||
@triton.autotune(
|
|
||||||
configs=[
|
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
@custom_autotune.autotune(
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
configs=[
|
||||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
# These provided a benefit on a 3090
|
||||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
],
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
key=['M', 'N', 'K'],
|
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
)
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
@triton.jit
|
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
def matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
],
|
||||||
scales_ptr, zeros_ptr, g_ptr,
|
key=['M', 'N'],
|
||||||
M, N, K, bits, maxq,
|
nearest_power_of_two=True,
|
||||||
stride_am, stride_ak,
|
)
|
||||||
stride_bk, stride_bn,
|
|
||||||
stride_cm, stride_cn,
|
|
||||||
stride_scales, stride_zeros,
|
@triton.jit
|
||||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
def matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
||||||
GROUP_SIZE_M: tl.constexpr):
|
scales_ptr, zeros_ptr, g_ptr,
|
||||||
"""
|
M, N, K, bits, maxq,
|
||||||
Compute the matrix multiplication C = A x B.
|
stride_am, stride_ak,
|
||||||
A is of shape (M, K) float16
|
stride_bk, stride_bn,
|
||||||
B is of shape (K//8, N) int32
|
stride_cm, stride_cn,
|
||||||
C is of shape (M, N) float16
|
stride_scales, stride_zeros,
|
||||||
scales is of shape (G, N) float16
|
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||||
zeros is of shape (G, N) float16
|
GROUP_SIZE_M: tl.constexpr):
|
||||||
g_ptr is of shape (K) int32
|
"""
|
||||||
"""
|
Compute the matrix multiplication C = A x B.
|
||||||
infearure_per_bits = 32 // bits
|
A is of shape (M, K) float16
|
||||||
|
B is of shape (K//8, N) int32
|
||||||
pid = tl.program_id(axis=0)
|
C is of shape (M, N) float16
|
||||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
scales is of shape (G, N) float16
|
||||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
zeros is of shape (G, N) float16
|
||||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
g_ptr is of shape (K) int32
|
||||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
"""
|
||||||
group_id = pid // num_pid_in_group
|
infearure_per_bits = 32 // bits
|
||||||
first_pid_m = group_id * GROUP_SIZE_M
|
|
||||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
pid = tl.program_id(axis=0)
|
||||||
pid_m = first_pid_m + (pid % group_size_m)
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
|
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
group_id = pid // num_pid_in_group
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
first_pid_m = group_id * GROUP_SIZE_M
|
||||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
a_mask = (offs_am[:, None] < M)
|
pid_m = first_pid_m + (pid % group_size_m)
|
||||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||||
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
|
||||||
g_ptrs = g_ptr + offs_k
|
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
scales_ptrs = scales_ptr + offs_bn[None, :]
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
|
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||||
|
a_mask = (offs_am[:, None] < M)
|
||||||
shifter = (offs_k % infearure_per_bits) * bits
|
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||||
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
g_ptrs = g_ptr + offs_k
|
||||||
|
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||||
for k in range(0, num_pid_k):
|
scales_ptrs = scales_ptr + offs_bn[None, :]
|
||||||
g_idx = tl.load(g_ptrs)
|
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||||
|
|
||||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
shifter = (offs_k % infearure_per_bits) * bits
|
||||||
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
||||||
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
|
||||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
for k in range(0, num_pid_k):
|
||||||
zeros = (zeros + 1)
|
g_idx = tl.load(g_ptrs)
|
||||||
|
|
||||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||||
|
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
|
||||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||||
b = (b - zeros) * scales # Scale and shift
|
zeros = (zeros + 1)
|
||||||
# ! Convert to fp16
|
|
||||||
b = b.to(tl.float16)
|
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||||
a = a.to(tl.float16)
|
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||||
|
|
||||||
accumulator += tl.dot(a, b)
|
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||||
a_ptrs += BLOCK_SIZE_K
|
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||||
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
b = (b - zeros) * scales # Scale and shift
|
||||||
g_ptrs += BLOCK_SIZE_K
|
# ! Convert to fp16
|
||||||
|
b = b.to(tl.float16)
|
||||||
c = accumulator.to(tl.float16)
|
a = a.to(tl.float16)
|
||||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
|
||||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
accumulator += tl.dot(a, b)
|
||||||
tl.store(c_ptrs, c, mask=c_mask)
|
a_ptrs += BLOCK_SIZE_K
|
||||||
|
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
||||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
g_ptrs += BLOCK_SIZE_K
|
||||||
@triton.autotune(
|
|
||||||
configs=[
|
c = accumulator.to(tl.float16)
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
tl.store(c_ptrs, c, mask=c_mask)
|
||||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
@custom_autotune.autotune(
|
||||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
configs=[
|
||||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
],
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
key=['M', 'N', 'K'],
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
)
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
|
# These provided a benefit on a 3090
|
||||||
@triton.jit
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
scales_ptr, zeros_ptr, g_ptr,
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
M, N, K, bits, maxq,
|
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
stride_am, stride_ak,
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
stride_bk, stride_bn,
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
stride_cm, stride_cn,
|
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||||
stride_scales, stride_zeros,
|
],
|
||||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
key=['M', 'K'],
|
||||||
GROUP_SIZE_M: tl.constexpr):
|
nearest_power_of_two=True,
|
||||||
"""
|
)
|
||||||
Compute the matrix multiplication C = A x B.
|
|
||||||
A is of shape (M, N) float16
|
|
||||||
B is of shape (K//8, N) int32
|
@triton.jit
|
||||||
C is of shape (M, K) float16
|
def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
||||||
scales is of shape (G, N) float16
|
scales_ptr, zeros_ptr, g_ptr,
|
||||||
zeros is of shape (G, N) float16
|
M, N, K, bits, maxq,
|
||||||
g_ptr is of shape (K) int32
|
stride_am, stride_ak,
|
||||||
"""
|
stride_bk, stride_bn,
|
||||||
infearure_per_bits = 32 // bits
|
stride_cm, stride_cn,
|
||||||
|
stride_scales, stride_zeros,
|
||||||
pid = tl.program_id(axis=0)
|
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
GROUP_SIZE_M: tl.constexpr):
|
||||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
"""
|
||||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
Compute the matrix multiplication C = A x B.
|
||||||
num_pid_in_group = GROUP_SIZE_M * num_pid_k
|
A is of shape (M, N) float16
|
||||||
group_id = pid // num_pid_in_group
|
B is of shape (K//8, N) int32
|
||||||
first_pid_m = group_id * GROUP_SIZE_M
|
C is of shape (M, K) float16
|
||||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
scales is of shape (G, N) float16
|
||||||
pid_m = first_pid_m + (pid % group_size_m)
|
zeros is of shape (G, N) float16
|
||||||
pid_k = (pid % num_pid_in_group) // group_size_m
|
g_ptr is of shape (K) int32
|
||||||
|
"""
|
||||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
infearure_per_bits = 32 // bits
|
||||||
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
|
||||||
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
pid = tl.program_id(axis=0)
|
||||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||||
a_mask = (offs_am[:, None] < M)
|
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
num_pid_in_group = GROUP_SIZE_M * num_pid_k
|
||||||
g_ptrs = g_ptr + offs_bk
|
group_id = pid // num_pid_in_group
|
||||||
g_idx = tl.load(g_ptrs)
|
first_pid_m = group_id * GROUP_SIZE_M
|
||||||
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
pid_m = first_pid_m + (pid % group_size_m)
|
||||||
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
|
pid_k = (pid % num_pid_in_group) // group_size_m
|
||||||
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
|
|
||||||
|
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
shifter = (offs_bk % infearure_per_bits) * bits
|
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||||
zeros_shifter = (offs_n % infearure_per_bits) * bits
|
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||||
|
a_mask = (offs_am[:, None] < M)
|
||||||
for k in range(0, num_pid_n):
|
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||||
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
g_ptrs = g_ptr + offs_bk
|
||||||
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
g_idx = tl.load(g_ptrs)
|
||||||
|
|
||||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||||
zeros = (zeros + 1)
|
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
|
||||||
|
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
|
||||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
|
||||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
shifter = (offs_bk % infearure_per_bits) * bits
|
||||||
|
zeros_shifter = (offs_n % infearure_per_bits) * bits
|
||||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
||||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
|
||||||
b = (b - zeros) * scales # Scale and shift
|
for k in range(0, num_pid_n):
|
||||||
b = tl.trans(b)
|
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||||
# ! Convert to fp16
|
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||||
b = b.to(tl.float16)
|
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||||
a = a.to(tl.float16)
|
|
||||||
|
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||||
accumulator += tl.dot(a, b)
|
zeros = (zeros + 1)
|
||||||
a_ptrs += BLOCK_SIZE_N
|
|
||||||
b_ptrs += BLOCK_SIZE_N
|
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||||
scales_ptrs += BLOCK_SIZE_N
|
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||||
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
|
|
||||||
|
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||||
c = accumulator.to(tl.float16)
|
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
|
b = (b - zeros) * scales # Scale and shift
|
||||||
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
|
b = tl.trans(b)
|
||||||
tl.store(c_ptrs, c, mask=c_mask)
|
# ! Convert to fp16
|
||||||
|
b = b.to(tl.float16)
|
||||||
|
a = a.to(tl.float16)
|
||||||
def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
|
||||||
output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16)
|
accumulator += tl.dot(a, b)
|
||||||
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)
|
a_ptrs += BLOCK_SIZE_N
|
||||||
matmul_248_kernel[grid](input, qweight, output,
|
b_ptrs += BLOCK_SIZE_N
|
||||||
scales, qzeros, g_idx,
|
scales_ptrs += BLOCK_SIZE_N
|
||||||
input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,
|
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
|
||||||
input.stride(0), input.stride(1),
|
|
||||||
qweight.stride(0), qweight.stride(1),
|
c = accumulator.to(tl.float16)
|
||||||
output.stride(0), output.stride(1),
|
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
|
||||||
scales.stride(0), qzeros.stride(0))
|
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
|
||||||
return output
|
tl.store(c_ptrs, c, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
|
assert input.shape[1] == qweight.shape[0] * 32 // bits
|
||||||
|
output = torch.empty((input.shape[0], qweight.shape[1]), device=scales.device, dtype=torch.float16)
|
||||||
|
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)
|
||||||
|
matmul_248_kernel[grid](input, qweight, output,
|
||||||
|
scales, qzeros, g_idx,
|
||||||
|
input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,
|
||||||
|
input.stride(0), input.stride(1),
|
||||||
|
qweight.stride(0), qweight.stride(1),
|
||||||
|
output.stride(0), output.stride(1),
|
||||||
|
scales.stride(0), qzeros.stride(0))
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def triton_matmul_transpose(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
|
assert input.shape[1] == qweight.shape[1]
|
||||||
|
output_shape = (input.shape[0], qweight.shape[0] * 32 // bits)
|
||||||
|
output = torch.empty((output_shape[0], output_shape[1]), device=scales.device, dtype=torch.float16)
|
||||||
|
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape[1], META['BLOCK_SIZE_K']),)
|
||||||
|
trans_matmul_248_kernel[grid](input, qweight, output,
|
||||||
|
scales, qzeros, g_idx,
|
||||||
|
input.shape[0], qweight.shape[1], output_shape[1], bits, maxq,
|
||||||
|
input.stride(0), input.stride(1),
|
||||||
|
qweight.stride(0), qweight.stride(1),
|
||||||
|
output.stride(0), output.stride(1),
|
||||||
|
scales.stride(0), qzeros.stride(0))
|
||||||
|
return output
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue