Fix repos.
This commit is contained in:
parent
8e705eddcb
commit
17f3da744c
|
|
@ -1,3 +1,5 @@
|
|||
alpaca_lora/
|
||||
repository/
|
||||
__pycache__/
|
||||
llama-13b-4bit
|
||||
llama-13b-4bit.pt
|
||||
|
|
@ -4,3 +4,4 @@
|
|||
[submodule "repos/GPTQ-for-LLaMa"]
|
||||
path = repos/GPTQ-for-LLaMa
|
||||
url = https://github.com/sterlind/GPTQ-for-LLaMa.git
|
||||
branch = lora_4bit
|
||||
|
|
|
|||
|
|
@ -1,129 +0,0 @@
|
|||
#include <torch/all.h>
|
||||
#include <torch/python.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <vector>
|
||||
|
||||
void vecquant2matmul_cuda(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
);
|
||||
|
||||
void vecquant2matmul(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant2matmul_cuda(vec, mat, mul, scales, zeros);
|
||||
}
|
||||
|
||||
void vecquant3matmul_cuda(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
);
|
||||
|
||||
void vecquant3matmul(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant3matmul_cuda(vec, mat, mul, scales, zeros);
|
||||
}
|
||||
|
||||
void vecquant4matmul_cuda(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
);
|
||||
|
||||
void vecquant4matmul(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant4matmul_cuda(vec, mat, mul, scales, zeros);
|
||||
}
|
||||
|
||||
void vecquant8matmul_cuda(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
);
|
||||
|
||||
void vecquant8matmul(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant8matmul_cuda(vec, mat, mul, scales, zeros);
|
||||
}
|
||||
|
||||
void vecquant4transposematmul_cuda(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
);
|
||||
|
||||
void vecquant4transposematmul(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant4transposematmul_cuda(vec, mat, mul, scales, zeros);
|
||||
}
|
||||
|
||||
void vecquant4matmul_half_cuda(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
);
|
||||
|
||||
void vecquant4matmul_half(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant4matmul_half_cuda(vec, mat, mul, scales, zeros);
|
||||
}
|
||||
|
||||
void vecquant4transposematmul_half_cuda(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
);
|
||||
|
||||
void vecquant4transposematmul_half(
|
||||
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
vecquant4transposematmul_half_cuda(vec, mat, mul, scales, zeros);
|
||||
}
|
||||
|
||||
void vecquant4recons_cuda(
|
||||
torch::Tensor mat, torch::Tensor res,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
);
|
||||
|
||||
void vecquant4recons(
|
||||
torch::Tensor mat, torch::Tensor res,
|
||||
torch::Tensor scales, torch::Tensor zeros
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(mat));
|
||||
vecquant4recons_cuda(mat, res, scales, zeros);
|
||||
}
|
||||
|
||||
at::Tensor fast_4bit_forward(
|
||||
torch::Tensor x, torch::Tensor mat, torch::Tensor buffer,
|
||||
torch::Tensor scales, torch::Tensor zeros, torch::Tensor bias) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(mat));
|
||||
vecquant4recons_cuda(mat, buffer, scales, zeros);
|
||||
auto result = at::addmm(bias, x, buffer);
|
||||
return result;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant4transposematmul", &vecquant4transposematmul, "Vector 4-bit Transpose Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant4matmul_half", &vecquant4matmul_half, "Vector 4-bit Half Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant4transposematmul_half", &vecquant4transposematmul_half, "Vector 4-bit Half Transpose Quantized Matrix Multiplication (CUDA)");
|
||||
m.def("vecquant4recons", &vecquant4recons, "Vector 4-bit Matrix Reconstruction (CUDA)");
|
||||
m.def("fast_4bit_forward", &fast_4bit_forward, "Vector 4-bit Fast Forward");
|
||||
}
|
||||
|
|
@ -1,718 +0,0 @@
|
|||
#include <torch/all.h>
|
||||
#include <torch/python.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600
|
||||
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
|
||||
__device__ __forceinline__ void atomicAddHalf(__half* address, c10::Half val) {
|
||||
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
unsigned short hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
|
||||
hsum += val;
|
||||
old = reinterpret_cast<size_t>(address) & 2
|
||||
? (old & 0xffff) | (hsum << 16)
|
||||
: (old & 0xffff0000) | hsum;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
|
||||
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
||||
} while (assumed != old);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant2MatMulKernel(
|
||||
const scalar_t* __restrict__ vec,
|
||||
const int* __restrict__ mat,
|
||||
scalar_t* __restrict__ mul,
|
||||
const scalar_t* __restrict__ scales,
|
||||
const scalar_t* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width
|
||||
);
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant3MatMulKernel(
|
||||
const scalar_t* __restrict__ vec,
|
||||
const int* __restrict__ mat,
|
||||
scalar_t* __restrict__ mul,
|
||||
const scalar_t* __restrict__ scales,
|
||||
const scalar_t* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width
|
||||
);
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant4MatMulKernel(
|
||||
const scalar_t* __restrict__ vec,
|
||||
const int* __restrict__ mat,
|
||||
scalar_t* __restrict__ mul,
|
||||
const scalar_t* __restrict__ scales,
|
||||
const scalar_t* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width
|
||||
);
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant8MatMulKernel(
|
||||
const scalar_t* __restrict__ vec,
|
||||
const int* __restrict__ mat,
|
||||
scalar_t* __restrict__ mul,
|
||||
const scalar_t* __restrict__ scales,
|
||||
const scalar_t* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width
|
||||
);
|
||||
|
||||
const int BLOCKWIDTH = 256;
|
||||
const int BLOCKHEIGHT2 = 16;
|
||||
const int BLOCKHEIGHT3 = 24;
|
||||
const int BLOCKHEIGHT4 = 32;
|
||||
const int BLOCKHEIGHT8 = 64;
|
||||
|
||||
__device__ inline unsigned int as_unsigned(int i) {
|
||||
return *reinterpret_cast<unsigned int*>(&i);
|
||||
}
|
||||
|
||||
void vecquant2matmul_cuda(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor zeros
|
||||
) {
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
batch
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
vec.type(), "vecquant2matmul_cuda", ([&] {
|
||||
VecQuant2MatMulKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
||||
batch, vec_height, height, width
|
||||
);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant2MatMulKernel(
|
||||
const scalar_t* __restrict__ vec,
|
||||
const int* __restrict__ mat,
|
||||
scalar_t* __restrict__ mul,
|
||||
const scalar_t* __restrict__ scales,
|
||||
const scalar_t* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width
|
||||
) {
|
||||
int b = blockIdx.z;
|
||||
int h = BLOCKHEIGHT2 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
||||
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT2) * BLOCKWIDTH + threadIdx.x];
|
||||
__syncthreads();
|
||||
|
||||
scalar_t scale = scales[w];
|
||||
scalar_t zero = zeros[w];
|
||||
|
||||
scalar_t res = 0;
|
||||
int i = width * h + w;
|
||||
int k = 0;
|
||||
|
||||
unsigned int tmp;
|
||||
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp = as_unsigned(mat[i]);
|
||||
res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2];
|
||||
res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3];
|
||||
res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4];
|
||||
res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5];
|
||||
res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6];
|
||||
res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7];
|
||||
res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8];
|
||||
res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9];
|
||||
res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10];
|
||||
res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11];
|
||||
res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12];
|
||||
res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13];
|
||||
res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14];
|
||||
res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15];
|
||||
i += width;
|
||||
k += 16;
|
||||
}
|
||||
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
}
|
||||
|
||||
void vecquant3matmul_cuda(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor zeros
|
||||
) {
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
batch
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
vec.type(), "vecquant3matmul_cuda", ([&] {
|
||||
VecQuant3MatMulKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
||||
batch, vec_height, height, width
|
||||
);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant3MatMulKernel(
|
||||
const scalar_t* __restrict__ vec,
|
||||
const int* __restrict__ mat,
|
||||
scalar_t* __restrict__ mul,
|
||||
const scalar_t* __restrict__ scales,
|
||||
const scalar_t* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width
|
||||
) {
|
||||
int b = blockIdx.z;
|
||||
int h = BLOCKHEIGHT3 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
||||
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT3) * BLOCKWIDTH + threadIdx.x];
|
||||
__syncthreads();
|
||||
|
||||
scalar_t scale = scales[w];
|
||||
scalar_t zero = zeros[w];
|
||||
|
||||
scalar_t res = 0;
|
||||
int i = width * h + w;
|
||||
int k = 0;
|
||||
|
||||
unsigned int tmp1;
|
||||
unsigned int tmp2;
|
||||
unsigned int tmp;
|
||||
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
|
||||
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
|
||||
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
|
||||
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
|
||||
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
|
||||
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
|
||||
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
|
||||
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
|
||||
i += width;
|
||||
tmp2 = as_unsigned(mat[i]);
|
||||
tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4);
|
||||
tmp2 >>= 1;
|
||||
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
|
||||
k += 11;
|
||||
res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2];
|
||||
res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3];
|
||||
res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4];
|
||||
res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5];
|
||||
res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6];
|
||||
res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7];
|
||||
res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8];
|
||||
res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9];
|
||||
i += width;
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6);
|
||||
tmp1 >>= 2;
|
||||
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
|
||||
k += 11;
|
||||
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
|
||||
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
|
||||
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
|
||||
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
|
||||
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
|
||||
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
|
||||
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
|
||||
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
|
||||
i += width;
|
||||
k += 10;
|
||||
}
|
||||
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
}
|
||||
|
||||
void vecquant4matmul_cuda(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor zeros
|
||||
) {
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
batch
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
vec.type(), "vecquant4matmul_cuda", ([&] {
|
||||
VecQuant4MatMulKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
||||
batch, vec_height, height, width
|
||||
);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant4MatMulKernel(
|
||||
const scalar_t* __restrict__ vec,
|
||||
const int* __restrict__ mat,
|
||||
scalar_t* __restrict__ mul,
|
||||
const scalar_t* __restrict__ scales,
|
||||
const scalar_t* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width
|
||||
) {
|
||||
int b = blockIdx.z;
|
||||
int h = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
||||
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x];
|
||||
__syncthreads();
|
||||
|
||||
scalar_t scale = scales[w];
|
||||
scalar_t zero = zeros[w];
|
||||
|
||||
scalar_t res = 0;
|
||||
int i = width * h + w;
|
||||
int k = 0;
|
||||
|
||||
unsigned int tmp;
|
||||
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp = as_unsigned(mat[i]);
|
||||
res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2];
|
||||
res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3];
|
||||
res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4];
|
||||
res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5];
|
||||
res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6];
|
||||
res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7];
|
||||
i += width;
|
||||
k += 8;
|
||||
}
|
||||
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
}
|
||||
|
||||
void vecquant8matmul_cuda(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor zeros
|
||||
) {
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
batch
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
vec.type(), "vecquant8matmul_cuda", ([&] {
|
||||
VecQuant8MatMulKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
||||
batch, vec_height, height, width
|
||||
);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant8MatMulKernel(
|
||||
const scalar_t* __restrict__ vec,
|
||||
const int* __restrict__ mat,
|
||||
scalar_t* __restrict__ mul,
|
||||
const scalar_t* __restrict__ scales,
|
||||
const scalar_t* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width
|
||||
) {
|
||||
int b = blockIdx.z;
|
||||
int h = BLOCKHEIGHT8 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
||||
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT8) * BLOCKWIDTH + threadIdx.x];
|
||||
__syncthreads();
|
||||
|
||||
scalar_t scale = scales[w];
|
||||
scalar_t zero = zeros[w];
|
||||
|
||||
scalar_t res = 0;
|
||||
int i = width * h + w;
|
||||
int k = 0;
|
||||
|
||||
unsigned int tmp;
|
||||
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp = as_unsigned(mat[i]);
|
||||
res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2];
|
||||
res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3];
|
||||
i += width;
|
||||
k += 4;
|
||||
}
|
||||
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant4TransposeMatMulKernel(
|
||||
const scalar_t* __restrict__ vec,
|
||||
const int* __restrict__ mat,
|
||||
scalar_t* __restrict__ mul,
|
||||
const scalar_t* __restrict__ scales,
|
||||
const scalar_t* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width
|
||||
) {
|
||||
int b = blockIdx.z;
|
||||
int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8;
|
||||
unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4);
|
||||
int w = BLOCKWIDTH * blockIdx.y;
|
||||
|
||||
int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x;
|
||||
int n_cols = b;
|
||||
|
||||
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
||||
blockvec[threadIdx.x] = vec[n_cols * vec_height + w + threadIdx.x];
|
||||
__syncthreads();
|
||||
|
||||
scalar_t res = 0;
|
||||
int i = width * h + w;
|
||||
int k = 0;
|
||||
int j = w;
|
||||
unsigned int tmp;
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp = as_unsigned(mat[i]);
|
||||
res += (scales[j] * scalar_t((tmp >> shift) & 0xF) - zeros[j]) * blockvec[k];
|
||||
i += 1;
|
||||
j += 1;
|
||||
k += 1;
|
||||
}
|
||||
|
||||
atomicAdd(&mul[n_cols * height * 8 + n_rows], res);
|
||||
}
|
||||
|
||||
void vecquant4transposematmul_cuda(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor zeros
|
||||
) {
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
batch
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
vec.type(), "vecquant4transposematmul_cuda", ([&] {
|
||||
VecQuant4TransposeMatMulKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
||||
batch, vec_height, height, width
|
||||
);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant4MatMulHalfKernel(
|
||||
const scalar_t* __restrict__ vec,
|
||||
const int* __restrict__ mat,
|
||||
scalar_t* __restrict__ mul,
|
||||
const scalar_t* __restrict__ scales,
|
||||
const scalar_t* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width
|
||||
) {
|
||||
int b = blockIdx.z;
|
||||
int h = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
__shared__ __half blockvec[BLOCKWIDTH];
|
||||
blockvec[threadIdx.x] = __half(vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x]);
|
||||
__syncthreads();
|
||||
|
||||
__half scale = __half(scales[w]);
|
||||
__half zero = __half(zeros[w]);
|
||||
|
||||
__half res = __float2half(0.0f);
|
||||
int i = width * h + w;
|
||||
int k = 0;
|
||||
|
||||
unsigned int tmp;
|
||||
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp = as_unsigned(mat[i]);
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 0) & 0xF)), zero), blockvec[k + 0]));
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 4) & 0xF)), zero), blockvec[k + 1]));
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 8) & 0xF)), zero), blockvec[k + 2]));
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 12) & 0xF)), zero), blockvec[k + 3]));
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 16) & 0xF)), zero), blockvec[k + 4]));
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 20) & 0xF)), zero), blockvec[k + 5]));
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 24) & 0xF)), zero), blockvec[k + 6]));
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> 28) & 0xF)), zero), blockvec[k + 7]));
|
||||
i += width;
|
||||
k += 8;
|
||||
}
|
||||
|
||||
__half* mul2 = (__half*)mul;
|
||||
#ifdef __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600
|
||||
atomicAddHalf(&mul2[b * width + w], res);
|
||||
#else
|
||||
atomicAdd(&mul2[b * width + w], res);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
void vecquant4matmul_half_cuda(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor zeros
|
||||
) {
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
batch
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
AT_DISPATCH_SWITCH(vec.type(), "vecquant4matmul_half_cuda",
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
|
||||
VecQuant4MatMulHalfKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
||||
batch, vec_height, height, width
|
||||
);
|
||||
})
|
||||
));
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant4TransposeMatMulHalfKernel(
|
||||
const scalar_t* __restrict__ vec,
|
||||
const int* __restrict__ mat,
|
||||
scalar_t* __restrict__ mul,
|
||||
const scalar_t* __restrict__ scales,
|
||||
const scalar_t* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width
|
||||
) {
|
||||
int b = blockIdx.z;
|
||||
int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8;
|
||||
unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4);
|
||||
int w = BLOCKWIDTH * blockIdx.y;
|
||||
|
||||
int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x;
|
||||
int n_cols = b;
|
||||
|
||||
__shared__ __half blockvec[BLOCKWIDTH];
|
||||
blockvec[threadIdx.x] = __half(vec[n_cols * vec_height + w + threadIdx.x]);
|
||||
__syncthreads();
|
||||
|
||||
__half res = __float2half(0.0f);
|
||||
int i = width * h + w;
|
||||
int k = 0;
|
||||
int j = w;
|
||||
unsigned int tmp;
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp = as_unsigned(mat[i]);
|
||||
__half zero = __half(zeros[j]);
|
||||
__half scale = __half(scales[j]);
|
||||
res = __hadd(res, __hmul(__hsub(__hmul(scale, __int2half_rn((tmp >> shift) & 0xF)), zero), blockvec[k]));
|
||||
i += 1;
|
||||
j += 1;
|
||||
k += 1;
|
||||
}
|
||||
|
||||
__half* mul2 = (__half*)mul;
|
||||
#ifdef __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ < 700 && __CUDA_ARCH__ > 600
|
||||
atomicAddHalf(&mul2[n_cols * height * 8 + n_rows], res);
|
||||
#else
|
||||
atomicAdd(&mul2[n_cols * height * 8 + n_rows], res);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
void vecquant4transposematmul_half_cuda(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor zeros
|
||||
) {
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
batch
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
AT_DISPATCH_SWITCH(vec.type(), "vecquant4transposematmul_half_cuda",
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, ([&] {
|
||||
VecQuant4TransposeMatMulHalfKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
||||
batch, vec_height, height, width
|
||||
);
|
||||
})
|
||||
));
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant4ReconsKernel(
|
||||
const int* __restrict__ mat,
|
||||
scalar_t* __restrict__ res,
|
||||
const scalar_t* __restrict__ scales,
|
||||
const scalar_t* __restrict__ zeros,
|
||||
int height,
|
||||
int width
|
||||
) {
|
||||
int b = blockIdx.z;
|
||||
int h = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
int n_rows = h * 8 + b;
|
||||
int n_cols = w;
|
||||
scalar_t scale = scales[w];
|
||||
scalar_t zero = zeros[w];
|
||||
int i = width * h + width * (b / 8) + w;
|
||||
int shift = b % 8 * 4;
|
||||
unsigned int tmp = as_unsigned(mat[i]);
|
||||
scalar_t result = (scale * scalar_t((tmp >> shift) & 0xF) - zero);
|
||||
res[n_rows * width + n_cols] = result;
|
||||
}
|
||||
|
||||
void vecquant4recons_cuda(
|
||||
torch::Tensor mat,
|
||||
torch::Tensor res,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor zeros
|
||||
) {
|
||||
int batch = BLOCKWIDTH;
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
batch
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
scales.type(), "vecquant4recons_cuda", ([&] {
|
||||
VecQuant4ReconsKernel<<<blocks, threads>>>(
|
||||
mat.data<int>(), res.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<scalar_t>(),
|
||||
height, width
|
||||
);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
import quant
|
||||
from gptq_llama import quant
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
|
@ -201,7 +201,7 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False, device_ma
|
|||
import transformers
|
||||
import accelerate
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
from modelutils import find_layers
|
||||
from gptq_llama.modelutils import find_layers
|
||||
|
||||
print("Loading Model ...")
|
||||
t0 = time.time()
|
||||
|
|
@ -1,697 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import math
|
||||
import re
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from ..utils import PeftConfig, PeftType, transpose
|
||||
|
||||
|
||||
def is_bnb_available():
|
||||
return importlib.util.find_spec("bitsandbytes") is not None
|
||||
|
||||
|
||||
def is_gptq_available():
|
||||
return importlib.util.find_spec("quant") is not None
|
||||
|
||||
|
||||
if is_bnb_available():
|
||||
import bitsandbytes as bnb
|
||||
|
||||
|
||||
if is_gptq_available():
|
||||
import quant
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraConfig(PeftConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`~peft.Lora`].
|
||||
|
||||
Args:
|
||||
r (`int`): Lora attention dimension
|
||||
target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to.
|
||||
lora_alpha (`float`): The alpha parameter for Lora scaling.
|
||||
lora_dropout (`float`): The dropout probability for Lora layers.
|
||||
merge_weights (`bool`):
|
||||
Whether to merge the weights of the Lora layers with the base transformer model in `eval` mode.
|
||||
fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
enable_lora ( `List[bool]`): Used with `lora.MergedLinear`.
|
||||
bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only'
|
||||
modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable
|
||||
and saved in the final checkpoint.
|
||||
"""
|
||||
|
||||
r: int = field(default=8, metadata={"help": "Lora attention dimension"})
|
||||
target_modules: Optional[Union[List[str], str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "List of module names or regex expression of the module names to replace with Lora."
|
||||
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
|
||||
},
|
||||
)
|
||||
lora_alpha: int = field(default=None, metadata={"help": "Lora alpha"})
|
||||
lora_dropout: float = field(default=None, metadata={"help": "Lora dropout"})
|
||||
merge_weights: bool = field(
|
||||
default=False, metadata={"help": "Merge weights of the original model and the Lora model"}
|
||||
)
|
||||
fan_in_fan_out: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
|
||||
)
|
||||
enable_lora: Optional[List[bool]] = field(default=None, metadata={"help": "Used with `lora.MergedLinear`."})
|
||||
bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"})
|
||||
modules_to_save: Optional[List[str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. "
|
||||
"For example, in Sequence Classification or Token Classification tasks, "
|
||||
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.peft_type = PeftType.LORA
|
||||
|
||||
|
||||
class LoraModel(torch.nn.Module):
|
||||
"""
|
||||
Creates Low Rank Adapter (Lora) model from a pretrained transformers model.
|
||||
|
||||
Args:
|
||||
model ([`transformers.PreTrainedModel`]): The model to be adapted.
|
||||
config ([`LoraConfig`]): The configuration of the Lora model.
|
||||
|
||||
Returns:
|
||||
`torch.nn.Module`: The Lora model.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import AutoModelForSeq2SeqLM, LoraConfig >>> from peft import LoraModel, LoraConfig >>>
|
||||
config = LoraConfig(
|
||||
peft_type="LORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"],
|
||||
lora_dropout=0.01, )
|
||||
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> lora_model = LoraModel(config, model)
|
||||
|
||||
**Attributes**:
|
||||
- **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted.
|
||||
- **peft_config** ([`LoraConfig`]): The configuration of the Lora model.
|
||||
"""
|
||||
|
||||
def __init__(self, config, model):
|
||||
super().__init__()
|
||||
self.peft_config = config
|
||||
self.model = model
|
||||
self._find_and_replace()
|
||||
mark_only_lora_as_trainable(self.model, self.peft_config.bias)
|
||||
self.forward = self.model.forward
|
||||
|
||||
def _find_and_replace(self):
|
||||
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
|
||||
if loaded_in_8bit and not is_bnb_available():
|
||||
raise ImportError(
|
||||
"To use Lora with 8-bit quantization, please install the `bitsandbytes` package. "
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
||||
is_target_modules_in_base_model = False
|
||||
is_hf_device_map_available = hasattr(self.model, "hf_device_map")
|
||||
kwargs = {
|
||||
"r": self.peft_config.r,
|
||||
"lora_alpha": self.peft_config.lora_alpha,
|
||||
"lora_dropout": self.peft_config.lora_dropout,
|
||||
"fan_in_fan_out": self.peft_config.fan_in_fan_out,
|
||||
"merge_weights": (self.peft_config.merge_weights or self.peft_config.inference_mode)
|
||||
and not is_hf_device_map_available,
|
||||
}
|
||||
key_list = [key for key, _ in self.model.named_modules()]
|
||||
for key in key_list:
|
||||
if isinstance(self.peft_config.target_modules, str):
|
||||
target_module_found = re.fullmatch(self.peft_config.target_modules, key)
|
||||
else:
|
||||
target_module_found = any(key.endswith(target_key) for target_key in self.peft_config.target_modules)
|
||||
if target_module_found:
|
||||
if not is_target_modules_in_base_model:
|
||||
is_target_modules_in_base_model = True
|
||||
parent, target, target_name = self._get_submodules(key)
|
||||
bias = target.bias is not None
|
||||
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
|
||||
kwargs.update(
|
||||
{
|
||||
"has_fp16_weights": target.state.has_fp16_weights,
|
||||
"memory_efficient_backward": target.state.memory_efficient_backward,
|
||||
"threshold": target.state.threshold,
|
||||
"index": target.index,
|
||||
}
|
||||
)
|
||||
if self.peft_config.enable_lora is None:
|
||||
new_module = Linear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
|
||||
else:
|
||||
kwargs.update({"enable_lora": self.peft_config.enable_lora})
|
||||
new_module = MergedLinear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
|
||||
elif isinstance(target, torch.nn.Linear) and self.peft_config.enable_lora is None:
|
||||
new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs)
|
||||
elif isinstance(target, Autograd4bitQuantLinear) and self.peft_config.enable_lora is None:
|
||||
new_module = Linear4bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
|
||||
elif self.peft_config.enable_lora is not None:
|
||||
kwargs.update({"enable_lora": self.peft_config.enable_lora})
|
||||
if isinstance(target, Conv1D):
|
||||
in_features, out_features = (
|
||||
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
|
||||
)
|
||||
else:
|
||||
in_features, out_features = target.in_features, target.out_features
|
||||
if kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is not a Conv1D. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = self.peft_config.fan_in_fan_out = False
|
||||
new_module = MergedLinear(in_features, out_features, bias=bias, **kwargs)
|
||||
self._replace_module(parent, target_name, new_module, target)
|
||||
if not is_target_modules_in_base_model:
|
||||
raise ValueError(
|
||||
f"Target modules {self.peft_config.target_modules} not found in the base model. "
|
||||
f"Please check the target modules and try again."
|
||||
)
|
||||
|
||||
def _get_submodules(self, key):
|
||||
parent = self.model.get_submodule(".".join(key.split(".")[:-1]))
|
||||
target_name = key.split(".")[-1]
|
||||
target = self.model.get_submodule(key)
|
||||
return parent, target, target_name
|
||||
|
||||
def _replace_module(self, parent_module, child_name, new_module, old_module):
|
||||
setattr(parent_module, child_name, new_module)
|
||||
if isinstance(old_module, Autograd4bitQuantLinear) and isinstance(new_module, Linear4bitLt):
|
||||
new_module.qweight = old_module.qweight
|
||||
new_module.scales = old_module.scales
|
||||
new_module.zeros = old_module.zeros
|
||||
new_module.bias = old_module.bias
|
||||
if getattr(old_module, "state", None) is not None:
|
||||
new_module.state = old_module.state
|
||||
new_module.to(old_module.qweight.device)
|
||||
|
||||
# dispatch to correct device
|
||||
for name, module in new_module.named_modules():
|
||||
if "lora_" in name:
|
||||
module.to(old_module.qweight.device)
|
||||
else:
|
||||
new_module.weight = old_module.weight
|
||||
if old_module.bias is not None:
|
||||
new_module.bias = old_module.bias
|
||||
if getattr(old_module, "state", None) is not None:
|
||||
new_module.state = old_module.state
|
||||
new_module.to(old_module.weight.device)
|
||||
|
||||
# dispatch to correct device
|
||||
for name, module in new_module.named_modules():
|
||||
if "lora_" in name:
|
||||
module.to(old_module.weight.device)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
"""Forward missing attributes to the wrapped module."""
|
||||
try:
|
||||
return super().__getattr__(name) # defer to nn.Module's logic
|
||||
except AttributeError:
|
||||
return getattr(self.model, name)
|
||||
|
||||
@property
|
||||
def modules_to_save(self):
|
||||
return None
|
||||
|
||||
def get_peft_config_as_dict(self, inference: bool = False):
|
||||
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.peft_config).items()}
|
||||
if inference:
|
||||
config["inference_mode"] = True
|
||||
return config
|
||||
|
||||
def _set_adapter_layers(self, enabled=True):
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, LoraLayer):
|
||||
module.disable_adapters = False if enabled else True
|
||||
|
||||
def enable_adapter_layers(self):
|
||||
self._set_adapter_layers(enabled=True)
|
||||
|
||||
def disable_adapter_layers(self):
|
||||
self._set_adapter_layers(enabled=False)
|
||||
|
||||
|
||||
# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
||||
# and modified to work with PyTorch FSDP
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# had to adapt it for `lora_only` to work
|
||||
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
|
||||
for n, p in model.named_parameters():
|
||||
if "lora_" not in n:
|
||||
p.requires_grad = False
|
||||
if bias == "none":
|
||||
return
|
||||
elif bias == "all":
|
||||
for n, p in model.named_parameters():
|
||||
if "bias" in n:
|
||||
p.requires_grad = True
|
||||
elif bias == "lora_only":
|
||||
for m in model.modules():
|
||||
if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None:
|
||||
m.bias.requires_grad = True
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LoraLayer:
|
||||
def __init__(
|
||||
self,
|
||||
r: int,
|
||||
lora_alpha: int,
|
||||
lora_dropout: float,
|
||||
merge_weights: bool,
|
||||
):
|
||||
self.r = r
|
||||
self.lora_alpha = lora_alpha
|
||||
# Optional dropout
|
||||
if lora_dropout > 0.0:
|
||||
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
||||
else:
|
||||
self.lora_dropout = lambda x: x
|
||||
# Mark the weight as unmerged
|
||||
self.merged = False
|
||||
self.merge_weights = merge_weights
|
||||
self.disable_adapters = False
|
||||
|
||||
|
||||
class Linear(nn.Linear, LoraLayer):
|
||||
# Lora implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
merge_weights: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
||||
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
|
||||
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
# Actual trainable parameters
|
||||
if r > 0:
|
||||
self.lora_A = nn.Linear(in_features, r, bias=False)
|
||||
self.lora_B = nn.Linear(r, out_features, bias=False)
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
self.reset_parameters()
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.Linear.reset_parameters(self)
|
||||
if hasattr(self, "lora_A"):
|
||||
# initialize A the same way as the default for nn.Linear and B to zero
|
||||
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B.weight)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
nn.Linear.train(self, mode)
|
||||
self.lora_A.train(mode)
|
||||
self.lora_B.train(mode)
|
||||
if not mode and self.merge_weights and not self.merged:
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0:
|
||||
self.weight.data += (
|
||||
transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
|
||||
)
|
||||
self.merged = True
|
||||
elif self.merge_weights and self.merged:
|
||||
# Make sure that the weights are not merged
|
||||
if self.r > 0:
|
||||
self.weight.data -= (
|
||||
transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
|
||||
)
|
||||
self.merged = False
|
||||
|
||||
def eval(self):
|
||||
nn.Linear.eval(self)
|
||||
self.lora_A.eval()
|
||||
self.lora_B.eval()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.disable_adapters:
|
||||
if self.r > 0 and self.merged:
|
||||
self.weight.data -= (
|
||||
transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
|
||||
)
|
||||
self.merged = False
|
||||
|
||||
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
|
||||
elif self.r > 0 and not self.merged:
|
||||
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
|
||||
if self.r > 0:
|
||||
result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
|
||||
return result
|
||||
else:
|
||||
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
|
||||
|
||||
|
||||
class MergedLinear(nn.Linear, LoraLayer):
|
||||
# Lora implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
enable_lora: List[bool] = [False],
|
||||
fan_in_fan_out: bool = False,
|
||||
merge_weights: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
||||
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
|
||||
if out_features % len(enable_lora) != 0:
|
||||
raise ValueError("The length of enable_lora must divide out_features")
|
||||
self.enable_lora = enable_lora
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
# Actual trainable parameters
|
||||
if r > 0 and any(enable_lora):
|
||||
self.lora_A = nn.Linear(in_features, r * sum(enable_lora), bias=False)
|
||||
self.lora_B = nn.Conv1d(
|
||||
r * sum(enable_lora),
|
||||
out_features // len(enable_lora) * sum(enable_lora),
|
||||
kernel_size=1,
|
||||
groups=2,
|
||||
bias=False,
|
||||
)
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
# Compute the indices
|
||||
self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1)
|
||||
self.lora_ind[enable_lora, :] = True
|
||||
self.lora_ind = self.lora_ind.view(-1)
|
||||
self.reset_parameters()
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.Linear.reset_parameters(self)
|
||||
if hasattr(self, "lora_A"):
|
||||
# initialize A the same way as the default for nn.Linear and B to zero
|
||||
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B.weight)
|
||||
|
||||
def zero_pad(self, x):
|
||||
result = x.new_zeros((*x.shape[:-1], self.out_features))
|
||||
result = result.view(-1, self.out_features)
|
||||
result[:, self.lora_ind] = x.reshape(-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora))
|
||||
return result.view((*x.shape[:-1], self.out_features))
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
nn.Linear.train(self, mode)
|
||||
self.lora_A.train(mode)
|
||||
self.lora_B.train(mode)
|
||||
if not mode and self.merge_weights and not self.merged:
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0 and any(self.enable_lora):
|
||||
delta_w = (
|
||||
F.conv1d(
|
||||
self.lora_A.weight.data.unsqueeze(0),
|
||||
self.lora_B.weight.data,
|
||||
groups=sum(self.enable_lora),
|
||||
)
|
||||
.squeeze(0)
|
||||
.transpose(-2, -1)
|
||||
)
|
||||
self.weight.data += transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out)
|
||||
self.merged = True
|
||||
elif self.merge_weights and self.merged:
|
||||
# Make sure that the weights are not merged
|
||||
if self.r > 0 and any(self.enable_lora):
|
||||
delta_w = (
|
||||
F.conv1d(
|
||||
self.lora_A.weight.data.unsqueeze(0),
|
||||
self.lora_B.weight.data,
|
||||
groups=sum(self.enable_lora),
|
||||
)
|
||||
.squeeze(0)
|
||||
.transpose(-2, -1)
|
||||
)
|
||||
self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out)
|
||||
self.merged = False
|
||||
|
||||
def eval(self):
|
||||
nn.Linear.eval(self)
|
||||
self.lora_A.eval()
|
||||
self.lora_B.eval()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.disable_adapters:
|
||||
if self.r > 0 and self.merged and any(self.enable_lora):
|
||||
delta_w = (
|
||||
F.conv1d(
|
||||
self.lora_A.weight.data.unsqueeze(0),
|
||||
self.lora_B.weight.data,
|
||||
groups=sum(self.enable_lora),
|
||||
)
|
||||
.squeeze(0)
|
||||
.transpose(-2, -1)
|
||||
)
|
||||
self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out)
|
||||
self.merged = False
|
||||
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
|
||||
elif self.merged:
|
||||
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
|
||||
else:
|
||||
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
|
||||
if self.r > 0:
|
||||
after_A = self.lora_A(self.lora_dropout(x))
|
||||
after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
|
||||
result += self.zero_pad(after_B) * self.scaling
|
||||
return result
|
||||
|
||||
|
||||
if is_bnb_available():
|
||||
|
||||
class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):
|
||||
# Lora implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
bnb.nn.Linear8bitLt.__init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=kwargs.get("bias", True),
|
||||
has_fp16_weights=kwargs.get("has_fp16_weights", True),
|
||||
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
|
||||
threshold=kwargs.get("threshold", 0.0),
|
||||
index=kwargs.get("index", None),
|
||||
)
|
||||
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
|
||||
# Actual trainable parameters
|
||||
if r > 0:
|
||||
self.lora_A = nn.Linear(in_features, r, bias=False)
|
||||
self.lora_B = nn.Linear(r, out_features, bias=False)
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if hasattr(self, "lora_A"):
|
||||
# initialize A the same way as the default for nn.Linear and B to zero
|
||||
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
result = super().forward(x)
|
||||
|
||||
if self.disable_adapters:
|
||||
return result
|
||||
elif self.r > 0:
|
||||
if not torch.is_autocast_enabled():
|
||||
expected_dtype = result.dtype
|
||||
|
||||
if x.dtype != torch.float32:
|
||||
x = x.float()
|
||||
output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling
|
||||
result += output
|
||||
else:
|
||||
output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
|
||||
result += output
|
||||
return result
|
||||
|
||||
class MergedLinear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):
|
||||
# Lora implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
enable_lora: List[bool] = [False],
|
||||
**kwargs,
|
||||
):
|
||||
bnb.nn.Linear8bitLt.__init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=kwargs.get("bias", True),
|
||||
has_fp16_weights=kwargs.get("has_fp16_weights", True),
|
||||
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
|
||||
threshold=kwargs.get("threshold", 0.0),
|
||||
index=kwargs.get("index", None),
|
||||
)
|
||||
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
|
||||
if out_features % len(enable_lora) != 0:
|
||||
raise ValueError("The length of enable_lora must divide out_features")
|
||||
self.enable_lora = enable_lora
|
||||
# Actual trainable parameters
|
||||
if r > 0 and any(enable_lora):
|
||||
self.lora_A = nn.Linear(in_features, r * sum(enable_lora), bias=False)
|
||||
self.lora_B = nn.Conv1d(
|
||||
r * sum(enable_lora),
|
||||
out_features // len(enable_lora) * sum(enable_lora),
|
||||
kernel_size=1,
|
||||
groups=2,
|
||||
bias=False,
|
||||
)
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
# Compute the indices
|
||||
self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1)
|
||||
self.lora_ind[enable_lora, :] = True
|
||||
self.lora_ind = self.lora_ind.view(-1)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if hasattr(self, "lora_A"):
|
||||
# initialize A the same way as the default for nn.Linear and B to zero
|
||||
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B.weight)
|
||||
|
||||
def zero_pad(self, x):
|
||||
result = x.new_zeros((*x.shape[:-1], self.out_features))
|
||||
result = result.view(-1, self.out_features)
|
||||
result[:, self.lora_ind] = x.reshape(
|
||||
-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
|
||||
)
|
||||
return result.view((*x.shape[:-1], self.out_features))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
result = super().forward(x)
|
||||
if self.disable_adapters:
|
||||
return result
|
||||
elif self.r > 0:
|
||||
if not torch.is_autocast_enabled():
|
||||
expected_dtype = result.dtype
|
||||
if x.dtype != torch.float32:
|
||||
x = x.float()
|
||||
after_A = self.lora_A(self.lora_dropout(x))
|
||||
after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
|
||||
output = self.zero_pad(after_B).to(expected_dtype) * self.scaling
|
||||
result += output
|
||||
else:
|
||||
after_A = self.lora_A(self.lora_dropout(x))
|
||||
after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
|
||||
output = self.zero_pad(after_B) * self.scaling
|
||||
result += output
|
||||
return result
|
||||
|
||||
if is_gptq_available():
|
||||
|
||||
from autograd_4bit import Autograd4bitQuantLinear
|
||||
|
||||
class Linear4bitLt(Autograd4bitQuantLinear, LoraLayer):
|
||||
# Lora implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
Autograd4bitQuantLinear.__init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features
|
||||
)
|
||||
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
|
||||
# Actual trainable parameters
|
||||
if r > 0:
|
||||
self.lora_A = nn.Linear(in_features, r, bias=False)
|
||||
self.lora_B = nn.Linear(r, out_features, bias=False)
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.qweight.requires_grad = False
|
||||
self.scales.requires_grad = False
|
||||
self.zeros.requires_grad = False
|
||||
self.bias.requires_grad = False
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if hasattr(self, "lora_A"):
|
||||
# initialize A the same way as the default for nn.Linear and B to zero
|
||||
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
result = super().forward(x)
|
||||
|
||||
if self.disable_adapters:
|
||||
return result
|
||||
elif self.r > 0:
|
||||
if not torch.is_autocast_enabled():
|
||||
expected_dtype = result.dtype
|
||||
|
||||
if x.dtype != torch.float32:
|
||||
x = x.float()
|
||||
output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling
|
||||
result += output
|
||||
else:
|
||||
output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
|
||||
result += output
|
||||
return result
|
||||
Loading…
Reference in New Issue