fix bug on triton matmul

This commit is contained in:
John Smith 2023-04-07 15:50:55 +08:00
parent dba3773b30
commit 32904da1ff
2 changed files with 14 additions and 6 deletions

View File

@ -44,6 +44,7 @@ try:
output = tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq)
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.bits, ctx.maxq = bits, maxq
output = output.clone()
return output
@staticmethod

View File

@ -211,7 +211,9 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
assert input.shape[1] == qweight.shape[0] * 32 // bits
assert input.shape[-1] == qweight.shape[0] * 32 // bits
outshape = input.shape[:-1] + (qweight.shape[1],)
input = input.reshape(-1, input.shape[-1])
output = torch.empty((input.shape[0], qweight.shape[1]), device=scales.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)
matmul_248_kernel[grid](input, qweight, output,
@ -221,19 +223,24 @@ def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0))
output = output.reshape(outshape)
return output
def triton_matmul_transpose(input, qweight, scales, qzeros, g_idx, bits, maxq):
assert input.shape[1] == qweight.shape[1]
output_shape = (input.shape[0], qweight.shape[0] * 32 // bits)
output = torch.empty((output_shape[0], output_shape[1]), device=scales.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape[1], META['BLOCK_SIZE_K']),)
assert input.shape[-1] == qweight.shape[1]
out_dim = qweight.shape[0] * 32 // bits
outshape = input.shape[:-1] + (out_dim,)
input = input.reshape(-1, input.shape[-1])
output_shape_mid = (input.shape[0], out_dim)
output = torch.empty((output_shape_mid[0], output_shape_mid[1]), device=scales.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape_mid[1], META['BLOCK_SIZE_K']),)
trans_matmul_248_kernel[grid](input, qweight, output,
scales, qzeros, g_idx,
input.shape[0], qweight.shape[1], output_shape[1], bits, maxq,
input.shape[0], qweight.shape[1], output_shape_mid[1], bits, maxq,
input.stride(0), input.stride(1),
qweight.stride(0), qweight.stride(1),
output.stride(0), output.stride(1),
scales.stride(0), qzeros.stride(0))
output = output.reshape(outshape)
return output