fix bug on triton matmul
This commit is contained in:
parent
dba3773b30
commit
32904da1ff
|
|
@ -44,6 +44,7 @@ try:
|
|||
output = tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
|
||||
ctx.bits, ctx.maxq = bits, maxq
|
||||
output = output.clone()
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -211,7 +211,9 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
|||
|
||||
|
||||
def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
assert input.shape[1] == qweight.shape[0] * 32 // bits
|
||||
assert input.shape[-1] == qweight.shape[0] * 32 // bits
|
||||
outshape = input.shape[:-1] + (qweight.shape[1],)
|
||||
input = input.reshape(-1, input.shape[-1])
|
||||
output = torch.empty((input.shape[0], qweight.shape[1]), device=scales.device, dtype=torch.float16)
|
||||
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)
|
||||
matmul_248_kernel[grid](input, qweight, output,
|
||||
|
|
@ -221,19 +223,24 @@ def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
|||
qweight.stride(0), qweight.stride(1),
|
||||
output.stride(0), output.stride(1),
|
||||
scales.stride(0), qzeros.stride(0))
|
||||
output = output.reshape(outshape)
|
||||
return output
|
||||
|
||||
|
||||
def triton_matmul_transpose(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
assert input.shape[1] == qweight.shape[1]
|
||||
output_shape = (input.shape[0], qweight.shape[0] * 32 // bits)
|
||||
output = torch.empty((output_shape[0], output_shape[1]), device=scales.device, dtype=torch.float16)
|
||||
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape[1], META['BLOCK_SIZE_K']),)
|
||||
assert input.shape[-1] == qweight.shape[1]
|
||||
out_dim = qweight.shape[0] * 32 // bits
|
||||
outshape = input.shape[:-1] + (out_dim,)
|
||||
input = input.reshape(-1, input.shape[-1])
|
||||
output_shape_mid = (input.shape[0], out_dim)
|
||||
output = torch.empty((output_shape_mid[0], output_shape_mid[1]), device=scales.device, dtype=torch.float16)
|
||||
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape_mid[1], META['BLOCK_SIZE_K']),)
|
||||
trans_matmul_248_kernel[grid](input, qweight, output,
|
||||
scales, qzeros, g_idx,
|
||||
input.shape[0], qweight.shape[1], output_shape[1], bits, maxq,
|
||||
input.shape[0], qweight.shape[1], output_shape_mid[1], bits, maxq,
|
||||
input.stride(0), input.stride(1),
|
||||
qweight.stride(0), qweight.stride(1),
|
||||
output.stride(0), output.stride(1),
|
||||
scales.stride(0), qzeros.stride(0))
|
||||
output = output.reshape(outshape)
|
||||
return output
|
||||
|
|
|
|||
Loading…
Reference in New Issue