This commit is contained in:
John Smith 2023-03-22 00:18:24 +08:00 committed by GitHub
parent 467849d13a
commit a955a1c2a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -131,6 +131,7 @@ class AutogradMatmul4bit(torch.autograd.Function):
def backward(ctx, grad_output):
qweight, scales, zeros = ctx.saved_tensors
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
grad = torch.matmul(grad_output, buffer.T)
return grad, None, None, None
@ -229,4 +230,4 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer