fix bug on triton matmul
This commit is contained in:
parent
dba3773b30
commit
32904da1ff
|
|
@ -44,6 +44,7 @@ try:
|
||||||
output = tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq)
|
output = tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||||
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
|
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
|
||||||
ctx.bits, ctx.maxq = bits, maxq
|
ctx.bits, ctx.maxq = bits, maxq
|
||||||
|
output = output.clone()
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -211,7 +211,9 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
|
||||||
|
|
||||||
|
|
||||||
def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
assert input.shape[1] == qweight.shape[0] * 32 // bits
|
assert input.shape[-1] == qweight.shape[0] * 32 // bits
|
||||||
|
outshape = input.shape[:-1] + (qweight.shape[1],)
|
||||||
|
input = input.reshape(-1, input.shape[-1])
|
||||||
output = torch.empty((input.shape[0], qweight.shape[1]), device=scales.device, dtype=torch.float16)
|
output = torch.empty((input.shape[0], qweight.shape[1]), device=scales.device, dtype=torch.float16)
|
||||||
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)
|
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)
|
||||||
matmul_248_kernel[grid](input, qweight, output,
|
matmul_248_kernel[grid](input, qweight, output,
|
||||||
|
|
@ -221,19 +223,24 @@ def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
qweight.stride(0), qweight.stride(1),
|
qweight.stride(0), qweight.stride(1),
|
||||||
output.stride(0), output.stride(1),
|
output.stride(0), output.stride(1),
|
||||||
scales.stride(0), qzeros.stride(0))
|
scales.stride(0), qzeros.stride(0))
|
||||||
|
output = output.reshape(outshape)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def triton_matmul_transpose(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
def triton_matmul_transpose(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
assert input.shape[1] == qweight.shape[1]
|
assert input.shape[-1] == qweight.shape[1]
|
||||||
output_shape = (input.shape[0], qweight.shape[0] * 32 // bits)
|
out_dim = qweight.shape[0] * 32 // bits
|
||||||
output = torch.empty((output_shape[0], output_shape[1]), device=scales.device, dtype=torch.float16)
|
outshape = input.shape[:-1] + (out_dim,)
|
||||||
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape[1], META['BLOCK_SIZE_K']),)
|
input = input.reshape(-1, input.shape[-1])
|
||||||
|
output_shape_mid = (input.shape[0], out_dim)
|
||||||
|
output = torch.empty((output_shape_mid[0], output_shape_mid[1]), device=scales.device, dtype=torch.float16)
|
||||||
|
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape_mid[1], META['BLOCK_SIZE_K']),)
|
||||||
trans_matmul_248_kernel[grid](input, qweight, output,
|
trans_matmul_248_kernel[grid](input, qweight, output,
|
||||||
scales, qzeros, g_idx,
|
scales, qzeros, g_idx,
|
||||||
input.shape[0], qweight.shape[1], output_shape[1], bits, maxq,
|
input.shape[0], qweight.shape[1], output_shape_mid[1], bits, maxq,
|
||||||
input.stride(0), input.stride(1),
|
input.stride(0), input.stride(1),
|
||||||
qweight.stride(0), qweight.stride(1),
|
qweight.stride(0), qweight.stride(1),
|
||||||
output.stride(0), output.stride(1),
|
output.stride(0), output.stride(1),
|
||||||
scales.stride(0), qzeros.stride(0))
|
scales.stride(0), qzeros.stride(0))
|
||||||
|
output = output.reshape(outshape)
|
||||||
return output
|
return output
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue