diff --git a/GPTQ-for-LLaMa/autograd_4bit.py b/GPTQ-for-LLaMa/autograd_4bit.py index a56bb1c..fd934a1 100644 --- a/GPTQ-for-LLaMa/autograd_4bit.py +++ b/GPTQ-for-LLaMa/autograd_4bit.py @@ -17,7 +17,6 @@ def matmul4bit(x, qweight, scales, zeros): assert qweight.shape[0] * 8 == x.shape[-1] outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]]) x = x.reshape(-1, x.shape[-1]) - assert x.shape[0] % 256 == 0 y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) dtype = x.dtype x = x.float() @@ -39,7 +38,6 @@ def matmul4bit_transpose(x, qweight, scales, zeros): assert qweight.shape[1] == x.shape[-1] outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8]) x = x.reshape(-1, x.shape[-1]) - assert x.shape[0] % 256 == 0 y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=torch.float32, device=x.device) dtype = x.dtype x = x.float()