From 3471be4e56546e8a6e0df3f8bfa93411eb7050f0 Mon Sep 17 00:00:00 2001 From: John Smith Date: Tue, 21 Mar 2023 08:43:07 +0000 Subject: [PATCH] add fast_4bit_matmul and auto switch 2 methods according to bottleneck --- GPTQ-for-LLaMa/autograd_4bit.py | 47 +++++++++++++++++-------- GPTQ-for-LLaMa/quant_cuda.cpp | 25 ++++++++++++++ GPTQ-for-LLaMa/quant_cuda_kernel.cu | 53 ++++++++++++++++++++++++++++- 3 files changed, 110 insertions(+), 15 deletions(-) diff --git a/GPTQ-for-LLaMa/autograd_4bit.py b/GPTQ-for-LLaMa/autograd_4bit.py index 09202f7..a246519 100644 --- a/GPTQ-for-LLaMa/autograd_4bit.py +++ b/GPTQ-for-LLaMa/autograd_4bit.py @@ -5,6 +5,19 @@ import torch.nn as nn import time +# Global Buffer +buffer_mat_dic = {} +use_new = True +auto_switch = True +auto_switch_thd = 16 + + +def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'): + if shape_of_qweight not in buffer_mat_dic.keys(): + buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device) + return buffer_mat_dic[shape_of_qweight] + + def matmul4bit(x, qweight, scales, zeros): """ input x: (n, m) @@ -87,6 +100,23 @@ def matmul4bit_transpose_half(x, qweight, scales, zeros): return y.reshape(outshape) +def fast_4bit_forward(x, qweight, scales, zeros, bias): + use_new_flag = use_new + if auto_switch: + if x.shape[1] > auto_switch_thd: + use_new_flag = True + else: + use_new_flag = False + if use_new_flag: + buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device) + quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros) + output = torch.matmul(x, buffer) + else: + output = matmul4bit(x, qweight, scales.float(), zeros.float()) + output += bias + return output + + class AutogradMatmul4bit(torch.autograd.Function): @staticmethod @@ -94,12 +124,7 @@ class AutogradMatmul4bit(torch.autograd.Function): ctx.save_for_backward(qweight, scales, zeros) # equals to torch.matmul(x, qweight) - if x.dtype == torch.float32: - output = matmul4bit(x, qweight, scales, zeros).clone() - elif x.dtype == torch.float16: - output = matmul4bit_half(x, qweight, scales, zeros).clone() - else: - raise ValueError('Only float and half are supportted.') + output = matmul4bit(x, qweight, scales, zeros).clone() return output @@ -108,12 +133,7 @@ class AutogradMatmul4bit(torch.autograd.Function): qweight, scales, zeros = ctx.saved_tensors # compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T - if grad_output.dtype == torch.float32: - grad = matmul4bit_transpose(grad_output, qweight, scales, zeros) - elif grad_output.dtype == torch.float16: - grad = matmul4bit_transpose_half(grad_output, qweight, scales, zeros) - else: - raise ValueError('Only float and half are supportted.') + grad = matmul4bit_transpose(grad_output, qweight, scales, zeros) return grad, None, None, None @@ -135,8 +155,7 @@ class Autograd4bitQuantLinear(nn.Module): ) def forward(self, x): - out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, self.zeros) - out += self.bias + out = fast_4bit_forward(x, self.qweight, self.scales, self.zeros, self.bias) return out diff --git a/GPTQ-for-LLaMa/quant_cuda.cpp b/GPTQ-for-LLaMa/quant_cuda.cpp index a6524b6..c7086c3 100644 --- a/GPTQ-for-LLaMa/quant_cuda.cpp +++ b/GPTQ-for-LLaMa/quant_cuda.cpp @@ -1,6 +1,7 @@ #include #include #include +#include void vecquant2matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, @@ -93,6 +94,28 @@ void vecquant4transposematmul_half( vecquant4transposematmul_half_cuda(vec, mat, mul, scales, zeros); } +void vecquant4recons_cuda( + torch::Tensor mat, torch::Tensor res, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant4recons( + torch::Tensor mat, torch::Tensor res, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(mat)); + vecquant4recons_cuda(mat, res, scales, zeros); +} + +at::Tensor fast_4bit_forward( + torch::Tensor x, torch::Tensor mat, torch::Tensor buffer, + torch::Tensor scales, torch::Tensor zeros, torch::Tensor bias) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(mat)); + vecquant4recons_cuda(mat, buffer, scales, zeros); + auto result = at::addmm(bias, x, buffer); + return result; +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)"); m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)"); @@ -101,4 +124,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("vecquant4transposematmul", &vecquant4transposematmul, "Vector 4-bit Transpose Quantized Matrix Multiplication (CUDA)"); m.def("vecquant4matmul_half", &vecquant4matmul_half, "Vector 4-bit Half Quantized Matrix Multiplication (CUDA)"); m.def("vecquant4transposematmul_half", &vecquant4transposematmul_half, "Vector 4-bit Half Transpose Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant4recons", &vecquant4recons, "Vector 4-bit Matrix Reconstruction (CUDA)"); + m.def("fast_4bit_forward", &fast_4bit_forward, "Vector 4-bit Fast Forward"); } diff --git a/GPTQ-for-LLaMa/quant_cuda_kernel.cu b/GPTQ-for-LLaMa/quant_cuda_kernel.cu index d82e1d0..0077650 100644 --- a/GPTQ-for-LLaMa/quant_cuda_kernel.cu +++ b/GPTQ-for-LLaMa/quant_cuda_kernel.cu @@ -647,4 +647,55 @@ void vecquant4transposematmul_half_cuda( ); }) )); -} \ No newline at end of file +} + +template +__global__ void VecQuant4ReconsKernel( + const int* __restrict__ mat, + scalar_t* __restrict__ res, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int height, + int width +) { + int b = blockIdx.z; + int h = BLOCKHEIGHT4 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + int n_rows = h * 8 + b; + int n_cols = w; + scalar_t scale = scales[w]; + scalar_t zero = zeros[w]; + int i = width * h + width * (b / 8) + w; + int shift = b % 8 * 4; + unsigned int tmp = as_unsigned(mat[i]); + scalar_t result = (scale * scalar_t((tmp >> shift) & 0xF) - zero); + res[n_rows * width + n_cols] = result; +} + +void vecquant4recons_cuda( + torch::Tensor mat, + torch::Tensor res, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = BLOCKWIDTH; + int height = mat.size(0); + int width = mat.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + scales.type(), "vecquant4recons_cuda", ([&] { + VecQuant4ReconsKernel<<>>( + mat.data(), res.data(), + scales.data(), zeros.data(), + height, width + ); + }) + ); +}