#include #include #include #include #include #ifdef __CUDA_ARCH__ #if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600 // adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh __device__ __forceinline__ void atomicAddHalf(__half* address, c10::Half val) { unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); unsigned int old = *address_as_ui; unsigned int assumed; do { assumed = old; unsigned short hsum = reinterpret_cast(address) & 2 ? (old >> 16) : (old & 0xffff); hsum += val; old = reinterpret_cast(address) & 2 ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum; old = atomicCAS(address_as_ui, assumed, old); // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) } while (assumed != old); } #endif #endif template __global__ void VecQuant2MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const scalar_t* __restrict__ zeros, int batch, int vec_height, int height, int width ); template __global__ void VecQuant3MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const scalar_t* __restrict__ zeros, int batch, int vec_height, int height, int width ); template __global__ void VecQuant4MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const scalar_t* __restrict__ zeros, int batch, int vec_height, int height, int width ); template __global__ void VecQuant8MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const scalar_t* __restrict__ zeros, int batch, int vec_height, int height, int width ); const int BLOCKWIDTH = 256; const int BLOCKHEIGHT2 = 16; const int BLOCKHEIGHT3 = 24; const int BLOCKHEIGHT4 = 32; const int BLOCKHEIGHT8 = 64; __device__ inline unsigned int as_unsigned(int i) { return *reinterpret_cast(&i); } void vecquant2matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros ) { int batch = vec.size(0); int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); dim3 blocks( (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, batch ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_FLOATING_TYPES( vec.type(), "vecquant2matmul_cuda", ([&] { VecQuant2MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), scales.data(), zeros.data(), batch, vec_height, height, width ); }) ); } template __global__ void VecQuant2MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const scalar_t* __restrict__ zeros, int batch, int vec_height, int height, int width ) { int b = blockIdx.z; int h = BLOCKHEIGHT2 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ scalar_t blockvec[BLOCKWIDTH]; blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT2) * BLOCKWIDTH + threadIdx.x]; __syncthreads(); scalar_t scale = scales[w]; scalar_t zero = zeros[w]; scalar_t res = 0; int i = width * h + w; int k = 0; unsigned int tmp; while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3]; res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4]; res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5]; res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6]; res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9]; res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10]; res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11]; res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12]; res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13]; res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14]; res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15]; i += width; k += 16; } atomicAdd(&mul[b * width + w], res); } void vecquant3matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros ) { int batch = vec.size(0); int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); dim3 blocks( (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, batch ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_FLOATING_TYPES( vec.type(), "vecquant3matmul_cuda", ([&] { VecQuant3MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), scales.data(), zeros.data(), batch, vec_height, height, width ); }) ); } template __global__ void VecQuant3MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const scalar_t* __restrict__ zeros, int batch, int vec_height, int height, int width ) { int b = blockIdx.z; int h = BLOCKHEIGHT3 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ scalar_t blockvec[BLOCKWIDTH]; blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT3) * BLOCKWIDTH + threadIdx.x]; __syncthreads(); scalar_t scale = scales[w]; scalar_t zero = zeros[w]; scalar_t res = 0; int i = width * h + w; int k = 0; unsigned int tmp1; unsigned int tmp2; unsigned int tmp; while (k < BLOCKWIDTH) { tmp1 = as_unsigned(mat[i]); res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; i += width; tmp2 = as_unsigned(mat[i]); tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4); tmp2 >>= 1; res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; k += 11; res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3]; res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4]; res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5]; res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6]; res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9]; i += width; tmp1 = as_unsigned(mat[i]); tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6); tmp1 >>= 2; res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; k += 11; res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; i += width; k += 10; } atomicAdd(&mul[b * width + w], res); } void vecquant4matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros ) { int batch = vec.size(0); int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); dim3 blocks( (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, batch ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_FLOATING_TYPES( vec.type(), "vecquant4matmul_cuda", ([&] { VecQuant4MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), scales.data(), zeros.data(), batch, vec_height, height, width ); }) ); } template __global__ void VecQuant4MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const scalar_t* __restrict__ zeros, int batch, int vec_height, int height, int width ) { int b = blockIdx.z; int h = BLOCKHEIGHT4 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ scalar_t blockvec[BLOCKWIDTH]; blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x]; __syncthreads(); scalar_t scale = scales[w]; scalar_t zero = zeros[w]; scalar_t res = 0; int i = width * h + w; int k = 0; unsigned int tmp; while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3]; res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4]; res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5]; res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6]; res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7]; i += width; k += 8; } atomicAdd(&mul[b * width + w], res); } void vecquant8matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros ) { int batch = vec.size(0); int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); dim3 blocks( (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, batch ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_FLOATING_TYPES( vec.type(), "vecquant8matmul_cuda", ([&] { VecQuant8MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), scales.data(), zeros.data(), batch, vec_height, height, width ); }) ); } template __global__ void VecQuant8MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const scalar_t* __restrict__ zeros, int batch, int vec_height, int height, int width ) { int b = blockIdx.z; int h = BLOCKHEIGHT8 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ scalar_t blockvec[BLOCKWIDTH]; blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT8) * BLOCKWIDTH + threadIdx.x]; __syncthreads(); scalar_t scale = scales[w]; scalar_t zero = zeros[w]; scalar_t res = 0; int i = width * h + w; int k = 0; unsigned int tmp; while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3]; i += width; k += 4; } atomicAdd(&mul[b * width + w], res); } template __global__ void VecQuant4TransposeMatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const scalar_t* __restrict__ zeros, int batch, int vec_height, int height, int width ) { int b = blockIdx.z; int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8; unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4); int w = BLOCKWIDTH * blockIdx.y; int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x; int n_cols = b; __shared__ scalar_t blockvec[BLOCKWIDTH]; blockvec[threadIdx.x] = vec[n_cols * vec_height + w + threadIdx.x]; __syncthreads(); scalar_t res = 0; int i = width * h + w; int k = 0; int j = w; unsigned int tmp; while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); res += (scales[j] * scalar_t((tmp >> shift) & 0xF) - zeros[j]) * blockvec[k]; i += 1; j += 1; k += 1; } atomicAdd(&mul[n_cols * height * 8 + n_rows], res); } void vecquant4transposematmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros ) { int batch = vec.size(0); int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); dim3 blocks( (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, batch ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_FLOATING_TYPES( vec.type(), "vecquant4transposematmul_cuda", ([&] { VecQuant4TransposeMatMulKernel<<>>( vec.data(), mat.data(), mul.data(), scales.data(), zeros.data(), batch, vec_height, height, width ); }) ); } template __global__ void VecQuant4MatMulHalfKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const scalar_t* __restrict__ zeros, int batch, int vec_height, int height, int width ) { int b = blockIdx.z; int h = BLOCKHEIGHT4 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ __half blockvec[BLOCKWIDTH]; blockvec[threadIdx.x] = __half(vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x]); __syncthreads(); __half scale = __half(scales[w]); __half zero = __half(zeros[w]); __half res = __float2half(0.0f); int i = width * h + w; int k = 0; unsigned int tmp; while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 0) & 0xF)), zero), blockvec[k + 0])); res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 4) & 0xF)), zero), blockvec[k + 1])); res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 8) & 0xF)), zero), blockvec[k + 2])); res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 12) & 0xF)), zero), blockvec[k + 3])); res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 16) & 0xF)), zero), blockvec[k + 4])); res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 20) & 0xF)), zero), blockvec[k + 5])); res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 24) & 0xF)), zero), blockvec[k + 6])); res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 28) & 0xF)), zero), blockvec[k + 7])); i += width; k += 8; } __half* mul2 = (__half*)mul; #ifdef __CUDA_ARCH__ #if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600 atomicAddHalf(&mul2[b * width + w], res); #else atomicAdd(&mul2[b * width + w], res); #endif #endif } void vecquant4matmul_half_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros ) { int batch = vec.size(0); int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); dim3 blocks( (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, batch ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_SWITCH(vec.type(), "vecquant4matmul_half_cuda", AT_DISPATCH_CASE(at::ScalarType::Half, ([&] { VecQuant4MatMulHalfKernel<<>>( vec.data(), mat.data(), mul.data(), scales.data(), zeros.data(), batch, vec_height, height, width ); }) )); } template __global__ void VecQuant4TransposeMatMulHalfKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const scalar_t* __restrict__ zeros, int batch, int vec_height, int height, int width ) { int b = blockIdx.z; int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8; unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4); int w = BLOCKWIDTH * blockIdx.y; int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x; int n_cols = b; __shared__ __half blockvec[BLOCKWIDTH]; blockvec[threadIdx.x] = __half(vec[n_cols * vec_height + w + threadIdx.x]); __syncthreads(); __half res = __float2half(0.0f); int i = width * h + w; int k = 0; int j = w; unsigned int tmp; while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); __half zero = __half(zeros[j]); __half scale = __half(scales[j]); res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> shift) & 0xF)), zero), blockvec[k])); i += 1; j += 1; k += 1; } __half* mul2 = (__half*)mul; #ifdef __CUDA_ARCH__ #if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600 atomicAddHalf(&mul2[n_cols * height * 8 + n_rows], res); #else atomicAdd(&mul2[n_cols * height * 8 + n_rows], res); #endif #endif } void vecquant4transposematmul_half_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros ) { int batch = vec.size(0); int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); dim3 blocks( (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, batch ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_SWITCH(vec.type(), "vecquant4transposematmul_half_cuda", AT_DISPATCH_CASE(at::ScalarType::Half, ([&] { VecQuant4TransposeMatMulHalfKernel<<>>( vec.data(), mat.data(), mul.data(), scales.data(), zeros.data(), batch, vec_height, height, width ); }) )); } template __global__ void VecQuant4ReconsKernel( const int* __restrict__ mat, scalar_t* __restrict__ res, const scalar_t* __restrict__ scales, const scalar_t* __restrict__ zeros, int height, int width ) { int b = blockIdx.z; int h = BLOCKHEIGHT4 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; int n_rows = h * 8 + b; int n_cols = w; scalar_t scale = scales[w]; scalar_t zero = zeros[w]; int i = width * h + width * (b / 8) + w; int shift = b % 8 * 4; unsigned int tmp = as_unsigned(mat[i]); scalar_t result = (scale * scalar_t((tmp >> shift) & 0xF) - zero); res[n_rows * width + n_cols] = result; } void vecquant4recons_cuda( torch::Tensor mat, torch::Tensor res, torch::Tensor scales, torch::Tensor zeros ) { int batch = BLOCKWIDTH; int height = mat.size(0); int width = mat.size(1); dim3 blocks( (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, batch ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_FLOATING_TYPES_AND_HALF( scales.type(), "vecquant4recons_cuda", ([&] { VecQuant4ReconsKernel<<>>( mat.data(), res.data(), scales.data(), zeros.data(), height, width ); }) ); }