From 82bbea27294ec3aba419bd0daeb3dadc0639b87c Mon Sep 17 00:00:00 2001 From: John Smith Date: Sat, 22 Apr 2023 23:01:39 +0800 Subject: [PATCH] optimized matmul for v2 model --- autograd_4bit.py | 4 +++- matmul_utils_4bit.py | 14 +++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/autograd_4bit.py b/autograd_4bit.py index 708e134..ceb693c 100644 --- a/autograd_4bit.py +++ b/autograd_4bit.py @@ -107,6 +107,7 @@ class Autograd4bitQuantLinear(nn.Module): groupsize = groupsize if groupsize != -1 else in_features self.groupsize = groupsize self.is_v1_model = is_v1_model + self.disable_bias = True if is_v1_model: self.register_buffer('zeros', torch.empty((out_features, 1))) self.register_buffer('scales', torch.empty((out_features, 1))) @@ -132,7 +133,8 @@ class Autograd4bitQuantLinear(nn.Module): out = matmul4bit_with_backend(x, self.qweight, self.scales, self.qzeros if not self.is_v1_model else self.zeros, self.g_idx, self.bits, self.maxq) - out += self.bias + if not self.disable_bias: + out += self.bias return out diff --git a/matmul_utils_4bit.py b/matmul_utils_4bit.py index 16575c6..e010bcc 100644 --- a/matmul_utils_4bit.py +++ b/matmul_utils_4bit.py @@ -63,7 +63,7 @@ def _matmul4bit_v2(x, qweight, scales, zeros, g_idx): assert qweight.shape[0] * 8 == x.shape[-1] outshape = x.shape[:-1] + (qweight.shape[1],) x = x.reshape(-1, x.shape[-1]) - y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) + y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float16, device=x.device) dtype = x.dtype if faster: x = x.half() @@ -114,11 +114,11 @@ def matmul4bit(x, qweight, scales, zeros, g_idx=None): if use_new: if auto_switch: if np.prod(x.shape[:-1]) > auto_switch_thd: - output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales, zeros) + output = _matmul4bit_v1_recons(x.to(scales.dtype), qweight, scales.half(), zeros.half()) else: - output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float()) + output = _matmul4bit_v1(x, qweight, scales, zeros) else: - output = _matmul4bit_v1(x, qweight, scales.float(), zeros.float()) + output = _matmul4bit_v1(x, qweight, scales, zeros) else: if g_idx is None: g_idx = torch.zeros(qweight.shape[0] * 8, dtype=torch.int32, device=x.device) @@ -126,11 +126,11 @@ def matmul4bit(x, qweight, scales, zeros, g_idx=None): if use_new: if auto_switch: if np.prod(x.shape[:-1]) > auto_switch_thd: - output = _matmul4bit_v2_recons(x.to(scales.dtype), qweight, scales, zeros, g_idx) + output = _matmul4bit_v2_recons(x.half(), qweight, scales.half(), zeros, g_idx) else: - output = _matmul4bit_v2(x, qweight, scales.float(), zeros, g_idx) + output = _matmul4bit_v2(x, qweight, scales, zeros, g_idx) else: - output = _matmul4bit_v2(x, qweight, scales.float(), zeros, g_idx) + output = _matmul4bit_v2(x, qweight, scales, zeros, g_idx) return output