alpaca_lora_4bit/GPTQ-for-LLaMa/quant_cuda_kernel.cu

719 lines
21 KiB
Plaintext

#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#ifdef __CUDA_ARCH__
#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
__device__ __forceinline__ void atomicAddHalf(__half* address, c10::Half val) {
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do {
assumed = old;
unsigned short hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
hsum += val;
old = reinterpret_cast<size_t>(address) & 2
? (old & 0xffff) | (hsum << 16)
: (old & 0xffff0000) | hsum;
old = atomicCAS(address_as_ui, assumed, old);
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);
}
#endif
#endif
template <typename scalar_t>
__global__ void VecQuant2MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
);
template <typename scalar_t>
__global__ void VecQuant3MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
);
template <typename scalar_t>
__global__ void VecQuant4MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
);
template <typename scalar_t>
__global__ void VecQuant8MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
);
const int BLOCKWIDTH = 256;
const int BLOCKHEIGHT2 = 16;
const int BLOCKHEIGHT3 = 24;
const int BLOCKHEIGHT4 = 32;
const int BLOCKHEIGHT8 = 64;
__device__ inline unsigned int as_unsigned(int i) {
return *reinterpret_cast<unsigned int*>(&i);
}
void vecquant2matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant2matmul_cuda", ([&] {
VecQuant2MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
batch, vec_height, height, width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant2MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT2 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT2) * BLOCKWIDTH + threadIdx.x];
__syncthreads();
scalar_t scale = scales[w];
scalar_t zero = zeros[w];
scalar_t res = 0;
int i = width * h + w;
int k = 0;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]);
res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3];
res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4];
res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7];
res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8];
res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9];
res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10];
res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11];
res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12];
res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13];
res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14];
res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15];
i += width;
k += 16;
}
atomicAdd(&mul[b * width + w], res);
}
void vecquant3matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant3matmul_cuda", ([&] {
VecQuant3MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
batch, vec_height, height, width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant3MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT3 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT3) * BLOCKWIDTH + threadIdx.x];
__syncthreads();
scalar_t scale = scales[w];
scalar_t zero = zeros[w];
scalar_t res = 0;
int i = width * h + w;
int k = 0;
unsigned int tmp1;
unsigned int tmp2;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp1 = as_unsigned(mat[i]);
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
i += width;
tmp2 = as_unsigned(mat[i]);
tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4);
tmp2 >>= 1;
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
k += 11;
res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3];
res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4];
res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7];
res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8];
res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9];
i += width;
tmp1 = as_unsigned(mat[i]);
tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6);
tmp1 >>= 2;
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
k += 11;
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
i += width;
k += 10;
}
atomicAdd(&mul[b * width + w], res);
}
void vecquant4matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant4matmul_cuda", ([&] {
VecQuant4MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
batch, vec_height, height, width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant4MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT4 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x];
__syncthreads();
scalar_t scale = scales[w];
scalar_t zero = zeros[w];
scalar_t res = 0;
int i = width * h + w;
int k = 0;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]);
res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3];
res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4];
res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7];
i += width;
k += 8;
}
atomicAdd(&mul[b * width + w], res);
}
void vecquant8matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant8matmul_cuda", ([&] {
VecQuant8MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
batch, vec_height, height, width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant8MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT8 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT8) * BLOCKWIDTH + threadIdx.x];
__syncthreads();
scalar_t scale = scales[w];
scalar_t zero = zeros[w];
scalar_t res = 0;
int i = width * h + w;
int k = 0;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]);
res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3];
i += width;
k += 4;
}
atomicAdd(&mul[b * width + w], res);
}
template <typename scalar_t>
__global__ void VecQuant4TransposeMatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8;
unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4);
int w = BLOCKWIDTH * blockIdx.y;
int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x;
int n_cols = b;
__shared__ scalar_t blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = vec[n_cols * vec_height + w + threadIdx.x];
__syncthreads();
scalar_t res = 0;
int i = width * h + w;
int k = 0;
int j = w;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]);
res += (scales[j] * scalar_t((tmp >> shift) & 0xF) - zeros[j]) * blockvec[k];
i += 1;
j += 1;
k += 1;
}
atomicAdd(&mul[n_cols * height * 8 + n_rows], res);
}
void vecquant4transposematmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant4transposematmul_cuda", ([&] {
VecQuant4TransposeMatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
batch, vec_height, height, width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant4MatMulHalfKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT4 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ __half blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = __half(vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x]);
__syncthreads();
__half scale = __half(scales[w]);
__half zero = __half(zeros[w]);
__half res = __float2half(0.0f);
int i = width * h + w;
int k = 0;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]);
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 0) & 0xF)), zero), blockvec[k + 0]));
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 4) & 0xF)), zero), blockvec[k + 1]));
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 8) & 0xF)), zero), blockvec[k + 2]));
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 12) & 0xF)), zero), blockvec[k + 3]));
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 16) & 0xF)), zero), blockvec[k + 4]));
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 20) & 0xF)), zero), blockvec[k + 5]));
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 24) & 0xF)), zero), blockvec[k + 6]));
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 28) & 0xF)), zero), blockvec[k + 7]));
i += width;
k += 8;
}
__half* mul2 = (__half*)mul;
#ifdef __CUDA_ARCH__
#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600
atomicAddHalf(&mul2[b * width + w], res);
#else
atomicAdd(&mul2[b * width + w], res);
#endif
#endif
}
void vecquant4matmul_half_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_SWITCH(vec.type(), "vecquant4matmul_half_cuda",
AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
VecQuant4MatMulHalfKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
batch, vec_height, height, width
);
})
));
}
template <typename scalar_t>
__global__ void VecQuant4TransposeMatMulHalfKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8;
unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4);
int w = BLOCKWIDTH * blockIdx.y;
int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x;
int n_cols = b;
__shared__ __half blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = __half(vec[n_cols * vec_height + w + threadIdx.x]);
__syncthreads();
__half res = __float2half(0.0f);
int i = width * h + w;
int k = 0;
int j = w;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]);
__half zero = __half(zeros[j]);
__half scale = __half(scales[j]);
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> shift) & 0xF)), zero), blockvec[k]));
i += 1;
j += 1;
k += 1;
}
__half* mul2 = (__half*)mul;
#ifdef __CUDA_ARCH__
#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600
atomicAddHalf(&mul2[n_cols * height * 8 + n_rows], res);
#else
atomicAdd(&mul2[n_cols * height * 8 + n_rows], res);
#endif
#endif
}
void vecquant4transposematmul_half_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_SWITCH(vec.type(), "vecquant4transposematmul_half_cuda",
AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
VecQuant4TransposeMatMulHalfKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
batch, vec_height, height, width
);
})
));
}
template <typename scalar_t>
__global__ void VecQuant4ReconsKernel(
const int* __restrict__ mat,
scalar_t* __restrict__ res,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT4 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
int n_rows = h * 8 + b;
int n_cols = w;
scalar_t scale = scales[w];
scalar_t zero = zeros[w];
int i = width * h + width * (b / 8) + w;
int shift = b % 8 * 4;
unsigned int tmp = as_unsigned(mat[i]);
scalar_t result = (scale * scalar_t((tmp >> shift) & 0xF) - zero);
res[n_rows * width + n_cols] = result;
}
void vecquant4recons_cuda(
torch::Tensor mat,
torch::Tensor res,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = BLOCKWIDTH;
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
scales.type(), "vecquant4recons_cuda", ([&] {
VecQuant4ReconsKernel<<<blocks, threads>>>(
mat.data<int>(), res.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
height, width
);
})
);
}