add half support on cuda kernel

This commit is contained in:
John Smith 2023-03-20 09:19:05 +00:00
parent 5c1411ff18
commit 5b64833390
2 changed files with 198 additions and 0 deletions

View File

@ -67,10 +67,38 @@ void vecquant4transposematmul(
vecquant4transposematmul_cuda(vec, mat, mul, scales, zeros);
}
void vecquant4matmul_half_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
);
void vecquant4matmul_half(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant4matmul_half_cuda(vec, mat, mul, scales, zeros);
}
void vecquant4transposematmul_half_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
);
void vecquant4transposematmul_half(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant4transposematmul_half_cuda(vec, mat, mul, scales, zeros);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant4transposematmul", &vecquant4transposematmul, "Vector 4-bit Transpose Quantized Matrix Multiplication (CUDA)");
m.def("vecquant4matmul_half", &vecquant4matmul_half, "Vector 4-bit Half Quantized Matrix Multiplication (CUDA)");
m.def("vecquant4transposematmul_half", &vecquant4transposematmul_half, "Vector 4-bit Half Transpose Quantized Matrix Multiplication (CUDA)");
}

View File

@ -2,6 +2,26 @@
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do {
assumed = old;
unsigned short hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
hsum += val;
old = reinterpret_cast<size_t>(address) & 2
? (old & 0xffff) | (hsum << 16)
: (old & 0xffff0000) | hsum;
old = atomicCAS(address_as_ui, assumed, old);
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);
}
template <typename scalar_t>
__global__ void VecQuant2MatMulKernel(
@ -478,3 +498,153 @@ void vecquant4transposematmul_cuda(
})
);
}
template <typename scalar_t>
__global__ void VecQuant4MatMulHalfKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT4 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ __half blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = __half(vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x]);
__syncthreads();
__half scale = __half(scales[w]);
__half zero = __half(zeros[w]);
__half res = __float2half(0.0f);
int i = width * h + w;
int k = 0;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]);
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 0) & 0xF)), zero), blockvec[k + 0]));
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 4) & 0xF)), zero), blockvec[k + 1]));
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 8) & 0xF)), zero), blockvec[k + 2]));
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 12) & 0xF)), zero), blockvec[k + 3]));
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 16) & 0xF)), zero), blockvec[k + 4]));
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 20) & 0xF)), zero), blockvec[k + 5]));
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 24) & 0xF)), zero), blockvec[k + 6]));
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 28) & 0xF)), zero), blockvec[k + 7]));
i += width;
k += 8;
}
__half* mul2 = (__half*)mul;
atomicAdd(&mul2[b * width + w], res);
}
void vecquant4matmul_half_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_SWITCH(vec.type(), "vecquant4matmul_half_cuda",
AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
VecQuant4MatMulHalfKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
batch, vec_height, height, width
);
})
));
}
template <typename scalar_t>
__global__ void VecQuant4TransposeMatMulHalfKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8;
unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4);
int w = BLOCKWIDTH * blockIdx.y;
int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x;
int n_cols = b;
__shared__ __half blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = __half(vec[n_cols * vec_height + w + threadIdx.x]);
__syncthreads();
__half res = __float2half(0.0f);
int i = width * h + w;
int k = 0;
int j = w;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]);
__half zero = __half(zeros[j]);
__half scale = __half(scales[j]);
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> shift) & 0xF)), zero), blockvec[k]));
i += 1;
j += 1;
k += 1;
}
__half* mul2 = (__half*)mul;
atomicAdd(&mul2[n_cols * height * 8 + n_rows], res);
}
void vecquant4transposematmul_half_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_SWITCH(vec.type(), "vecquant4transposematmul_half_cuda",
AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
VecQuant4TransposeMatMulHalfKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
batch, vec_height, height, width
);
})
));
}