diff --git a/matmul_utils_4bit.py b/matmul_utils_4bit.py index 009093c..2aaa0ad 100644 --- a/matmul_utils_4bit.py +++ b/matmul_utils_4bit.py @@ -102,7 +102,7 @@ def _matmul4bit_v2_recons(x, qweight, scales, zeros, g_idx, transpose=False): def matmul4bit(x, qweight, scales, zeros, g_idx=None): # detect if zeros is int32 - if zeros.dtype == torch.int32: + if zeros.dtype != torch.int32: # use v1 if use_new: if auto_switch: