add fast_4bit_matmul and auto switch 2 methods according to bottleneck

This commit is contained in:
John Smith 2023-03-21 08:43:07 +00:00
parent dd0d5a31f7
commit 3471be4e56
3 changed files with 110 additions and 15 deletions

View File

@ -5,6 +5,19 @@ import torch.nn as nn
import time
# Global Buffer
buffer_mat_dic = {}
use_new = True
auto_switch = True
auto_switch_thd = 16
def get_buffer(shape_of_qweight, dtype=torch.float16, device='cuda'):
if shape_of_qweight not in buffer_mat_dic.keys():
buffer_mat_dic[shape_of_qweight] = torch.zeros((shape_of_qweight[0] * 8, shape_of_qweight[1]), dtype=dtype, device=device)
return buffer_mat_dic[shape_of_qweight]
def matmul4bit(x, qweight, scales, zeros):
"""
input x: (n, m)
@ -87,6 +100,23 @@ def matmul4bit_transpose_half(x, qweight, scales, zeros):
return y.reshape(outshape)
def fast_4bit_forward(x, qweight, scales, zeros, bias):
use_new_flag = use_new
if auto_switch:
if x.shape[1] > auto_switch_thd:
use_new_flag = True
else:
use_new_flag = False
if use_new_flag:
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
output = torch.matmul(x, buffer)
else:
output = matmul4bit(x, qweight, scales.float(), zeros.float())
output += bias
return output
class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod
@ -94,12 +124,7 @@ class AutogradMatmul4bit(torch.autograd.Function):
ctx.save_for_backward(qweight, scales, zeros)
# equals to torch.matmul(x, qweight)
if x.dtype == torch.float32:
output = matmul4bit(x, qweight, scales, zeros).clone()
elif x.dtype == torch.float16:
output = matmul4bit_half(x, qweight, scales, zeros).clone()
else:
raise ValueError('Only float and half are supportted.')
output = matmul4bit(x, qweight, scales, zeros).clone()
return output
@ -108,12 +133,7 @@ class AutogradMatmul4bit(torch.autograd.Function):
qweight, scales, zeros = ctx.saved_tensors
# compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T
if grad_output.dtype == torch.float32:
grad = matmul4bit_transpose(grad_output, qweight, scales, zeros)
elif grad_output.dtype == torch.float16:
grad = matmul4bit_transpose_half(grad_output, qweight, scales, zeros)
else:
raise ValueError('Only float and half are supportted.')
grad = matmul4bit_transpose(grad_output, qweight, scales, zeros)
return grad, None, None, None
@ -135,8 +155,7 @@ class Autograd4bitQuantLinear(nn.Module):
)
def forward(self, x):
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, self.zeros)
out += self.bias
out = fast_4bit_forward(x, self.qweight, self.scales, self.zeros, self.bias)
return out

View File

@ -1,6 +1,7 @@
#include <torch/all.h>
#include <torch/python.h>
#include <c10/cuda/CUDAGuard.h>
#include <vector>
void vecquant2matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
@ -93,6 +94,28 @@ void vecquant4transposematmul_half(
vecquant4transposematmul_half_cuda(vec, mat, mul, scales, zeros);
}
void vecquant4recons_cuda(
torch::Tensor mat, torch::Tensor res,
torch::Tensor scales, torch::Tensor zeros
);
void vecquant4recons(
torch::Tensor mat, torch::Tensor res,
torch::Tensor scales, torch::Tensor zeros
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(mat));
vecquant4recons_cuda(mat, res, scales, zeros);
}
at::Tensor fast_4bit_forward(
torch::Tensor x, torch::Tensor mat, torch::Tensor buffer,
torch::Tensor scales, torch::Tensor zeros, torch::Tensor bias) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(mat));
vecquant4recons_cuda(mat, buffer, scales, zeros);
auto result = at::addmm(bias, x, buffer);
return result;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
@ -101,4 +124,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("vecquant4transposematmul", &vecquant4transposematmul, "Vector 4-bit Transpose Quantized Matrix Multiplication (CUDA)");
m.def("vecquant4matmul_half", &vecquant4matmul_half, "Vector 4-bit Half Quantized Matrix Multiplication (CUDA)");
m.def("vecquant4transposematmul_half", &vecquant4transposematmul_half, "Vector 4-bit Half Transpose Quantized Matrix Multiplication (CUDA)");
m.def("vecquant4recons", &vecquant4recons, "Vector 4-bit Matrix Reconstruction (CUDA)");
m.def("fast_4bit_forward", &fast_4bit_forward, "Vector 4-bit Fast Forward");
}

View File

@ -647,4 +647,55 @@ void vecquant4transposematmul_half_cuda(
);
})
));
}
}
template <typename scalar_t>
__global__ void VecQuant4ReconsKernel(
const int* __restrict__ mat,
scalar_t* __restrict__ res,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT4 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
int n_rows = h * 8 + b;
int n_cols = w;
scalar_t scale = scales[w];
scalar_t zero = zeros[w];
int i = width * h + width * (b / 8) + w;
int shift = b % 8 * 4;
unsigned int tmp = as_unsigned(mat[i]);
scalar_t result = (scale * scalar_t((tmp >> shift) & 0xF) - zero);
res[n_rows * width + n_cols] = result;
}
void vecquant4recons_cuda(
torch::Tensor mat,
torch::Tensor res,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = BLOCKWIDTH;
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
scales.type(), "vecquant4recons_cuda", ([&] {
VecQuant4ReconsKernel<<<blocks, threads>>>(
mat.data<int>(), res.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
height, width
);
})
);
}