add half support on cuda kernel
This commit is contained in:
parent
5c1411ff18
commit
5b64833390
|
|
@ -67,10 +67,38 @@ void vecquant4transposematmul(
|
|||
vecquant4transposematmul_cuda(vec, mat, mul, scales, zeros);
|
||||
}
|
||||
|
||||
void vecquant4matmul_half_cuda(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
);
|
||||
|
||||
void vecquant4matmul_half(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant4matmul_half_cuda(vec, mat, mul, scales, zeros);
|
||||
}
|
||||
|
||||
void vecquant4transposematmul_half_cuda(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
);
|
||||
|
||||
void vecquant4transposematmul_half(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant4transposematmul_half_cuda(vec, mat, mul, scales, zeros);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant4transposematmul", &vecquant4transposematmul, "Vector 4-bit Transpose Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant4matmul_half", &vecquant4matmul_half, "Vector 4-bit Half Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant4transposematmul_half", &vecquant4transposematmul_half, "Vector 4-bit Half Transpose Quantized Matrix Multiplication (CUDA)");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,26 @@
|
|||
#include <torch/python.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
|
||||
__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
|
||||
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
unsigned short hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
|
||||
hsum += val;
|
||||
old = reinterpret_cast<size_t>(address) & 2
|
||||
? (old & 0xffff) | (hsum << 16)
|
||||
: (old & 0xffff0000) | hsum;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
|
||||
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant2MatMulKernel(
|
||||
|
|
@ -478,3 +498,153 @@ void vecquant4transposematmul_cuda(
|
|||
})
|
||||
);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant4MatMulHalfKernel(
|
||||
const scalar_t* __restrict__ vec,
|
||||
const int* __restrict__ mat,
|
||||
scalar_t* __restrict__ mul,
|
||||
const scalar_t* __restrict__ scales,
|
||||
const scalar_t* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width
|
||||
) {
|
||||
int b = blockIdx.z;
|
||||
int h = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
__shared__ __half blockvec[BLOCKWIDTH];
|
||||
blockvec[threadIdx.x] = __half(vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x]);
|
||||
__syncthreads();
|
||||
|
||||
__half scale = __half(scales[w]);
|
||||
__half zero = __half(zeros[w]);
|
||||
|
||||
__half res = __float2half(0.0f);
|
||||
int i = width * h + w;
|
||||
int k = 0;
|
||||
|
||||
unsigned int tmp;
|
||||
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp = as_unsigned(mat[i]);
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 0) & 0xF)), zero), blockvec[k + 0]));
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 4) & 0xF)), zero), blockvec[k + 1]));
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 8) & 0xF)), zero), blockvec[k + 2]));
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 12) & 0xF)), zero), blockvec[k + 3]));
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 16) & 0xF)), zero), blockvec[k + 4]));
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 20) & 0xF)), zero), blockvec[k + 5]));
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 24) & 0xF)), zero), blockvec[k + 6]));
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 28) & 0xF)), zero), blockvec[k + 7]));
|
||||
i += width;
|
||||
k += 8;
|
||||
}
|
||||
|
||||
__half* mul2 = (__half*)mul;
|
||||
atomicAdd(&mul2[b * width + w], res);
|
||||
}
|
||||
|
||||
void vecquant4matmul_half_cuda(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor zeros
|
||||
) {
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
batch
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
AT_DISPATCH_SWITCH(vec.type(), "vecquant4matmul_half_cuda",
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
|
||||
VecQuant4MatMulHalfKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
||||
batch, vec_height, height, width
|
||||
);
|
||||
})
|
||||
));
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant4TransposeMatMulHalfKernel(
|
||||
const scalar_t* __restrict__ vec,
|
||||
const int* __restrict__ mat,
|
||||
scalar_t* __restrict__ mul,
|
||||
const scalar_t* __restrict__ scales,
|
||||
const scalar_t* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width
|
||||
) {
|
||||
int b = blockIdx.z;
|
||||
int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8;
|
||||
unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4);
|
||||
int w = BLOCKWIDTH * blockIdx.y;
|
||||
|
||||
int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x;
|
||||
int n_cols = b;
|
||||
|
||||
__shared__ __half blockvec[BLOCKWIDTH];
|
||||
blockvec[threadIdx.x] = __half(vec[n_cols * vec_height + w + threadIdx.x]);
|
||||
__syncthreads();
|
||||
|
||||
__half res = __float2half(0.0f);
|
||||
int i = width * h + w;
|
||||
int k = 0;
|
||||
int j = w;
|
||||
unsigned int tmp;
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp = as_unsigned(mat[i]);
|
||||
__half zero = __half(zeros[j]);
|
||||
__half scale = __half(scales[j]);
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> shift) & 0xF)), zero), blockvec[k]));
|
||||
i += 1;
|
||||
j += 1;
|
||||
k += 1;
|
||||
}
|
||||
|
||||
__half* mul2 = (__half*)mul;
|
||||
atomicAdd(&mul2[n_cols * height * 8 + n_rows], res);
|
||||
}
|
||||
|
||||
void vecquant4transposematmul_half_cuda(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor zeros
|
||||
) {
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
batch
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
AT_DISPATCH_SWITCH(vec.type(), "vecquant4transposematmul_half_cuda",
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
|
||||
VecQuant4TransposeMatMulHalfKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
||||
batch, vec_height, height, width
|
||||
);
|
||||
})
|
||||
));
|
||||
}
|
||||
Loading…
Reference in New Issue