This commit is contained in:
John Smith 2023-04-06 10:46:42 +08:00
parent 86387a0a35
commit 085d9556f9
2 changed files with 207 additions and 207 deletions

View File

@ -48,7 +48,7 @@ class Autograd4bitQuantLinear(nn.Module):
) )
self.register_buffer('scales', torch.empty((math.ceil(infeatures/groupsize), outfeatures))) self.register_buffer('scales', torch.empty((math.ceil(infeatures/groupsize), outfeatures)))
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype = torch.int32)) self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype = torch.int32))
self.bias = nn.Parameter(torch.empty(outfeatures)) self.register_buffer('bias', torch.empty(outfeatures))
self.register_buffer( self.register_buffer(
'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int) 'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)
) )
@ -58,11 +58,11 @@ class Autograd4bitQuantLinear(nn.Module):
if torch.is_grad_enabled(): if torch.is_grad_enabled():
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, out = AutogradMatmul4bit.apply(x, self.qweight, self.scales,
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize) self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
out.add_(self.bias) out += self.bias
else: else:
out = mm4b.matmul4bit(x, self.qweight, self.scales, out = mm4b.matmul4bit(x, self.qweight, self.scales,
self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize) self.qzeros if self.groupsize != -1 else self.zeros, self.groupsize)
out.add_(self.bias) out += self.bias
return out return out

View File

@ -1,205 +1,205 @@
import triton import triton
import triton.language as tl import triton.language as tl
import torch import torch
# code based https://github.com/fpgaminer/GPTQ-triton # code based https://github.com/fpgaminer/GPTQ-triton
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
], ],
key=['M', 'N', 'K'], key=['M', 'N', 'K'],
) )
@triton.jit @triton.jit
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, def matmul_248_kernel(a_ptr, b_ptr, c_ptr,
scales_ptr, zeros_ptr, g_ptr, scales_ptr, zeros_ptr, g_ptr,
M, N, K, bits, maxq, M, N, K, bits, maxq,
stride_am, stride_ak, stride_am, stride_ak,
stride_bk, stride_bn, stride_bk, stride_bn,
stride_cm, stride_cn, stride_cm, stride_cn,
stride_scales, stride_zeros, stride_scales, stride_zeros,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr): GROUP_SIZE_M: tl.constexpr):
""" """
Compute the matrix multiplication C = A x B. Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16 A is of shape (M, K) float16
B is of shape (K//8, N) int32 B is of shape (K//8, N) int32
C is of shape (M, N) float16 C is of shape (M, N) float16
scales is of shape (G, N) float16 scales is of shape (G, N) float16
zeros is of shape (G, N) float16 zeros is of shape (G, N) float16
g_ptr is of shape (K) int32 g_ptr is of shape (K) int32
""" """
infearure_per_bits = 32 // bits infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m) pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = (offs_am[:, None] < M) a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times # b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B # shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :] scales_ptrs = scales_ptr + offs_bn[None, :]
zeros_ptrs = zeros_ptr + (offs_bn[None, :]// infearure_per_bits) zeros_ptrs = zeros_ptr + (offs_bn[None, :]// infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k): for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs) g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1) zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values # Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift b = (b - zeros) * scales # Scale and shift
accumulator += tl.dot(a, b) accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K g_ptrs += BLOCK_SIZE_K
c = accumulator.to(tl.float16) c = accumulator.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
# code based https://github.com/fpgaminer/GPTQ-triton # code based https://github.com/fpgaminer/GPTQ-triton
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
], ],
key=['M', 'N', 'K'], key=['M', 'N', 'K'],
) )
@triton.jit @triton.jit
def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
scales_ptr, zeros_ptr, g_ptr, scales_ptr, zeros_ptr, g_ptr,
M, N, K, bits, maxq, M, N, K, bits, maxq,
stride_am, stride_ak, stride_am, stride_ak,
stride_bk, stride_bn, stride_bk, stride_bn,
stride_cm, stride_cn, stride_cm, stride_cn,
stride_scales, stride_zeros, stride_scales, stride_zeros,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr): GROUP_SIZE_M: tl.constexpr):
""" """
Compute the matrix multiplication C = A x B. Compute the matrix multiplication C = A x B.
A is of shape (M, N) float16 A is of shape (M, N) float16
B is of shape (K//8, N) int32 B is of shape (K//8, N) int32
C is of shape (M, K) float16 C is of shape (M, K) float16
scales is of shape (G, N) float16 scales is of shape (G, N) float16
zeros is of shape (G, N) float16 zeros is of shape (G, N) float16
g_ptr is of shape (K) int32 g_ptr is of shape (K) int32
""" """
infearure_per_bits = 32 // bits infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_k num_pid_in_group = GROUP_SIZE_M * num_pid_k
group_id = pid // num_pid_in_group group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m) pid_m = first_pid_m + (pid % group_size_m)
pid_k = (pid % num_pid_in_group) // group_size_m pid_k = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_n = tl.arange(0, BLOCK_SIZE_N) offs_n = tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
a_mask = (offs_am[:, None] < M) a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times # b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_bk g_ptrs = g_ptr + offs_bk
g_idx = tl.load(g_ptrs) g_idx = tl.load(g_ptrs)
# shifter is used to extract the N bits of each element in the 32-bit word from B # shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
zeros_ptrs = zeros_ptr + (offs_n[None, :]// infearure_per_bits) + g_idx[:, None] * stride_zeros zeros_ptrs = zeros_ptr + (offs_n[None, :]// infearure_per_bits) + g_idx[:, None] * stride_zeros
shifter = (offs_bk % infearure_per_bits) * bits shifter = (offs_bk % infearure_per_bits) * bits
zeros_shifter = (offs_n % infearure_per_bits) * bits zeros_shifter = (offs_n % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
for k in range(0, num_pid_n): for k in range(0, num_pid_n):
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1) zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values # Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift b = (b - zeros) * scales # Scale and shift
b = tl.trans(b) b = tl.trans(b)
accumulator += tl.dot(a, b) accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_N a_ptrs += BLOCK_SIZE_N
b_ptrs += BLOCK_SIZE_N b_ptrs += BLOCK_SIZE_N
scales_ptrs += BLOCK_SIZE_N scales_ptrs += BLOCK_SIZE_N
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
c = accumulator.to(tl.float16) c = accumulator.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16) output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16)
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),) grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)
matmul_248_kernel[grid](input, qweight, output, matmul_248_kernel[grid](input, qweight, output,
scales, qzeros, g_idx, scales, qzeros, g_idx,
input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,
input.stride(0), input.stride(1), input.stride(0), input.stride(1),
qweight.stride(0), qweight.stride(1), qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1), output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0)) scales.stride(0), qzeros.stride(0))
return output return output