From 5b648333901bebca981352a72f58f8459bba0445 Mon Sep 17 00:00:00 2001 From: John Smith Date: Mon, 20 Mar 2023 09:19:05 +0000 Subject: [PATCH] add half support on cuda kernel --- GPTQ-for-LLaMa/quant_cuda.cpp | 28 +++++ GPTQ-for-LLaMa/quant_cuda_kernel.cu | 170 ++++++++++++++++++++++++++++ 2 files changed, 198 insertions(+) diff --git a/GPTQ-for-LLaMa/quant_cuda.cpp b/GPTQ-for-LLaMa/quant_cuda.cpp index af4e28f..a6524b6 100644 --- a/GPTQ-for-LLaMa/quant_cuda.cpp +++ b/GPTQ-for-LLaMa/quant_cuda.cpp @@ -67,10 +67,38 @@ void vecquant4transposematmul( vecquant4transposematmul_cuda(vec, mat, mul, scales, zeros); } +void vecquant4matmul_half_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant4matmul_half( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_half_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant4transposematmul_half_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant4transposematmul_half( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4transposematmul_half_cuda(vec, mat, mul, scales, zeros); +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)"); m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)"); m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)"); m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)"); m.def("vecquant4transposematmul", &vecquant4transposematmul, "Vector 4-bit Transpose Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant4matmul_half", &vecquant4matmul_half, "Vector 4-bit Half Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant4transposematmul_half", &vecquant4transposematmul_half, "Vector 4-bit Half Transpose Quantized Matrix Multiplication (CUDA)"); } diff --git a/GPTQ-for-LLaMa/quant_cuda_kernel.cu b/GPTQ-for-LLaMa/quant_cuda_kernel.cu index 769fe6f..d82e1d0 100644 --- a/GPTQ-for-LLaMa/quant_cuda_kernel.cu +++ b/GPTQ-for-LLaMa/quant_cuda_kernel.cu @@ -2,6 +2,26 @@ #include #include #include +#include + +// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh +__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { + unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + unsigned short hsum = reinterpret_cast(address) & 2 ? (old >> 16) : (old & 0xffff); + hsum += val; + old = reinterpret_cast(address) & 2 + ? (old & 0xffff) | (hsum << 16) + : (old & 0xffff0000) | hsum; + old = atomicCAS(address_as_ui, assumed, old); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); +} template __global__ void VecQuant2MatMulKernel( @@ -478,3 +498,153 @@ void vecquant4transposematmul_cuda( }) ); } + +template +__global__ void VecQuant4MatMulHalfKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width +) { + int b = blockIdx.z; + int h = BLOCKHEIGHT4 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ __half blockvec[BLOCKWIDTH]; + blockvec[threadIdx.x] = __half(vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x]); + __syncthreads(); + + __half scale = __half(scales[w]); + __half zero = __half(zeros[w]); + + __half res = __float2half(0.0f); + int i = width * h + w; + int k = 0; + + unsigned int tmp; + + while (k < BLOCKWIDTH) { + tmp = as_unsigned(mat[i]); + res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 0) & 0xF)), zero), blockvec[k + 0])); + res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 4) & 0xF)), zero), blockvec[k + 1])); + res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 8) & 0xF)), zero), blockvec[k + 2])); + res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 12) & 0xF)), zero), blockvec[k + 3])); + res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 16) & 0xF)), zero), blockvec[k + 4])); + res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 20) & 0xF)), zero), blockvec[k + 5])); + res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 24) & 0xF)), zero), blockvec[k + 6])); + res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 28) & 0xF)), zero), blockvec[k + 7])); + i += width; + k += 8; + } + + __half* mul2 = (__half*)mul; + atomicAdd(&mul2[b * width + w], res); +} + +void vecquant4matmul_half_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_SWITCH(vec.type(), "vecquant4matmul_half_cuda", + AT_DISPATCH_CASE(at::ScalarType::Half, ([&] { + VecQuant4MatMulHalfKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, vec_height, height, width + ); + }) + )); +} + +template +__global__ void VecQuant4TransposeMatMulHalfKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width +) { + int b = blockIdx.z; + int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8; + unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4); + int w = BLOCKWIDTH * blockIdx.y; + + int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x; + int n_cols = b; + + __shared__ __half blockvec[BLOCKWIDTH]; + blockvec[threadIdx.x] = __half(vec[n_cols * vec_height + w + threadIdx.x]); + __syncthreads(); + + __half res = __float2half(0.0f); + int i = width * h + w; + int k = 0; + int j = w; + unsigned int tmp; + while (k < BLOCKWIDTH) { + tmp = as_unsigned(mat[i]); + __half zero = __half(zeros[j]); + __half scale = __half(scales[j]); + res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> shift) & 0xF)), zero), blockvec[k])); + i += 1; + j += 1; + k += 1; + } + + __half* mul2 = (__half*)mul; + atomicAdd(&mul2[n_cols * height * 8 + n_rows], res); +} + +void vecquant4transposematmul_half_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_SWITCH(vec.type(), "vecquant4transposematmul_half_cuda", + AT_DISPATCH_CASE(at::ScalarType::Half, ([&] { + VecQuant4TransposeMatMulHalfKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, vec_height, height, width + ); + }) + )); +} \ No newline at end of file