Add files via upload

Fix triton kernels
This commit is contained in:
Andrey Glushenkov 2023-04-06 02:38:06 +03:00 committed by GitHub
parent 4a2d23aa29
commit 0d271d5d90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 6 deletions

View File

@ -78,12 +78,15 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr,
zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1) zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values # Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift b = (b - zeros) * scales # Scale and shift
# ! Convert to fp16
b = b.to(tl.float16)
a = a.to(tl.float16)
accumulator += tl.dot(a, b) accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K a_ptrs += BLOCK_SIZE_K
@ -93,7 +96,7 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr,
c = accumulator.to(tl.float16) c = accumulator.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, c, mask=c_mask)
# code based https://github.com/fpgaminer/GPTQ-triton # code based https://github.com/fpgaminer/GPTQ-triton
@triton.autotune( @triton.autotune(
@ -178,6 +181,9 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift b = (b - zeros) * scales # Scale and shift
b = tl.trans(b) b = tl.trans(b)
# ! Convert to fp16
b = b.to(tl.float16)
a = a.to(tl.float16)
accumulator += tl.dot(a, b) accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_N a_ptrs += BLOCK_SIZE_N
@ -188,7 +194,7 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
c = accumulator.to(tl.float16) c = accumulator.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, c, mask=c_mask)
def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
@ -202,4 +208,3 @@ def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
output.stride(0), output.stride(1), output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0)) scales.stride(0), qzeros.stride(0))
return output return output