diff --git a/GPTQ-for-LLaMa/autograd_4bit.py b/GPTQ-for-LLaMa/autograd_4bit.py index fd934a1..34bf9d7 100644 --- a/GPTQ-for-LLaMa/autograd_4bit.py +++ b/GPTQ-for-LLaMa/autograd_4bit.py @@ -50,20 +50,20 @@ class AutogradMatmul4bit(torch.autograd.Function): @staticmethod def forward(ctx, x, qweight, scales, zeros): - ctx.save_for_backward(x, qweight, scales, zeros) + ctx.save_for_backward(qweight, scales, zeros) output = matmul4bit(x, qweight, scales, zeros).clone() return output # equals to torch.matmul(x, qweight) @staticmethod def backward(ctx, grad_output): - x, qweight, scales, zeros = ctx.saved_tensors + qweight, scales, zeros = ctx.saved_tensors # print(grad_output.shape, A.shape, B.shape) # compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T grad1 = matmul4bit_transpose(grad_output, qweight, scales, zeros) - grad2 = torch.matmul(x.transpose(-1, -2), grad_output) + # grad2 = torch.matmul(x.transpose(-1, -2), grad_output) - return grad1, grad2, None, None + return grad1, None, None, None # Assumes layer is perfectly divisible into 256 * 256 blocks