155 lines
7.0 KiB
Python
155 lines
7.0 KiB
Python
import torch
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
# %
|
|
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune`
|
|
# decorator, which consumes:
|
|
# - A list of :code:`triton.Config` objects that define different configurations of
|
|
# meta-parameters (e.g., BLOCK_SIZE_M) and compilation options (e.g., num_warps) to try
|
|
# - An autotuning *key* whose change in values will trigger evaluation of all the
|
|
# provided configs
|
|
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
|
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
|
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
|
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
|
],
|
|
key=['M', 'N', 'K'],
|
|
)
|
|
@triton.jit
|
|
def matmul_kernel(
|
|
# Pointers to matrices
|
|
a_ptr, b_ptr, c_ptr,
|
|
# Matrix dimensions
|
|
M, N, K,
|
|
# The stride variables represent how much to increase the ptr by when moving by 1
|
|
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
|
# by to get the element one row down (A has M rows)
|
|
stride_am, stride_ak,
|
|
stride_bk, stride_bn,
|
|
stride_cm, stride_cn,
|
|
# Meta-parameters
|
|
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
ACTIVATION: tl.constexpr,
|
|
):
|
|
"""Kernel for computing the matmul C = A x B.
|
|
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
|
"""
|
|
# -----------------------------------------------------------
|
|
# Map program ids `pid` to the block of C it should compute.
|
|
# This is done in a grouped ordering to promote L2 data reuse
|
|
# See above `L2 Cache Optimizations` section for details
|
|
pid = tl.program_id(axis=0)
|
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
group_id = pid // num_pid_in_group
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
pid_m = first_pid_m + (pid % group_size_m)
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
# ----------------------------------------------------------
|
|
# Create pointers for the first blocks of A and B.
|
|
# We will advance this pointer as we move in the K direction
|
|
# and accumulate
|
|
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
|
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
|
|
# see above `Pointer Arithmetics` section for details
|
|
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
|
|
# -----------------------------------------------------------
|
|
# Iterate to compute a block of the C matrix
|
|
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
|
# of fp32 values for higher accuracy.
|
|
# `accumulator` will be converted back to fp16 after the loop
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
for k in range(0, K, BLOCK_SIZE_K):
|
|
# Note that for simplicity, we don't apply a mask here.
|
|
# This means that if K is not a multiple of BLOCK_SIZE_K,
|
|
# this will access out-of-bounds memory and produce an
|
|
# error or (worse!) incorrect results.
|
|
a = tl.load(a_ptrs)
|
|
b = tl.load(b_ptrs)
|
|
# We accumulate along the K dimension
|
|
accumulator += tl.dot(a, b)
|
|
# Advance the ptrs to the next K block
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
# you can fuse arbitrary activation functions here
|
|
# while the accumulator is still in FP32!
|
|
if ACTIVATION == "leaky_relu":
|
|
accumulator = leaky_relu(accumulator)
|
|
c = accumulator.to(tl.float16)
|
|
|
|
# -----------------------------------------------------------
|
|
# Write back the block of the output matrix C
|
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, c, mask=c_mask)
|
|
|
|
|
|
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
|
@triton.jit
|
|
def leaky_relu(x):
|
|
x = x + 1
|
|
return tl.where(x >= 0, x, 0.01 * x)
|
|
|
|
def matmul(a, b, activation=""):
|
|
# checks constraints
|
|
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
|
assert a.is_contiguous(), "matrix A must be contiguous"
|
|
assert b.is_contiguous(), "matrix B must be contiguous"
|
|
M, K = a.shape
|
|
K, N = b.shape
|
|
assert (
|
|
K % 32 == 0
|
|
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
|
|
# allocates output
|
|
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
|
# 1D launch kernel where each block gets its own program.
|
|
grid = lambda META: (
|
|
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
|
)
|
|
matmul_kernel[grid](
|
|
a, b, c,
|
|
M, N, K,
|
|
a.stride(0), a.stride(1),
|
|
b.stride(0), b.stride(1),
|
|
c.stride(0), c.stride(1),
|
|
ACTIVATION=activation,
|
|
)
|
|
return c
|
|
|
|
|
|
|
|
torch.manual_seed(0)
|
|
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
|
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
|
triton_output = matmul(a, b)
|
|
torch_output = torch.matmul(a, b)
|
|
print(f"triton_output={triton_output}")
|
|
print(f"torch_output={torch_output}")
|
|
if triton.testing.allclose(triton_output, torch_output):
|
|
print("✅ Triton and Torch match")
|
|
else:
|
|
print("❌ Triton and Torch differ")
|