Fix cuda kernel for Pascal & Cuda 6/6.1
When I left the other functions to use normal atomic add it seemed like a small speedup. 4.79 it/s vs 5.23 it/s
This commit is contained in:
parent
60b227d0ba
commit
58998acc9f
|
|
@ -4,8 +4,10 @@
|
|||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600
|
||||
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
|
||||
__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
|
||||
__device__ __forceinline__ void atomicAddHalf(__half* address, c10::Half val) {
|
||||
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
|
@ -22,6 +24,8 @@ __device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
|
|||
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
||||
} while (assumed != old);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant2MatMulKernel(
|
||||
|
|
@ -543,7 +547,14 @@ __global__ void VecQuant4MatMulHalfKernel(
|
|||
}
|
||||
|
||||
__half* mul2 = (__half*)mul;
|
||||
#ifdef __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600
|
||||
atomicAddHalf(&mul2[b * width + w], res);
|
||||
#else
|
||||
atomicAdd(&mul2[b * width + w], res);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
void vecquant4matmul_half_cuda(
|
||||
|
|
@ -616,7 +627,13 @@ __global__ void VecQuant4TransposeMatMulHalfKernel(
|
|||
}
|
||||
|
||||
__half* mul2 = (__half*)mul;
|
||||
atomicAdd(&mul2[n_cols * height * 8 + n_rows], res);
|
||||
#ifdef __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600
|
||||
atomicAddHalf(&mul2[n_cols * height * 8 + n_rows], res);
|
||||
#else
|
||||
atomicAddHalf(&mul2[n_cols * height * 8 + n_rows], res);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
void vecquant4transposematmul_half_cuda(
|
||||
|
|
|
|||
Loading…
Reference in New Issue