add fast_4bit_matmul and auto switch 2 methods according to bottleneck
This commit is contained in:
parent
dd0d5a31f7
commit
3471be4e56
|
|
@ -5,6 +5,19 @@ import torch.nn as nn
|
|||
import time
|
||||
|
||||
|
||||
# Global Buffer
|
||||
buffer_mat_dic = {}
|
||||
use_new = True
|
||||
auto_switch = True
|
||||
auto_switch_thd = 16
|
||||
|
||||
|
||||
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
|
||||
if shape_of_qweight not in buffer_mat_dic.keys():
|
||||
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
|
||||
return buffer_mat_dic[shape_of_qweight]
|
||||
|
||||
|
||||
def matmul4bit(x, qweight, scales, zeros):
|
||||
"""
|
||||
input x: (n, m)
|
||||
|
|
@ -87,6 +100,23 @@ def matmul4bit_transpose_half(x, qweight, scales, zeros):
|
|||
return y.reshape(outshape)
|
||||
|
||||
|
||||
def fast_4bit_forward(x, qweight, scales, zeros, bias):
|
||||
use_new_flag = use_new
|
||||
if auto_switch:
|
||||
if x.shape[1] > auto_switch_thd:
|
||||
use_new_flag = True
|
||||
else:
|
||||
use_new_flag = False
|
||||
if use_new_flag:
|
||||
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
||||
quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
|
||||
output = torch.matmul(x, buffer)
|
||||
else:
|
||||
output = matmul4bit(x, qweight, scales.float(), zeros.float())
|
||||
output += bias
|
||||
return output
|
||||
|
||||
|
||||
class AutogradMatmul4bit(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -94,12 +124,7 @@ class AutogradMatmul4bit(torch.autograd.Function):
|
|||
ctx.save_for_backward(qweight, scales, zeros)
|
||||
|
||||
# equals to torch.matmul(x, qweight)
|
||||
if x.dtype == torch.float32:
|
||||
output = matmul4bit(x, qweight, scales, zeros).clone()
|
||||
elif x.dtype == torch.float16:
|
||||
output = matmul4bit_half(x, qweight, scales, zeros).clone()
|
||||
else:
|
||||
raise ValueError('Only float and half are supportted.')
|
||||
output = matmul4bit(x, qweight, scales, zeros).clone()
|
||||
|
||||
return output
|
||||
|
||||
|
|
@ -108,12 +133,7 @@ class AutogradMatmul4bit(torch.autograd.Function):
|
|||
qweight, scales, zeros = ctx.saved_tensors
|
||||
|
||||
# compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T
|
||||
if grad_output.dtype == torch.float32:
|
||||
grad = matmul4bit_transpose(grad_output, qweight, scales, zeros)
|
||||
elif grad_output.dtype == torch.float16:
|
||||
grad = matmul4bit_transpose_half(grad_output, qweight, scales, zeros)
|
||||
else:
|
||||
raise ValueError('Only float and half are supportted.')
|
||||
grad = matmul4bit_transpose(grad_output, qweight, scales, zeros)
|
||||
|
||||
return grad, None, None, None
|
||||
|
||||
|
|
@ -135,8 +155,7 @@ class Autograd4bitQuantLinear(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, self.zeros)
|
||||
out += self.bias
|
||||
out = fast_4bit_forward(x, self.qweight, self.scales, self.zeros, self.bias)
|
||||
return out
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include <torch/all.h>
|
||||
#include <torch/python.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <vector>
|
||||
|
||||
void vecquant2matmul_cuda(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
|
|
@ -93,6 +94,28 @@ void vecquant4transposematmul_half(
|
|||
vecquant4transposematmul_half_cuda(vec, mat, mul, scales, zeros);
|
||||
}
|
||||
|
||||
void vecquant4recons_cuda(
|
||||
torch::Tensor mat, torch::Tensor res,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
);
|
||||
|
||||
void vecquant4recons(
|
||||
torch::Tensor mat, torch::Tensor res,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(mat));
|
||||
vecquant4recons_cuda(mat, res, scales, zeros);
|
||||
}
|
||||
|
||||
at::Tensor fast_4bit_forward(
|
||||
torch::Tensor x, torch::Tensor mat, torch::Tensor buffer,
|
||||
torch::Tensor scales, torch::Tensor zeros, torch::Tensor bias) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(mat));
|
||||
vecquant4recons_cuda(mat, buffer, scales, zeros);
|
||||
auto result = at::addmm(bias, x, buffer);
|
||||
return result;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
|
||||
|
|
@ -101,4 +124,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
m.def("vecquant4transposematmul", &vecquant4transposematmul, "Vector 4-bit Transpose Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant4matmul_half", &vecquant4matmul_half, "Vector 4-bit Half Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant4transposematmul_half", &vecquant4transposematmul_half, "Vector 4-bit Half Transpose Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant4recons", &vecquant4recons, "Vector 4-bit Matrix Reconstruction (CUDA)");
|
||||
m.def("fast_4bit_forward", &fast_4bit_forward, "Vector 4-bit Fast Forward");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -647,4 +647,55 @@ void vecquant4transposematmul_half_cuda(
|
|||
);
|
||||
})
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant4ReconsKernel(
|
||||
const int* __restrict__ mat,
|
||||
scalar_t* __restrict__ res,
|
||||
const scalar_t* __restrict__ scales,
|
||||
const scalar_t* __restrict__ zeros,
|
||||
int height,
|
||||
int width
|
||||
) {
|
||||
int b = blockIdx.z;
|
||||
int h = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
int n_rows = h * 8 + b;
|
||||
int n_cols = w;
|
||||
scalar_t scale = scales[w];
|
||||
scalar_t zero = zeros[w];
|
||||
int i = width * h + width * (b / 8) + w;
|
||||
int shift = b % 8 * 4;
|
||||
unsigned int tmp = as_unsigned(mat[i]);
|
||||
scalar_t result = (scale * scalar_t((tmp >> shift) & 0xF) - zero);
|
||||
res[n_rows * width + n_cols] = result;
|
||||
}
|
||||
|
||||
void vecquant4recons_cuda(
|
||||
torch::Tensor mat,
|
||||
torch::Tensor res,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor zeros
|
||||
) {
|
||||
int batch = BLOCKWIDTH;
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
batch
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
scales.type(), "vecquant4recons_cuda", ([&] {
|
||||
VecQuant4ReconsKernel<<<blocks, threads>>>(
|
||||
mat.data<int>(), res.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
||||
height, width
|
||||
);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue