update autograd

This commit is contained in:
John Smith 2023-03-21 09:41:18 +00:00
parent 3471be4e56
commit ef0a326cec
1 changed files with 10 additions and 9 deletions

View File

@ -122,19 +122,16 @@ class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, qweight, scales, zeros): def forward(ctx, x, qweight, scales, zeros):
ctx.save_for_backward(qweight, scales, zeros) ctx.save_for_backward(qweight, scales, zeros)
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
# equals to torch.matmul(x, qweight) quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
output = matmul4bit(x, qweight, scales, zeros).clone() output = torch.matmul(x, buffer).clone()
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
qweight, scales, zeros = ctx.saved_tensors qweight, scales, zeros = ctx.saved_tensors
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
# compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T grad = torch.matmul(grad_output, buffer.T)
grad = matmul4bit_transpose(grad_output, qweight, scales, zeros)
return grad, None, None, None return grad, None, None, None
@ -155,7 +152,11 @@ class Autograd4bitQuantLinear(nn.Module):
) )
def forward(self, x): def forward(self, x):
out = fast_4bit_forward(x, self.qweight, self.scales, self.zeros, self.bias) if torch.is_grad_enabled():
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, self.zeros)
out += self.bias
else:
out = fast_4bit_forward(x, self.qweight, self.scales, self.zeros, self.bias)
return out return out