reduced memory usage by a little

This commit is contained in:
John Smith 2023-03-20 00:51:52 +08:00 committed by GitHub
parent 2b84b32fbe
commit 04f5575a23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 4 deletions

View File

@ -50,20 +50,20 @@ class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod
def forward(ctx, x, qweight, scales, zeros):
ctx.save_for_backward(x, qweight, scales, zeros)
ctx.save_for_backward(qweight, scales, zeros)
output = matmul4bit(x, qweight, scales, zeros).clone()
return output # equals to torch.matmul(x, qweight)
@staticmethod
def backward(ctx, grad_output):
x, qweight, scales, zeros = ctx.saved_tensors
qweight, scales, zeros = ctx.saved_tensors
# print(grad_output.shape, A.shape, B.shape)
# compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T
grad1 = matmul4bit_transpose(grad_output, qweight, scales, zeros)
grad2 = torch.matmul(x.transpose(-1, -2), grad_output)
# grad2 = torch.matmul(x.transpose(-1, -2), grad_output)
return grad1, grad2, None, None
return grad1, None, None, None
# Assumes layer is perfectly divisible into 256 * 256 blocks