diff --git a/matmul_utils_4bit.py b/matmul_utils_4bit.py index 7fca98b..e9d621f 100644 --- a/matmul_utils_4bit.py +++ b/matmul_utils_4bit.py @@ -106,7 +106,7 @@ def matmul4bit(x, qweight, scales, zeros, groupsize=-1): if use_new: if auto_switch: if np.prod(x.shape[:-1]) > auto_switch_thd: - output = _matmul4bit_v1_recons(x, qweight, scales, zeros) + output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales, zeros) else: output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float()) else: @@ -116,7 +116,7 @@ def matmul4bit(x, qweight, scales, zeros, groupsize=-1): if use_new: if auto_switch: if np.prod(x.shape[:-1]) > auto_switch_thd: - output = _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize) + output = _matmul4bit_v2_recons(x.to(scales.dtype), qweight, scales, zeros, groupsize) else: output = _matmul4bit_v2(x, qweight, scales.float(), zeros, groupsize) else: