optimized matmul for v2 model

This commit is contained in:
John Smith 2023-04-22 23:01:39 +08:00
parent 9fe5ab3642
commit 82bbea2729
2 changed files with 10 additions and 8 deletions

View File

@ -107,6 +107,7 @@ class Autograd4bitQuantLinear(nn.Module):
groupsize = groupsize if groupsize != -1 else in_features
self.groupsize = groupsize
self.is_v1_model = is_v1_model
self.disable_bias = True
if is_v1_model:
self.register_buffer('zeros', torch.empty((out_features, 1)))
self.register_buffer('scales', torch.empty((out_features, 1)))
@ -132,7 +133,8 @@ class Autograd4bitQuantLinear(nn.Module):
out = matmul4bit_with_backend(x, self.qweight, self.scales,
self.qzeros if not self.is_v1_model else self.zeros,
self.g_idx, self.bits, self.maxq)
out += self.bias
if not self.disable_bias:
out += self.bias
return out

View File

@ -63,7 +63,7 @@ def _matmul4bit_v2(x, qweight, scales, zeros, g_idx):
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = x.shape[:-1] + (qweight.shape[1],)
x = x.reshape(-1, x.shape[-1])
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float16, device=x.device)
dtype = x.dtype
if faster:
x = x.half()
@ -114,11 +114,11 @@ def matmul4bit(x, qweight, scales, zeros, g_idx=None):
if use_new:
if auto_switch:
if np.prod(x.shape[:-1]) > auto_switch_thd:
output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales, zeros)
output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales.half(), zeros.half())
else:
output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float())
output = _matmul4bit_v1(x, qweight, scales, zeros)
else:
output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float())
output = _matmul4bit_v1(x, qweight, scales, zeros)
else:
if g_idx is None:
g_idx = torch.zeros(qweight.shape[0] * 8, dtype=torch.int32, device=x.device)
@ -126,11 +126,11 @@ def matmul4bit(x, qweight, scales, zeros, g_idx=None):
if use_new:
if auto_switch:
if np.prod(x.shape[:-1]) > auto_switch_thd:
output = _matmul4bit_v2_recons(x.to(scales.dtype), qweight, scales, zeros, g_idx)
output = _matmul4bit_v2_recons(x.half(), qweight, scales.half(), zeros, g_idx)
else:
output = _matmul4bit_v2(x, qweight, scales.float(), zeros, g_idx)
output = _matmul4bit_v2(x, qweight, scales, zeros, g_idx)
else:
output = _matmul4bit_v2(x, qweight, scales.float(), zeros, g_idx)
output = _matmul4bit_v2(x, qweight, scales, zeros, g_idx)
return output