optimized groupsize backward for performance
This commit is contained in:
parent
5986649b37
commit
0fdae9224c
|
|
@ -3,15 +3,14 @@ import torch
|
|||
import torch.nn as nn
|
||||
import time
|
||||
import math
|
||||
from safetensors import safe_open
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AutogradMatmul4bit(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, qweight, scales, zeros, groupsize=-1):
|
||||
ctx.save_for_backward(qweight, scales, zeros, torch.from_numpy(np.array([groupsize])).cuda())
|
||||
ctx.save_for_backward(qweight, scales, zeros)
|
||||
ctx.groupsize = groupsize
|
||||
if groupsize == -1:
|
||||
output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros)
|
||||
else:
|
||||
|
|
@ -21,8 +20,8 @@ class AutogradMatmul4bit(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
qweight, scales, zeros, groupsize = ctx.saved_tensors
|
||||
groupsize = groupsize.cpu().numpy()[0]
|
||||
qweight, scales, zeros = ctx.saved_tensors
|
||||
groupsize = ctx.groupsize
|
||||
if groupsize == -1:
|
||||
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue