parent
4a2d23aa29
commit
0d271d5d90
|
|
@ -62,7 +62,7 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
||||||
g_ptrs = g_ptr + offs_k
|
g_ptrs = g_ptr + offs_k
|
||||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||||
scales_ptrs = scales_ptr + offs_bn[None, :]
|
scales_ptrs = scales_ptr + offs_bn[None, :]
|
||||||
zeros_ptrs = zeros_ptr + (offs_bn[None, :]// infearure_per_bits)
|
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||||
|
|
||||||
shifter = (offs_k % infearure_per_bits) * bits
|
shifter = (offs_k % infearure_per_bits) * bits
|
||||||
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
||||||
|
|
@ -78,12 +78,15 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
||||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||||
zeros = (zeros + 1)
|
zeros = (zeros + 1)
|
||||||
|
|
||||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||||
|
|
||||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||||
b = (b - zeros) * scales # Scale and shift
|
b = (b - zeros) * scales # Scale and shift
|
||||||
|
# ! Convert to fp16
|
||||||
|
b = b.to(tl.float16)
|
||||||
|
a = a.to(tl.float16)
|
||||||
|
|
||||||
accumulator += tl.dot(a, b)
|
accumulator += tl.dot(a, b)
|
||||||
a_ptrs += BLOCK_SIZE_K
|
a_ptrs += BLOCK_SIZE_K
|
||||||
|
|
@ -93,7 +96,7 @@ def matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
||||||
c = accumulator.to(tl.float16)
|
c = accumulator.to(tl.float16)
|
||||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
tl.store(c_ptrs, c, mask=c_mask)
|
||||||
|
|
||||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
|
|
@ -157,7 +160,7 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
||||||
|
|
||||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||||
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
|
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
|
||||||
zeros_ptrs = zeros_ptr + (offs_n[None, :]// infearure_per_bits) + g_idx[:, None] * stride_zeros
|
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
|
||||||
|
|
||||||
shifter = (offs_bk % infearure_per_bits) * bits
|
shifter = (offs_bk % infearure_per_bits) * bits
|
||||||
zeros_shifter = (offs_n % infearure_per_bits) * bits
|
zeros_shifter = (offs_n % infearure_per_bits) * bits
|
||||||
|
|
@ -178,6 +181,9 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
||||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||||
b = (b - zeros) * scales # Scale and shift
|
b = (b - zeros) * scales # Scale and shift
|
||||||
b = tl.trans(b)
|
b = tl.trans(b)
|
||||||
|
# ! Convert to fp16
|
||||||
|
b = b.to(tl.float16)
|
||||||
|
a = a.to(tl.float16)
|
||||||
|
|
||||||
accumulator += tl.dot(a, b)
|
accumulator += tl.dot(a, b)
|
||||||
a_ptrs += BLOCK_SIZE_N
|
a_ptrs += BLOCK_SIZE_N
|
||||||
|
|
@ -188,7 +194,7 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
||||||
c = accumulator.to(tl.float16)
|
c = accumulator.to(tl.float16)
|
||||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
|
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
|
||||||
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
|
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
|
||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
tl.store(c_ptrs, c, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
|
|
@ -202,4 +208,3 @@ def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
output.stride(0), output.stride(1),
|
output.stride(0), output.stride(1),
|
||||||
scales.stride(0), qzeros.stride(0))
|
scales.stride(0), qzeros.stride(0))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
Loading…
Reference in New Issue