diff --git a/autograd_4bit.py b/autograd_4bit.py index cf0faa1..55430ec 100644 --- a/autograd_4bit.py +++ b/autograd_4bit.py @@ -44,6 +44,7 @@ try: output = tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bits, maxq) ctx.save_for_backward(qweight, scales, qzeros, g_idx) ctx.bits, ctx.maxq = bits, maxq + output = output.clone() return output @staticmethod diff --git a/triton_utils.py b/triton_utils.py index 9722628..7afcf46 100644 --- a/triton_utils.py +++ b/triton_utils.py @@ -211,7 +211,9 @@ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): - assert input.shape[1] == qweight.shape[0] * 32 // bits + assert input.shape[-1] == qweight.shape[0] * 32 // bits + outshape = input.shape[:-1] + (qweight.shape[1],) + input = input.reshape(-1, input.shape[-1]) output = torch.empty((input.shape[0], qweight.shape[1]), device=scales.device, dtype=torch.float16) grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),) matmul_248_kernel[grid](input, qweight, output, @@ -221,19 +223,24 @@ def triton_matmul(input, qweight, scales, qzeros, g_idx, bits, maxq): qweight.stride(0), qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) + output = output.reshape(outshape) return output def triton_matmul_transpose(input, qweight, scales, qzeros, g_idx, bits, maxq): - assert input.shape[1] == qweight.shape[1] - output_shape = (input.shape[0], qweight.shape[0] * 32 // bits) - output = torch.empty((output_shape[0], output_shape[1]), device=scales.device, dtype=torch.float16) - grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape[1], META['BLOCK_SIZE_K']),) + assert input.shape[-1] == qweight.shape[1] + out_dim = qweight.shape[0] * 32 // bits + outshape = input.shape[:-1] + (out_dim,) + input = input.reshape(-1, input.shape[-1]) + output_shape_mid = (input.shape[0], out_dim) + output = torch.empty((output_shape_mid[0], output_shape_mid[1]), device=scales.device, dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_shape_mid[1], META['BLOCK_SIZE_K']),) trans_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, - input.shape[0], qweight.shape[1], output_shape[1], bits, maxq, + input.shape[0], qweight.shape[1], output_shape_mid[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) + output = output.reshape(outshape) return output