From ef0a326cec099e3065ea3009850b683c5ebbcc33 Mon Sep 17 00:00:00 2001 From: John Smith Date: Tue, 21 Mar 2023 09:41:18 +0000 Subject: [PATCH] update autograd --- GPTQ-for-LLaMa/autograd_4bit.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/GPTQ-for-LLaMa/autograd_4bit.py b/GPTQ-for-LLaMa/autograd_4bit.py index a246519..a609ccc 100644 --- a/GPTQ-for-LLaMa/autograd_4bit.py +++ b/GPTQ-for-LLaMa/autograd_4bit.py @@ -122,19 +122,16 @@ class AutogradMatmul4bit(torch.autograd.Function): @staticmethod def forward(ctx, x, qweight, scales, zeros): ctx.save_for_backward(qweight, scales, zeros) - - # equals to torch.matmul(x, qweight) - output = matmul4bit(x, qweight, scales, zeros).clone() - + buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device) + quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros) + output = torch.matmul(x, buffer).clone() return output @staticmethod def backward(ctx, grad_output): qweight, scales, zeros = ctx.saved_tensors - - # compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T - grad = matmul4bit_transpose(grad_output, qweight, scales, zeros) - + buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device) + grad = torch.matmul(grad_output, buffer.T) return grad, None, None, None @@ -155,7 +152,11 @@ class Autograd4bitQuantLinear(nn.Module): ) def forward(self, x): - out = fast_4bit_forward(x, self.qweight, self.scales, self.zeros, self.bias) + if torch.is_grad_enabled(): + out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, self.zeros) + out += self.bias + else: + out = fast_4bit_forward(x, self.qweight, self.scales, self.zeros, self.bias) return out