reduced memory usage by a little
This commit is contained in:
parent
2b84b32fbe
commit
04f5575a23
|
|
@ -50,20 +50,20 @@ class AutogradMatmul4bit(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx, x, qweight, scales, zeros):
|
||||
ctx.save_for_backward(x, qweight, scales, zeros)
|
||||
ctx.save_for_backward(qweight, scales, zeros)
|
||||
output = matmul4bit(x, qweight, scales, zeros).clone()
|
||||
return output # equals to torch.matmul(x, qweight)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x, qweight, scales, zeros = ctx.saved_tensors
|
||||
qweight, scales, zeros = ctx.saved_tensors
|
||||
# print(grad_output.shape, A.shape, B.shape)
|
||||
|
||||
# compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T
|
||||
grad1 = matmul4bit_transpose(grad_output, qweight, scales, zeros)
|
||||
grad2 = torch.matmul(x.transpose(-1, -2), grad_output)
|
||||
# grad2 = torch.matmul(x.transpose(-1, -2), grad_output)
|
||||
|
||||
return grad1, grad2, None, None
|
||||
return grad1, None, None, None
|
||||
|
||||
|
||||
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
||||
|
|
|
|||
Loading…
Reference in New Issue