From 58998acc9fd50b1e3e8068f4afb867467dc84727 Mon Sep 17 00:00:00 2001 From: Forkoz <59298527+Ph0rk0z@users.noreply.github.com> Date: Thu, 23 Mar 2023 07:33:57 -0500 Subject: [PATCH] Fix cuda kernel for Pascal & Cuda 6/6.1 When I left the other functions to use normal atomic add it seemed like a small speedup. 4.79 it/s vs 5.23 it/s --- GPTQ-for-LLaMa/quant_cuda_kernel.cu | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/GPTQ-for-LLaMa/quant_cuda_kernel.cu b/GPTQ-for-LLaMa/quant_cuda_kernel.cu index 0077650..de0c0d6 100644 --- a/GPTQ-for-LLaMa/quant_cuda_kernel.cu +++ b/GPTQ-for-LLaMa/quant_cuda_kernel.cu @@ -4,8 +4,10 @@ #include #include +#ifdef __CUDA_ARCH__ +#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600 // adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh -__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { +__device__ __forceinline__ void atomicAddHalf(__half* address, c10::Half val) { unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); unsigned int old = *address_as_ui; unsigned int assumed; @@ -22,6 +24,8 @@ __device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) } while (assumed != old); } +#endif +#endif template __global__ void VecQuant2MatMulKernel( @@ -543,7 +547,14 @@ __global__ void VecQuant4MatMulHalfKernel( } __half* mul2 = (__half*)mul; +#ifdef __CUDA_ARCH__ +#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600 + atomicAddHalf(&mul2[b * width + w], res); +#else atomicAdd(&mul2[b * width + w], res); +#endif +#endif + } void vecquant4matmul_half_cuda( @@ -616,7 +627,13 @@ __global__ void VecQuant4TransposeMatMulHalfKernel( } __half* mul2 = (__half*)mul; - atomicAdd(&mul2[n_cols * height * 8 + n_rows], res); +#ifdef __CUDA_ARCH__ +#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600 + atomicAddHalf(&mul2[n_cols * height * 8 + n_rows], res); +#else + atomicAddHalf(&mul2[n_cols * height * 8 + n_rows], res); +#endif +#endif } void vecquant4transposematmul_half_cuda(