diff --git a/autograd_4bit.py b/autograd_4bit.py index 77befb7..d89f116 100644 --- a/autograd_4bit.py +++ b/autograd_4bit.py @@ -3,15 +3,14 @@ import torch import torch.nn as nn import time import math -from safetensors import safe_open -import numpy as np class AutogradMatmul4bit(torch.autograd.Function): @staticmethod def forward(ctx, x, qweight, scales, zeros, groupsize=-1): - ctx.save_for_backward(qweight, scales, zeros, torch.from_numpy(np.array([groupsize])).cuda()) + ctx.save_for_backward(qweight, scales, zeros) + ctx.groupsize = groupsize if groupsize == -1: output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros) else: @@ -21,8 +20,8 @@ class AutogradMatmul4bit(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): - qweight, scales, zeros, groupsize = ctx.saved_tensors - groupsize = groupsize.cpu().numpy()[0] + qweight, scales, zeros = ctx.saved_tensors + groupsize = ctx.groupsize if groupsize == -1: grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True) else: