Update autograd_4bit.py

This commit is contained in:
John Smith 2023-03-18 22:13:11 +08:00 committed by GitHub
parent 6f4bbb40a9
commit 2b84b32fbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 0 additions and 2 deletions

View File

@ -17,7 +17,6 @@ def matmul4bit(x, qweight, scales, zeros):
assert qweight.shape[0] * 8 == x.shape[-1] assert qweight.shape[0] * 8 == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]]) outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
assert x.shape[0] % 256 == 0
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
dtype = x.dtype dtype = x.dtype
x = x.float() x = x.float()
@ -39,7 +38,6 @@ def matmul4bit_transpose(x, qweight, scales, zeros):
assert qweight.shape[1] == x.shape[-1] assert qweight.shape[1] == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8]) outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8])
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
assert x.shape[0] % 256 == 0
y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=torch.float32, device=x.device) y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=torch.float32, device=x.device)
dtype = x.dtype dtype = x.dtype
x = x.float() x = x.float()