optimized groupsize backward for performance

This commit is contained in:
John Smith 2023-03-29 17:44:51 +08:00
parent 5986649b37
commit 0fdae9224c
1 changed files with 4 additions and 5 deletions

View File

@ -3,15 +3,14 @@ import torch
import torch.nn as nn
import time
import math
from safetensors import safe_open
import numpy as np
class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod
def forward(ctx, x, qweight, scales, zeros, groupsize=-1):
ctx.save_for_backward(qweight, scales, zeros, torch.from_numpy(np.array([groupsize])).cuda())
ctx.save_for_backward(qweight, scales, zeros)
ctx.groupsize = groupsize
if groupsize == -1:
output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros)
else:
@ -21,8 +20,8 @@ class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
qweight, scales, zeros, groupsize = ctx.saved_tensors
groupsize = groupsize.cpu().numpy()[0]
qweight, scales, zeros = ctx.saved_tensors
groupsize = ctx.groupsize
if groupsize == -1:
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
else: