fix bug
This commit is contained in:
parent
467849d13a
commit
a955a1c2a5
|
|
@ -131,6 +131,7 @@ class AutogradMatmul4bit(torch.autograd.Function):
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
qweight, scales, zeros = ctx.saved_tensors
|
qweight, scales, zeros = ctx.saved_tensors
|
||||||
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
||||||
|
quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
|
||||||
grad = torch.matmul(grad_output, buffer.T)
|
grad = torch.matmul(grad_output, buffer.T)
|
||||||
return grad, None, None, None
|
return grad, None, None, None
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue