Update autograd_4bit.py
This commit is contained in:
parent
6f4bbb40a9
commit
2b84b32fbe
|
|
@ -17,7 +17,6 @@ def matmul4bit(x, qweight, scales, zeros):
|
|||
assert qweight.shape[0] * 8 == x.shape[-1]
|
||||
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
assert x.shape[0] % 256 == 0
|
||||
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
|
|
@ -39,7 +38,6 @@ def matmul4bit_transpose(x, qweight, scales, zeros):
|
|||
assert qweight.shape[1] == x.shape[-1]
|
||||
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8])
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
assert x.shape[0] % 256 == 0
|
||||
y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=torch.float32, device=x.device)
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
|
|
|
|||
Loading…
Reference in New Issue