optimized matmul for v2 model
This commit is contained in:
parent
9fe5ab3642
commit
82bbea2729
|
|
@ -107,6 +107,7 @@ class Autograd4bitQuantLinear(nn.Module):
|
|||
groupsize = groupsize if groupsize != -1 else in_features
|
||||
self.groupsize = groupsize
|
||||
self.is_v1_model = is_v1_model
|
||||
self.disable_bias = True
|
||||
if is_v1_model:
|
||||
self.register_buffer('zeros', torch.empty((out_features, 1)))
|
||||
self.register_buffer('scales', torch.empty((out_features, 1)))
|
||||
|
|
@ -132,7 +133,8 @@ class Autograd4bitQuantLinear(nn.Module):
|
|||
out = matmul4bit_with_backend(x, self.qweight, self.scales,
|
||||
self.qzeros if not self.is_v1_model else self.zeros,
|
||||
self.g_idx, self.bits, self.maxq)
|
||||
out += self.bias
|
||||
if not self.disable_bias:
|
||||
out += self.bias
|
||||
return out
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ def _matmul4bit_v2(x, qweight, scales, zeros, g_idx):
|
|||
assert qweight.shape[0] * 8 == x.shape[-1]
|
||||
outshape = x.shape[:-1] + (qweight.shape[1],)
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
||||
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float16, device=x.device)
|
||||
dtype = x.dtype
|
||||
if faster:
|
||||
x = x.half()
|
||||
|
|
@ -114,11 +114,11 @@ def matmul4bit(x, qweight, scales, zeros, g_idx=None):
|
|||
if use_new:
|
||||
if auto_switch:
|
||||
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
||||
output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales, zeros)
|
||||
output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales.half(), zeros.half())
|
||||
else:
|
||||
output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float())
|
||||
output = _matmul4bit_v1(x, qweight, scales, zeros)
|
||||
else:
|
||||
output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float())
|
||||
output = _matmul4bit_v1(x, qweight, scales, zeros)
|
||||
else:
|
||||
if g_idx is None:
|
||||
g_idx = torch.zeros(qweight.shape[0] * 8, dtype=torch.int32, device=x.device)
|
||||
|
|
@ -126,11 +126,11 @@ def matmul4bit(x, qweight, scales, zeros, g_idx=None):
|
|||
if use_new:
|
||||
if auto_switch:
|
||||
if np.prod(x.shape[:-1]) > auto_switch_thd:
|
||||
output = _matmul4bit_v2_recons(x.to(scales.dtype), qweight, scales, zeros, g_idx)
|
||||
output = _matmul4bit_v2_recons(x.half(), qweight, scales.half(), zeros, g_idx)
|
||||
else:
|
||||
output = _matmul4bit_v2(x, qweight, scales.float(), zeros, g_idx)
|
||||
output = _matmul4bit_v2(x, qweight, scales, zeros, g_idx)
|
||||
else:
|
||||
output = _matmul4bit_v2(x, qweight, scales.float(), zeros, g_idx)
|
||||
output = _matmul4bit_v2(x, qweight, scales, zeros, g_idx)
|
||||
return output
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue