From 234004ceb5135e092bc9a08a9dbb75eff61f8fd9 Mon Sep 17 00:00:00 2001 From: John Smith Date: Tue, 28 Mar 2023 22:05:18 +0800 Subject: [PATCH] fix bug --- matmul_utils_4bit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/matmul_utils_4bit.py b/matmul_utils_4bit.py index b897f70..b476812 100644 --- a/matmul_utils_4bit.py +++ b/matmul_utils_4bit.py @@ -108,9 +108,9 @@ def matmul4bit(x, qweight, scales, zeros, groupsize=-1): if np.prod(x.shape[:-1]) > auto_switch_thd: output = _matmul4bit_v1_recons(x, qweight, scales, zeros) else: - output = _matmul4bit_v1(x, qweight, scales, zeros) + output = _matmul4bit_v1(x, qweight, scales.float(), zeros) else: - output = _matmul4bit_v1(x, qweight, scales, zeros) + output = _matmul4bit_v1(x, qweight, scales.float(), zeros) else: # use v2 if use_new: @@ -118,9 +118,9 @@ def matmul4bit(x, qweight, scales, zeros, groupsize=-1): if np.prod(x.shape[:-1]) > auto_switch_thd: output = _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize) else: - output = _matmul4bit_v2(x, qweight, scales, zeros, groupsize) + output = _matmul4bit_v2(x, qweight, scales.float(), zeros, groupsize) else: - output = _matmul4bit_v2(x, qweight, scales, zeros, groupsize) + output = _matmul4bit_v2(x, qweight, scales.float(), zeros, groupsize) return output