fix bug
This commit is contained in:
parent
1043ded7d9
commit
1719bd0ce3
|
|
@ -106,7 +106,7 @@ def matmul4bit(x, qweight, scales, zeros, groupsize=-1):
|
||||||
if use_new:
|
if use_new:
|
||||||
if auto_switch:
|
if auto_switch:
|
||||||
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
||||||
output = _matmul4bit_v1_recons(x, qweight, scales, zeros)
|
output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales, zeros)
|
||||||
else:
|
else:
|
||||||
output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float())
|
output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float())
|
||||||
else:
|
else:
|
||||||
|
|
@ -116,7 +116,7 @@ def matmul4bit(x, qweight, scales, zeros, groupsize=-1):
|
||||||
if use_new:
|
if use_new:
|
||||||
if auto_switch:
|
if auto_switch:
|
||||||
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
||||||
output = _matmul4bit_v2_recons(x, qweight, scales, zeros, groupsize)
|
output = _matmul4bit_v2_recons(x.to(scales.dtype), qweight, scales, zeros, groupsize)
|
||||||
else:
|
else:
|
||||||
output = _matmul4bit_v2(x, qweight, scales.float(), zeros, groupsize)
|
output = _matmul4bit_v2(x, qweight, scales.float(), zeros, groupsize)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue