diff --git a/matmul_utils_4bit.py b/matmul_utils_4bit.py index b476812..3d1d60b 100644 --- a/matmul_utils_4bit.py +++ b/matmul_utils_4bit.py @@ -108,9 +108,9 @@ def matmul4bit(x, qweight, scales, zeros, groupsize=-1): if np.prod(x.shape[:-1]) > auto_switch_thd: output = _matmul4bit_v1_recons(x, qweight, scales, zeros) else: - output = _matmul4bit_v1(x, qweight, scales.float(), zeros) + output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float()) else: - output = _matmul4bit_v1(x, qweight, scales.float(), zeros) + output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float()) else: # use v2 if use_new: