From 0d271d5d90dff890be586af6177f076fe4448ddc Mon Sep 17 00:00:00 2001 From: Andrey Glushenkov Date: Thu, 6 Apr 2023 02:38:06 +0300 Subject: [PATCH] Add files via upload Fix triton kernels --- triton_utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/triton_utils.py b/triton_utils.py index 7f50c5e..57f9e66 100644 --- a/triton_utils.py +++ b/triton_utils.py @@ -62,7 +62,7 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr, g_ptrs = g_ptr + offs_k # shifter is used to extract the N bits of each element in the 32-bit word from B scales_ptrs = scales_ptr + offs_bn[None, :] - zeros_ptrs = zeros_ptr + (offs_bn[None, :]// infearure_per_bits) + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) shifter = (offs_k % infearure_per_bits) * bits zeros_shifter = (offs_bn % infearure_per_bits) * bits @@ -78,12 +78,15 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr, zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros + 1) - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated # Now we need to unpack b (which is N-bit values) into 32-bit values b = (b >> shifter[:, None]) & maxq # Extract the N-bit values b = (b - zeros) * scales # Scale and shift + # ! Convert to fp16 + b = b.to(tl.float16) + a = a.to(tl.float16) accumulator += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K @@ -93,7 +96,7 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr, c = accumulator.to(tl.float16) c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) + tl.store(c_ptrs, c, mask=c_mask) # code based https://github.com/fpgaminer/GPTQ-triton @triton.autotune( @@ -157,7 +160,7 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, # shifter is used to extract the N bits of each element in the 32-bit word from B scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales - zeros_ptrs = zeros_ptr + (offs_n[None, :]// infearure_per_bits) + g_idx[:, None] * stride_zeros + zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros shifter = (offs_bk % infearure_per_bits) * bits zeros_shifter = (offs_n % infearure_per_bits) * bits @@ -178,6 +181,9 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, b = (b >> shifter[:, None]) & maxq # Extract the N-bit values b = (b - zeros) * scales # Scale and shift b = tl.trans(b) + # ! Convert to fp16 + b = b.to(tl.float16) + a = a.to(tl.float16) accumulator += tl.dot(a, b) a_ptrs += BLOCK_SIZE_N @@ -188,7 +194,7 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, c = accumulator.to(tl.float16) c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) - tl.store(c_ptrs, accumulator, mask=c_mask) + tl.store(c_ptrs, c, mask=c_mask) def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): @@ -202,4 +208,3 @@ def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) return output - \ No newline at end of file