add half support on cuda kernel
This commit is contained in:
parent
5c1411ff18
commit
5b64833390
|
|
@ -67,10 +67,38 @@ void vecquant4transposematmul(
|
||||||
vecquant4transposematmul_cuda(vec, mat, mul, scales, zeros);
|
vecquant4transposematmul_cuda(vec, mat, mul, scales, zeros);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void vecquant4matmul_half_cuda(
|
||||||
|
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||||
|
torch::Tensor scales, torch::Tensor zeros
|
||||||
|
);
|
||||||
|
|
||||||
|
void vecquant4matmul_half(
|
||||||
|
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||||
|
torch::Tensor scales, torch::Tensor zeros
|
||||||
|
) {
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||||
|
vecquant4matmul_half_cuda(vec, mat, mul, scales, zeros);
|
||||||
|
}
|
||||||
|
|
||||||
|
void vecquant4transposematmul_half_cuda(
|
||||||
|
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||||
|
torch::Tensor scales, torch::Tensor zeros
|
||||||
|
);
|
||||||
|
|
||||||
|
void vecquant4transposematmul_half(
|
||||||
|
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||||
|
torch::Tensor scales, torch::Tensor zeros
|
||||||
|
) {
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||||
|
vecquant4transposematmul_half_cuda(vec, mat, mul, scales, zeros);
|
||||||
|
}
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
|
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
|
||||||
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
|
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
|
||||||
m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)");
|
m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)");
|
||||||
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
|
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
|
||||||
m.def("vecquant4transposematmul", &vecquant4transposematmul, "Vector 4-bit Transpose Quantized Matrix Multiplication (CUDA)");
|
m.def("vecquant4transposematmul", &vecquant4transposematmul, "Vector 4-bit Transpose Quantized Matrix Multiplication (CUDA)");
|
||||||
|
m.def("vecquant4matmul_half", &vecquant4matmul_half, "Vector 4-bit Half Quantized Matrix Multiplication (CUDA)");
|
||||||
|
m.def("vecquant4transposematmul_half", &vecquant4transposematmul_half, "Vector 4-bit Half Transpose Quantized Matrix Multiplication (CUDA)");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,26 @@
|
||||||
#include <torch/python.h>
|
#include <torch/python.h>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
|
||||||
|
__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
|
||||||
|
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
|
||||||
|
unsigned int old = *address_as_ui;
|
||||||
|
unsigned int assumed;
|
||||||
|
|
||||||
|
do {
|
||||||
|
assumed = old;
|
||||||
|
unsigned short hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
|
||||||
|
hsum += val;
|
||||||
|
old = reinterpret_cast<size_t>(address) & 2
|
||||||
|
? (old & 0xffff) | (hsum << 16)
|
||||||
|
: (old & 0xffff0000) | hsum;
|
||||||
|
old = atomicCAS(address_as_ui, assumed, old);
|
||||||
|
|
||||||
|
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
||||||
|
} while (assumed != old);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void VecQuant2MatMulKernel(
|
__global__ void VecQuant2MatMulKernel(
|
||||||
|
|
@ -478,3 +498,153 @@ void vecquant4transposematmul_cuda(
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void VecQuant4MatMulHalfKernel(
|
||||||
|
const scalar_t* __restrict__ vec,
|
||||||
|
const int* __restrict__ mat,
|
||||||
|
scalar_t* __restrict__ mul,
|
||||||
|
const scalar_t* __restrict__ scales,
|
||||||
|
const scalar_t* __restrict__ zeros,
|
||||||
|
int batch,
|
||||||
|
int vec_height,
|
||||||
|
int height,
|
||||||
|
int width
|
||||||
|
) {
|
||||||
|
int b = blockIdx.z;
|
||||||
|
int h = BLOCKHEIGHT4 * blockIdx.x;
|
||||||
|
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||||
|
|
||||||
|
__shared__ __half blockvec[BLOCKWIDTH];
|
||||||
|
blockvec[threadIdx.x] = __half(vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
__half scale = __half(scales[w]);
|
||||||
|
__half zero = __half(zeros[w]);
|
||||||
|
|
||||||
|
__half res = __float2half(0.0f);
|
||||||
|
int i = width * h + w;
|
||||||
|
int k = 0;
|
||||||
|
|
||||||
|
unsigned int tmp;
|
||||||
|
|
||||||
|
while (k < BLOCKWIDTH) {
|
||||||
|
tmp = as_unsigned(mat[i]);
|
||||||
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 0) & 0xF)), zero), blockvec[k + 0]));
|
||||||
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 4) & 0xF)), zero), blockvec[k + 1]));
|
||||||
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 8) & 0xF)), zero), blockvec[k + 2]));
|
||||||
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 12) & 0xF)), zero), blockvec[k + 3]));
|
||||||
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 16) & 0xF)), zero), blockvec[k + 4]));
|
||||||
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 20) & 0xF)), zero), blockvec[k + 5]));
|
||||||
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 24) & 0xF)), zero), blockvec[k + 6]));
|
||||||
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 28) & 0xF)), zero), blockvec[k + 7]));
|
||||||
|
i += width;
|
||||||
|
k += 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
__half* mul2 = (__half*)mul;
|
||||||
|
atomicAdd(&mul2[b * width + w], res);
|
||||||
|
}
|
||||||
|
|
||||||
|
void vecquant4matmul_half_cuda(
|
||||||
|
torch::Tensor vec,
|
||||||
|
torch::Tensor mat,
|
||||||
|
torch::Tensor mul,
|
||||||
|
torch::Tensor scales,
|
||||||
|
torch::Tensor zeros
|
||||||
|
) {
|
||||||
|
int batch = vec.size(0);
|
||||||
|
int vec_height = vec.size(1);
|
||||||
|
int height = mat.size(0);
|
||||||
|
int width = mat.size(1);
|
||||||
|
|
||||||
|
dim3 blocks(
|
||||||
|
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||||
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||||
|
batch
|
||||||
|
);
|
||||||
|
dim3 threads(BLOCKWIDTH);
|
||||||
|
|
||||||
|
AT_DISPATCH_SWITCH(vec.type(), "vecquant4matmul_half_cuda",
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
|
||||||
|
VecQuant4MatMulHalfKernel<<<blocks, threads>>>(
|
||||||
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||||
|
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
||||||
|
batch, vec_height, height, width
|
||||||
|
);
|
||||||
|
})
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void VecQuant4TransposeMatMulHalfKernel(
|
||||||
|
const scalar_t* __restrict__ vec,
|
||||||
|
const int* __restrict__ mat,
|
||||||
|
scalar_t* __restrict__ mul,
|
||||||
|
const scalar_t* __restrict__ scales,
|
||||||
|
const scalar_t* __restrict__ zeros,
|
||||||
|
int batch,
|
||||||
|
int vec_height,
|
||||||
|
int height,
|
||||||
|
int width
|
||||||
|
) {
|
||||||
|
int b = blockIdx.z;
|
||||||
|
int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8;
|
||||||
|
unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4);
|
||||||
|
int w = BLOCKWIDTH * blockIdx.y;
|
||||||
|
|
||||||
|
int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x;
|
||||||
|
int n_cols = b;
|
||||||
|
|
||||||
|
__shared__ __half blockvec[BLOCKWIDTH];
|
||||||
|
blockvec[threadIdx.x] = __half(vec[n_cols * vec_height + w + threadIdx.x]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
__half res = __float2half(0.0f);
|
||||||
|
int i = width * h + w;
|
||||||
|
int k = 0;
|
||||||
|
int j = w;
|
||||||
|
unsigned int tmp;
|
||||||
|
while (k < BLOCKWIDTH) {
|
||||||
|
tmp = as_unsigned(mat[i]);
|
||||||
|
__half zero = __half(zeros[j]);
|
||||||
|
__half scale = __half(scales[j]);
|
||||||
|
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> shift) & 0xF)), zero), blockvec[k]));
|
||||||
|
i += 1;
|
||||||
|
j += 1;
|
||||||
|
k += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
__half* mul2 = (__half*)mul;
|
||||||
|
atomicAdd(&mul2[n_cols * height * 8 + n_rows], res);
|
||||||
|
}
|
||||||
|
|
||||||
|
void vecquant4transposematmul_half_cuda(
|
||||||
|
torch::Tensor vec,
|
||||||
|
torch::Tensor mat,
|
||||||
|
torch::Tensor mul,
|
||||||
|
torch::Tensor scales,
|
||||||
|
torch::Tensor zeros
|
||||||
|
) {
|
||||||
|
int batch = vec.size(0);
|
||||||
|
int vec_height = vec.size(1);
|
||||||
|
int height = mat.size(0);
|
||||||
|
int width = mat.size(1);
|
||||||
|
|
||||||
|
dim3 blocks(
|
||||||
|
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||||
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||||
|
batch
|
||||||
|
);
|
||||||
|
dim3 threads(BLOCKWIDTH);
|
||||||
|
|
||||||
|
AT_DISPATCH_SWITCH(vec.type(), "vecquant4transposematmul_half_cuda",
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
|
||||||
|
VecQuant4TransposeMatMulHalfKernel<<<blocks, threads>>>(
|
||||||
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||||
|
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
||||||
|
batch, vec_height, height, width
|
||||||
|
);
|
||||||
|
})
|
||||||
|
));
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue