update autograd
This commit is contained in:
parent
3471be4e56
commit
ef0a326cec
|
|
@ -122,19 +122,16 @@ class AutogradMatmul4bit(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x, qweight, scales, zeros):
|
def forward(ctx, x, qweight, scales, zeros):
|
||||||
ctx.save_for_backward(qweight, scales, zeros)
|
ctx.save_for_backward(qweight, scales, zeros)
|
||||||
|
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
||||||
# equals to torch.matmul(x, qweight)
|
quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
|
||||||
output = matmul4bit(x, qweight, scales, zeros).clone()
|
output = torch.matmul(x, buffer).clone()
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
qweight, scales, zeros = ctx.saved_tensors
|
qweight, scales, zeros = ctx.saved_tensors
|
||||||
|
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
||||||
# compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T
|
grad = torch.matmul(grad_output, buffer.T)
|
||||||
grad = matmul4bit_transpose(grad_output, qweight, scales, zeros)
|
|
||||||
|
|
||||||
return grad, None, None, None
|
return grad, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -155,7 +152,11 @@ class Autograd4bitQuantLinear(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = fast_4bit_forward(x, self.qweight, self.scales, self.zeros, self.bias)
|
if torch.is_grad_enabled():
|
||||||
|
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, self.zeros)
|
||||||
|
out += self.bias
|
||||||
|
else:
|
||||||
|
out = fast_4bit_forward(x, self.qweight, self.scales, self.zeros, self.bias)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue