optimized groupsize backward for performance
This commit is contained in:
parent
5986649b37
commit
0fdae9224c
|
|
@ -3,15 +3,14 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
from safetensors import safe_open
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class AutogradMatmul4bit(torch.autograd.Function):
|
class AutogradMatmul4bit(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x, qweight, scales, zeros, groupsize=-1):
|
def forward(ctx, x, qweight, scales, zeros, groupsize=-1):
|
||||||
ctx.save_for_backward(qweight, scales, zeros, torch.from_numpy(np.array([groupsize])).cuda())
|
ctx.save_for_backward(qweight, scales, zeros)
|
||||||
|
ctx.groupsize = groupsize
|
||||||
if groupsize == -1:
|
if groupsize == -1:
|
||||||
output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros)
|
output = mm4b._matmul4bit_v1_recons(x, qweight, scales, zeros)
|
||||||
else:
|
else:
|
||||||
|
|
@ -21,8 +20,8 @@ class AutogradMatmul4bit(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
qweight, scales, zeros, groupsize = ctx.saved_tensors
|
qweight, scales, zeros = ctx.saved_tensors
|
||||||
groupsize = groupsize.cpu().numpy()[0]
|
groupsize = ctx.groupsize
|
||||||
if groupsize == -1:
|
if groupsize == -1:
|
||||||
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
|
grad = mm4b._matmul4bit_v1_recons(grad_output, qweight, scales, zeros, transpose=True)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue