719 lines
21 KiB
Plaintext
719 lines
21 KiB
Plaintext
#include <torch/all.h>
|
|
#include <torch/python.h>
|
|
#include <cuda.h>
|
|
#include <cuda_runtime.h>
|
|
#include <cuda_fp16.h>
|
|
|
|
#ifdef __CUDA_ARCH__
|
|
#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600
|
|
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
|
|
__device__ __forceinline__ void atomicAddHalf(__half* address, c10::Half val) {
|
|
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
|
|
unsigned int old = *address_as_ui;
|
|
unsigned int assumed;
|
|
|
|
do {
|
|
assumed = old;
|
|
unsigned short hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
|
|
hsum += val;
|
|
old = reinterpret_cast<size_t>(address) & 2
|
|
? (old & 0xffff) | (hsum << 16)
|
|
: (old & 0xffff0000) | hsum;
|
|
old = atomicCAS(address_as_ui, assumed, old);
|
|
|
|
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
|
} while (assumed != old);
|
|
}
|
|
#endif
|
|
#endif
|
|
|
|
template <typename scalar_t>
|
|
__global__ void VecQuant2MatMulKernel(
|
|
const scalar_t* __restrict__ vec,
|
|
const int* __restrict__ mat,
|
|
scalar_t* __restrict__ mul,
|
|
const scalar_t* __restrict__ scales,
|
|
const scalar_t* __restrict__ zeros,
|
|
int batch,
|
|
int vec_height,
|
|
int height,
|
|
int width
|
|
);
|
|
|
|
template <typename scalar_t>
|
|
__global__ void VecQuant3MatMulKernel(
|
|
const scalar_t* __restrict__ vec,
|
|
const int* __restrict__ mat,
|
|
scalar_t* __restrict__ mul,
|
|
const scalar_t* __restrict__ scales,
|
|
const scalar_t* __restrict__ zeros,
|
|
int batch,
|
|
int vec_height,
|
|
int height,
|
|
int width
|
|
);
|
|
|
|
template <typename scalar_t>
|
|
__global__ void VecQuant4MatMulKernel(
|
|
const scalar_t* __restrict__ vec,
|
|
const int* __restrict__ mat,
|
|
scalar_t* __restrict__ mul,
|
|
const scalar_t* __restrict__ scales,
|
|
const scalar_t* __restrict__ zeros,
|
|
int batch,
|
|
int vec_height,
|
|
int height,
|
|
int width
|
|
);
|
|
|
|
template <typename scalar_t>
|
|
__global__ void VecQuant8MatMulKernel(
|
|
const scalar_t* __restrict__ vec,
|
|
const int* __restrict__ mat,
|
|
scalar_t* __restrict__ mul,
|
|
const scalar_t* __restrict__ scales,
|
|
const scalar_t* __restrict__ zeros,
|
|
int batch,
|
|
int vec_height,
|
|
int height,
|
|
int width
|
|
);
|
|
|
|
const int BLOCKWIDTH = 256;
|
|
const int BLOCKHEIGHT2 = 16;
|
|
const int BLOCKHEIGHT3 = 24;
|
|
const int BLOCKHEIGHT4 = 32;
|
|
const int BLOCKHEIGHT8 = 64;
|
|
|
|
__device__ inline unsigned int as_unsigned(int i) {
|
|
return *reinterpret_cast<unsigned int*>(&i);
|
|
}
|
|
|
|
void vecquant2matmul_cuda(
|
|
torch::Tensor vec,
|
|
torch::Tensor mat,
|
|
torch::Tensor mul,
|
|
torch::Tensor scales,
|
|
torch::Tensor zeros
|
|
) {
|
|
int batch = vec.size(0);
|
|
int vec_height = vec.size(1);
|
|
int height = mat.size(0);
|
|
int width = mat.size(1);
|
|
|
|
dim3 blocks(
|
|
(height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
|
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
|
batch
|
|
);
|
|
dim3 threads(BLOCKWIDTH);
|
|
|
|
AT_DISPATCH_FLOATING_TYPES(
|
|
vec.type(), "vecquant2matmul_cuda", ([&] {
|
|
VecQuant2MatMulKernel<<<blocks, threads>>>(
|
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
|
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
|
batch, vec_height, height, width
|
|
);
|
|
})
|
|
);
|
|
}
|
|
|
|
template <typename scalar_t>
|
|
__global__ void VecQuant2MatMulKernel(
|
|
const scalar_t* __restrict__ vec,
|
|
const int* __restrict__ mat,
|
|
scalar_t* __restrict__ mul,
|
|
const scalar_t* __restrict__ scales,
|
|
const scalar_t* __restrict__ zeros,
|
|
int batch,
|
|
int vec_height,
|
|
int height,
|
|
int width
|
|
) {
|
|
int b = blockIdx.z;
|
|
int h = BLOCKHEIGHT2 * blockIdx.x;
|
|
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
|
|
|
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
|
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT2) * BLOCKWIDTH + threadIdx.x];
|
|
__syncthreads();
|
|
|
|
scalar_t scale = scales[w];
|
|
scalar_t zero = zeros[w];
|
|
|
|
scalar_t res = 0;
|
|
int i = width * h + w;
|
|
int k = 0;
|
|
|
|
unsigned int tmp;
|
|
|
|
while (k < BLOCKWIDTH) {
|
|
tmp = as_unsigned(mat[i]);
|
|
res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
|
|
res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
|
|
res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2];
|
|
res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3];
|
|
res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4];
|
|
res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5];
|
|
res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6];
|
|
res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7];
|
|
res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8];
|
|
res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9];
|
|
res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10];
|
|
res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11];
|
|
res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12];
|
|
res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13];
|
|
res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14];
|
|
res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15];
|
|
i += width;
|
|
k += 16;
|
|
}
|
|
|
|
atomicAdd(&mul[b * width + w], res);
|
|
}
|
|
|
|
void vecquant3matmul_cuda(
|
|
torch::Tensor vec,
|
|
torch::Tensor mat,
|
|
torch::Tensor mul,
|
|
torch::Tensor scales,
|
|
torch::Tensor zeros
|
|
) {
|
|
int batch = vec.size(0);
|
|
int vec_height = vec.size(1);
|
|
int height = mat.size(0);
|
|
int width = mat.size(1);
|
|
|
|
dim3 blocks(
|
|
(height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
|
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
|
batch
|
|
);
|
|
dim3 threads(BLOCKWIDTH);
|
|
|
|
AT_DISPATCH_FLOATING_TYPES(
|
|
vec.type(), "vecquant3matmul_cuda", ([&] {
|
|
VecQuant3MatMulKernel<<<blocks, threads>>>(
|
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
|
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
|
batch, vec_height, height, width
|
|
);
|
|
})
|
|
);
|
|
}
|
|
|
|
template <typename scalar_t>
|
|
__global__ void VecQuant3MatMulKernel(
|
|
const scalar_t* __restrict__ vec,
|
|
const int* __restrict__ mat,
|
|
scalar_t* __restrict__ mul,
|
|
const scalar_t* __restrict__ scales,
|
|
const scalar_t* __restrict__ zeros,
|
|
int batch,
|
|
int vec_height,
|
|
int height,
|
|
int width
|
|
) {
|
|
int b = blockIdx.z;
|
|
int h = BLOCKHEIGHT3 * blockIdx.x;
|
|
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
|
|
|
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
|
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT3) * BLOCKWIDTH + threadIdx.x];
|
|
__syncthreads();
|
|
|
|
scalar_t scale = scales[w];
|
|
scalar_t zero = zeros[w];
|
|
|
|
scalar_t res = 0;
|
|
int i = width * h + w;
|
|
int k = 0;
|
|
|
|
unsigned int tmp1;
|
|
unsigned int tmp2;
|
|
unsigned int tmp;
|
|
|
|
while (k < BLOCKWIDTH) {
|
|
tmp1 = as_unsigned(mat[i]);
|
|
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
|
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
|
|
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
|
|
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
|
|
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
|
|
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
|
|
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
|
|
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
|
|
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
|
|
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
|
|
i += width;
|
|
tmp2 = as_unsigned(mat[i]);
|
|
tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4);
|
|
tmp2 >>= 1;
|
|
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
|
|
k += 11;
|
|
res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
|
res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1];
|
|
res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2];
|
|
res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3];
|
|
res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4];
|
|
res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5];
|
|
res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6];
|
|
res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7];
|
|
res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8];
|
|
res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9];
|
|
i += width;
|
|
tmp1 = as_unsigned(mat[i]);
|
|
tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6);
|
|
tmp1 >>= 2;
|
|
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
|
|
k += 11;
|
|
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
|
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
|
|
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
|
|
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
|
|
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
|
|
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
|
|
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
|
|
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
|
|
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
|
|
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
|
|
i += width;
|
|
k += 10;
|
|
}
|
|
|
|
atomicAdd(&mul[b * width + w], res);
|
|
}
|
|
|
|
void vecquant4matmul_cuda(
|
|
torch::Tensor vec,
|
|
torch::Tensor mat,
|
|
torch::Tensor mul,
|
|
torch::Tensor scales,
|
|
torch::Tensor zeros
|
|
) {
|
|
int batch = vec.size(0);
|
|
int vec_height = vec.size(1);
|
|
int height = mat.size(0);
|
|
int width = mat.size(1);
|
|
|
|
dim3 blocks(
|
|
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
|
batch
|
|
);
|
|
dim3 threads(BLOCKWIDTH);
|
|
|
|
AT_DISPATCH_FLOATING_TYPES(
|
|
vec.type(), "vecquant4matmul_cuda", ([&] {
|
|
VecQuant4MatMulKernel<<<blocks, threads>>>(
|
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
|
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
|
batch, vec_height, height, width
|
|
);
|
|
})
|
|
);
|
|
}
|
|
|
|
template <typename scalar_t>
|
|
__global__ void VecQuant4MatMulKernel(
|
|
const scalar_t* __restrict__ vec,
|
|
const int* __restrict__ mat,
|
|
scalar_t* __restrict__ mul,
|
|
const scalar_t* __restrict__ scales,
|
|
const scalar_t* __restrict__ zeros,
|
|
int batch,
|
|
int vec_height,
|
|
int height,
|
|
int width
|
|
) {
|
|
int b = blockIdx.z;
|
|
int h = BLOCKHEIGHT4 * blockIdx.x;
|
|
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
|
|
|
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
|
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x];
|
|
__syncthreads();
|
|
|
|
scalar_t scale = scales[w];
|
|
scalar_t zero = zeros[w];
|
|
|
|
scalar_t res = 0;
|
|
int i = width * h + w;
|
|
int k = 0;
|
|
|
|
unsigned int tmp;
|
|
|
|
while (k < BLOCKWIDTH) {
|
|
tmp = as_unsigned(mat[i]);
|
|
res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
|
|
res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
|
|
res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2];
|
|
res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3];
|
|
res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4];
|
|
res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5];
|
|
res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6];
|
|
res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7];
|
|
i += width;
|
|
k += 8;
|
|
}
|
|
|
|
atomicAdd(&mul[b * width + w], res);
|
|
}
|
|
|
|
void vecquant8matmul_cuda(
|
|
torch::Tensor vec,
|
|
torch::Tensor mat,
|
|
torch::Tensor mul,
|
|
torch::Tensor scales,
|
|
torch::Tensor zeros
|
|
) {
|
|
int batch = vec.size(0);
|
|
int vec_height = vec.size(1);
|
|
int height = mat.size(0);
|
|
int width = mat.size(1);
|
|
|
|
dim3 blocks(
|
|
(height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
|
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
|
batch
|
|
);
|
|
dim3 threads(BLOCKWIDTH);
|
|
|
|
AT_DISPATCH_FLOATING_TYPES(
|
|
vec.type(), "vecquant8matmul_cuda", ([&] {
|
|
VecQuant8MatMulKernel<<<blocks, threads>>>(
|
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
|
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
|
batch, vec_height, height, width
|
|
);
|
|
})
|
|
);
|
|
}
|
|
|
|
template <typename scalar_t>
|
|
__global__ void VecQuant8MatMulKernel(
|
|
const scalar_t* __restrict__ vec,
|
|
const int* __restrict__ mat,
|
|
scalar_t* __restrict__ mul,
|
|
const scalar_t* __restrict__ scales,
|
|
const scalar_t* __restrict__ zeros,
|
|
int batch,
|
|
int vec_height,
|
|
int height,
|
|
int width
|
|
) {
|
|
int b = blockIdx.z;
|
|
int h = BLOCKHEIGHT8 * blockIdx.x;
|
|
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
|
|
|
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
|
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT8) * BLOCKWIDTH + threadIdx.x];
|
|
__syncthreads();
|
|
|
|
scalar_t scale = scales[w];
|
|
scalar_t zero = zeros[w];
|
|
|
|
scalar_t res = 0;
|
|
int i = width * h + w;
|
|
int k = 0;
|
|
|
|
unsigned int tmp;
|
|
|
|
while (k < BLOCKWIDTH) {
|
|
tmp = as_unsigned(mat[i]);
|
|
res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
|
|
res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
|
|
res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2];
|
|
res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3];
|
|
i += width;
|
|
k += 4;
|
|
}
|
|
|
|
atomicAdd(&mul[b * width + w], res);
|
|
}
|
|
|
|
template <typename scalar_t>
|
|
__global__ void VecQuant4TransposeMatMulKernel(
|
|
const scalar_t* __restrict__ vec,
|
|
const int* __restrict__ mat,
|
|
scalar_t* __restrict__ mul,
|
|
const scalar_t* __restrict__ scales,
|
|
const scalar_t* __restrict__ zeros,
|
|
int batch,
|
|
int vec_height,
|
|
int height,
|
|
int width
|
|
) {
|
|
int b = blockIdx.z;
|
|
int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8;
|
|
unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4);
|
|
int w = BLOCKWIDTH * blockIdx.y;
|
|
|
|
int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x;
|
|
int n_cols = b;
|
|
|
|
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
|
blockvec[threadIdx.x] = vec[n_cols * vec_height + w + threadIdx.x];
|
|
__syncthreads();
|
|
|
|
scalar_t res = 0;
|
|
int i = width * h + w;
|
|
int k = 0;
|
|
int j = w;
|
|
unsigned int tmp;
|
|
while (k < BLOCKWIDTH) {
|
|
tmp = as_unsigned(mat[i]);
|
|
res += (scales[j] * scalar_t((tmp >> shift) & 0xF) - zeros[j]) * blockvec[k];
|
|
i += 1;
|
|
j += 1;
|
|
k += 1;
|
|
}
|
|
|
|
atomicAdd(&mul[n_cols * height * 8 + n_rows], res);
|
|
}
|
|
|
|
void vecquant4transposematmul_cuda(
|
|
torch::Tensor vec,
|
|
torch::Tensor mat,
|
|
torch::Tensor mul,
|
|
torch::Tensor scales,
|
|
torch::Tensor zeros
|
|
) {
|
|
int batch = vec.size(0);
|
|
int vec_height = vec.size(1);
|
|
int height = mat.size(0);
|
|
int width = mat.size(1);
|
|
|
|
dim3 blocks(
|
|
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
|
batch
|
|
);
|
|
dim3 threads(BLOCKWIDTH);
|
|
|
|
AT_DISPATCH_FLOATING_TYPES(
|
|
vec.type(), "vecquant4transposematmul_cuda", ([&] {
|
|
VecQuant4TransposeMatMulKernel<<<blocks, threads>>>(
|
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
|
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
|
batch, vec_height, height, width
|
|
);
|
|
})
|
|
);
|
|
}
|
|
|
|
template <typename scalar_t>
|
|
__global__ void VecQuant4MatMulHalfKernel(
|
|
const scalar_t* __restrict__ vec,
|
|
const int* __restrict__ mat,
|
|
scalar_t* __restrict__ mul,
|
|
const scalar_t* __restrict__ scales,
|
|
const scalar_t* __restrict__ zeros,
|
|
int batch,
|
|
int vec_height,
|
|
int height,
|
|
int width
|
|
) {
|
|
int b = blockIdx.z;
|
|
int h = BLOCKHEIGHT4 * blockIdx.x;
|
|
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
|
|
|
__shared__ __half blockvec[BLOCKWIDTH];
|
|
blockvec[threadIdx.x] = __half(vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x]);
|
|
__syncthreads();
|
|
|
|
__half scale = __half(scales[w]);
|
|
__half zero = __half(zeros[w]);
|
|
|
|
__half res = __float2half(0.0f);
|
|
int i = width * h + w;
|
|
int k = 0;
|
|
|
|
unsigned int tmp;
|
|
|
|
while (k < BLOCKWIDTH) {
|
|
tmp = as_unsigned(mat[i]);
|
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 0) & 0xF)), zero), blockvec[k + 0]));
|
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 4) & 0xF)), zero), blockvec[k + 1]));
|
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 8) & 0xF)), zero), blockvec[k + 2]));
|
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 12) & 0xF)), zero), blockvec[k + 3]));
|
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 16) & 0xF)), zero), blockvec[k + 4]));
|
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 20) & 0xF)), zero), blockvec[k + 5]));
|
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 24) & 0xF)), zero), blockvec[k + 6]));
|
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 28) & 0xF)), zero), blockvec[k + 7]));
|
|
i += width;
|
|
k += 8;
|
|
}
|
|
|
|
__half* mul2 = (__half*)mul;
|
|
#ifdef __CUDA_ARCH__
|
|
#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600
|
|
atomicAddHalf(&mul2[b * width + w], res);
|
|
#else
|
|
atomicAdd(&mul2[b * width + w], res);
|
|
#endif
|
|
#endif
|
|
|
|
}
|
|
|
|
void vecquant4matmul_half_cuda(
|
|
torch::Tensor vec,
|
|
torch::Tensor mat,
|
|
torch::Tensor mul,
|
|
torch::Tensor scales,
|
|
torch::Tensor zeros
|
|
) {
|
|
int batch = vec.size(0);
|
|
int vec_height = vec.size(1);
|
|
int height = mat.size(0);
|
|
int width = mat.size(1);
|
|
|
|
dim3 blocks(
|
|
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
|
batch
|
|
);
|
|
dim3 threads(BLOCKWIDTH);
|
|
|
|
AT_DISPATCH_SWITCH(vec.type(), "vecquant4matmul_half_cuda",
|
|
AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
|
|
VecQuant4MatMulHalfKernel<<<blocks, threads>>>(
|
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
|
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
|
batch, vec_height, height, width
|
|
);
|
|
})
|
|
));
|
|
}
|
|
|
|
template <typename scalar_t>
|
|
__global__ void VecQuant4TransposeMatMulHalfKernel(
|
|
const scalar_t* __restrict__ vec,
|
|
const int* __restrict__ mat,
|
|
scalar_t* __restrict__ mul,
|
|
const scalar_t* __restrict__ scales,
|
|
const scalar_t* __restrict__ zeros,
|
|
int batch,
|
|
int vec_height,
|
|
int height,
|
|
int width
|
|
) {
|
|
int b = blockIdx.z;
|
|
int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8;
|
|
unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4);
|
|
int w = BLOCKWIDTH * blockIdx.y;
|
|
|
|
int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x;
|
|
int n_cols = b;
|
|
|
|
__shared__ __half blockvec[BLOCKWIDTH];
|
|
blockvec[threadIdx.x] = __half(vec[n_cols * vec_height + w + threadIdx.x]);
|
|
__syncthreads();
|
|
|
|
__half res = __float2half(0.0f);
|
|
int i = width * h + w;
|
|
int k = 0;
|
|
int j = w;
|
|
unsigned int tmp;
|
|
while (k < BLOCKWIDTH) {
|
|
tmp = as_unsigned(mat[i]);
|
|
__half zero = __half(zeros[j]);
|
|
__half scale = __half(scales[j]);
|
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> shift) & 0xF)), zero), blockvec[k]));
|
|
i += 1;
|
|
j += 1;
|
|
k += 1;
|
|
}
|
|
|
|
__half* mul2 = (__half*)mul;
|
|
#ifdef __CUDA_ARCH__
|
|
#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600
|
|
atomicAddHalf(&mul2[n_cols * height * 8 + n_rows], res);
|
|
#else
|
|
atomicAdd(&mul2[n_cols * height * 8 + n_rows], res);
|
|
#endif
|
|
#endif
|
|
}
|
|
|
|
void vecquant4transposematmul_half_cuda(
|
|
torch::Tensor vec,
|
|
torch::Tensor mat,
|
|
torch::Tensor mul,
|
|
torch::Tensor scales,
|
|
torch::Tensor zeros
|
|
) {
|
|
int batch = vec.size(0);
|
|
int vec_height = vec.size(1);
|
|
int height = mat.size(0);
|
|
int width = mat.size(1);
|
|
|
|
dim3 blocks(
|
|
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
|
batch
|
|
);
|
|
dim3 threads(BLOCKWIDTH);
|
|
|
|
AT_DISPATCH_SWITCH(vec.type(), "vecquant4transposematmul_half_cuda",
|
|
AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
|
|
VecQuant4TransposeMatMulHalfKernel<<<blocks, threads>>>(
|
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
|
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
|
batch, vec_height, height, width
|
|
);
|
|
})
|
|
));
|
|
}
|
|
|
|
template <typename scalar_t>
|
|
__global__ void VecQuant4ReconsKernel(
|
|
const int* __restrict__ mat,
|
|
scalar_t* __restrict__ res,
|
|
const scalar_t* __restrict__ scales,
|
|
const scalar_t* __restrict__ zeros,
|
|
int height,
|
|
int width
|
|
) {
|
|
int b = blockIdx.z;
|
|
int h = BLOCKHEIGHT4 * blockIdx.x;
|
|
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
|
int n_rows = h * 8 + b;
|
|
int n_cols = w;
|
|
scalar_t scale = scales[w];
|
|
scalar_t zero = zeros[w];
|
|
int i = width * h + width * (b / 8) + w;
|
|
int shift = b % 8 * 4;
|
|
unsigned int tmp = as_unsigned(mat[i]);
|
|
scalar_t result = (scale * scalar_t((tmp >> shift) & 0xF) - zero);
|
|
res[n_rows * width + n_cols] = result;
|
|
}
|
|
|
|
void vecquant4recons_cuda(
|
|
torch::Tensor mat,
|
|
torch::Tensor res,
|
|
torch::Tensor scales,
|
|
torch::Tensor zeros
|
|
) {
|
|
int batch = BLOCKWIDTH;
|
|
int height = mat.size(0);
|
|
int width = mat.size(1);
|
|
|
|
dim3 blocks(
|
|
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
|
batch
|
|
);
|
|
dim3 threads(BLOCKWIDTH);
|
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
|
scales.type(), "vecquant4recons_cuda", ([&] {
|
|
VecQuant4ReconsKernel<<<blocks, threads>>>(
|
|
mat.data<int>(), res.data<scalar_t>(),
|
|
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
|
height, width
|
|
);
|
|
})
|
|
);
|
|
}
|