parent
4a2d23aa29
commit
0d271d5d90
|
|
@ -62,7 +62,7 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
|||
g_ptrs = g_ptr + offs_k
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_bn[None, :]
|
||||
zeros_ptrs = zeros_ptr + (offs_bn[None, :]// infearure_per_bits)
|
||||
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||
|
||||
shifter = (offs_k % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
||||
|
|
@ -78,12 +78,15 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
|||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros) * scales # Scale and shift
|
||||
# ! Convert to fp16
|
||||
b = b.to(tl.float16)
|
||||
a = a.to(tl.float16)
|
||||
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
|
|
@ -93,7 +96,7 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
|||
c = accumulator.to(tl.float16)
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||
@triton.autotune(
|
||||
|
|
@ -157,7 +160,7 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
|||
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
|
||||
zeros_ptrs = zeros_ptr + (offs_n[None, :]// infearure_per_bits) + g_idx[:, None] * stride_zeros
|
||||
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
|
||||
|
||||
shifter = (offs_bk % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_n % infearure_per_bits) * bits
|
||||
|
|
@ -178,6 +181,9 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
|||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros) * scales # Scale and shift
|
||||
b = tl.trans(b)
|
||||
# ! Convert to fp16
|
||||
b = b.to(tl.float16)
|
||||
a = a.to(tl.float16)
|
||||
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_N
|
||||
|
|
@ -188,7 +194,7 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
|||
c = accumulator.to(tl.float16)
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
|
|
@ -202,4 +208,3 @@ def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
|||
output.stride(0), output.stride(1),
|
||||
scales.stride(0), qzeros.stride(0))
|
||||
return output
|
||||
|
||||
Loading…
Reference in New Issue