add patch for gptq and peft

This commit is contained in:
John Smith 2023-03-18 13:31:48 +08:00 committed by GitHub
parent 326bc9214a
commit 551f62a0e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 1403 additions and 0 deletions

View File

@ -0,0 +1,144 @@
import quant
import torch
import numpy as np
import torch.nn as nn
def matmul4bit(x, qweight, scales, zeros):
"""
input x: (n, m)
qweight: (j, k)
where m == j*8
perform x @ qweight
return y:
"""
assert qweight.shape[0] * 8 == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]])
x = x.reshape(-1, x.shape[-1])
assert x.shape[0] % 256 == 0
y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device)
dtype = x.dtype
x = x.float()
quant.quant_cuda.vecquant4matmul(x, qweight, y, scales, zeros)
y = y.to(dtype)
return y.reshape(outshape)
def matmul4bit_transpose(x, qweight, scales, zeros):
"""
input x: (n, m)
qweight: (j, k)
where m == k
perform qweight @ x.T
return y:
"""
assert qweight.shape[1] == x.shape[-1]
outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8])
x = x.reshape(-1, x.shape[-1])
assert x.shape[0] % 256 == 0
y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=torch.float32, device=x.device)
dtype = x.dtype
x = x.float()
quant.quant_cuda.vecquant4transposematmul(x, qweight, y, scales, zeros)
y = y.to(dtype)
return y.reshape(outshape)
class AutogradMatmul4bit(torch.autograd.Function):
@staticmethod
def forward(ctx, x, qweight, scales, zeros):
ctx.save_for_backward(x, qweight, scales, zeros)
output = matmul4bit(x, qweight, scales, zeros).clone()
return output # equals to torch.matmul(x, qweight)
@staticmethod
def backward(ctx, grad_output):
x, qweight, scales, zeros = ctx.saved_tensors
# print(grad_output.shape, A.shape, B.shape)
# compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T
grad1 = matmul4bit_transpose(grad_output, qweight, scales, zeros)
grad2 = torch.matmul(x.transpose(-1, -2), grad_output)
return grad1, grad2, None, None
# Assumes layer is perfectly divisible into 256 * 256 blocks
class Autograd4bitQuantLinear(nn.Module):
def __init__(self, infeatures, outfeatures):
super().__init__()
bits = 4
self.in_features = infeatures
self.out_features = outfeatures
self.bits = bits
self.register_buffer('zeros', torch.empty((outfeatures, 1)))
self.register_buffer('scales', torch.empty((outfeatures, 1)))
self.register_buffer('bias', torch.empty(outfeatures))
self.register_buffer(
'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)
)
def forward(self, x):
out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, self.zeros)
out += self.bias
return out
def make_quant_for_4bit_autograd(module, names, name=''):
if isinstance(module, Autograd4bitQuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
setattr(
module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features)
)
for name1, child in module.named_children():
make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1)
def load_llama_model_4bit_low_ram(config_path, model_path):
import transformers
import accelerate
from transformers import LLaMAConfig, LLaMAForCausalLM, LLaMATokenizer
from modelutils import find_layers
print("Loading Model ...")
t0 = time.time()
with accelerate.init_empty_weights():
config = LLaMAConfig.from_pretrained(config_path)
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = LLaMAForCausalLM(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant_for_4bit_autograd(model, layers)
model = accelerate.load_checkpoint_and_dispatch(model=model, checkpoint=model_path, device_map='auto')
model.cuda()
model.seqlen = 2048
tokenizer = LLaMATokenizer.from_pretrained(config_path)
tokenizer.truncation_side = 'left'
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer

View File

@ -0,0 +1,76 @@
#include <torch/all.h>
#include <torch/python.h>
#include <c10/cuda/CUDAGuard.h>
void vecquant2matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
);
void vecquant2matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant2matmul_cuda(vec, mat, mul, scales, zeros);
}
void vecquant3matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
);
void vecquant3matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant3matmul_cuda(vec, mat, mul, scales, zeros);
}
void vecquant4matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
);
void vecquant4matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant4matmul_cuda(vec, mat, mul, scales, zeros);
}
void vecquant8matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
);
void vecquant8matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant8matmul_cuda(vec, mat, mul, scales, zeros);
}
void vecquant4transposematmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
);
void vecquant4transposematmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant4transposematmul_cuda(vec, mat, mul, scales, zeros);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant4transposematmul", &vecquant4transposematmul, "Vector 4-bit Transpose Quantized Matrix Multiplication (CUDA)");
}

View File

@ -0,0 +1,480 @@
#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
template <typename scalar_t>
__global__ void VecQuant2MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
);
template <typename scalar_t>
__global__ void VecQuant3MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
);
template <typename scalar_t>
__global__ void VecQuant4MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
);
template <typename scalar_t>
__global__ void VecQuant8MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
);
const int BLOCKWIDTH = 256;
const int BLOCKHEIGHT2 = 16;
const int BLOCKHEIGHT3 = 24;
const int BLOCKHEIGHT4 = 32;
const int BLOCKHEIGHT8 = 64;
__device__ inline unsigned int as_unsigned(int i) {
return *reinterpret_cast<unsigned int*>(&i);
}
void vecquant2matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant2matmul_cuda", ([&] {
VecQuant2MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
batch, vec_height, height, width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant2MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT2 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT2) * BLOCKWIDTH + threadIdx.x];
__syncthreads();
scalar_t scale = scales[w];
scalar_t zero = zeros[w];
scalar_t res = 0;
int i = width * h + w;
int k = 0;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]);
res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3];
res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4];
res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7];
res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8];
res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9];
res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10];
res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11];
res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12];
res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13];
res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14];
res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15];
i += width;
k += 16;
}
atomicAdd(&mul[b * width + w], res);
}
void vecquant3matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant3matmul_cuda", ([&] {
VecQuant3MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
batch, vec_height, height, width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant3MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT3 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT3) * BLOCKWIDTH + threadIdx.x];
__syncthreads();
scalar_t scale = scales[w];
scalar_t zero = zeros[w];
scalar_t res = 0;
int i = width * h + w;
int k = 0;
unsigned int tmp1;
unsigned int tmp2;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp1 = as_unsigned(mat[i]);
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
i += width;
tmp2 = as_unsigned(mat[i]);
tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4);
tmp2 >>= 1;
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
k += 11;
res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3];
res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4];
res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7];
res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8];
res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9];
i += width;
tmp1 = as_unsigned(mat[i]);
tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6);
tmp1 >>= 2;
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
k += 11;
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
i += width;
k += 10;
}
atomicAdd(&mul[b * width + w], res);
}
void vecquant4matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant4matmul_cuda", ([&] {
VecQuant4MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
batch, vec_height, height, width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant4MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT4 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x];
__syncthreads();
scalar_t scale = scales[w];
scalar_t zero = zeros[w];
scalar_t res = 0;
int i = width * h + w;
int k = 0;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]);
res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3];
res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4];
res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5];
res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6];
res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7];
i += width;
k += 8;
}
atomicAdd(&mul[b * width + w], res);
}
void vecquant8matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant8matmul_cuda", ([&] {
VecQuant8MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
batch, vec_height, height, width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant8MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT8 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT8) * BLOCKWIDTH + threadIdx.x];
__syncthreads();
scalar_t scale = scales[w];
scalar_t zero = zeros[w];
scalar_t res = 0;
int i = width * h + w;
int k = 0;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]);
res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2];
res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3];
i += width;
k += 4;
}
atomicAdd(&mul[b * width + w], res);
}
template <typename scalar_t>
__global__ void VecQuant4TransposeMatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const scalar_t* __restrict__ zeros,
int batch,
int vec_height,
int height,
int width
) {
int b = blockIdx.z;
int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8;
unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4);
int w = BLOCKWIDTH * blockIdx.y;
int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x;
int n_cols = b;
__shared__ scalar_t blockvec[BLOCKWIDTH];
blockvec[threadIdx.x] = vec[n_cols * vec_height + w + threadIdx.x];
__syncthreads();
scalar_t res = 0;
int i = width * h + w;
int k = 0;
int j = w;
unsigned int tmp;
while (k < BLOCKWIDTH) {
tmp = as_unsigned(mat[i]);
res += (scales[j] * scalar_t((tmp >> shift) & 0xF) - zeros[j]) * blockvec[k];
i += 1;
j += 1;
k += 1;
}
atomicAdd(&mul[n_cols * height * 8 + n_rows], res);
}
void vecquant4transposematmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
batch
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant4transposematmul_cuda", ([&] {
VecQuant4TransposeMatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<scalar_t>(),
batch, vec_height, height, width
);
})
);
}

697
peft/tuners/lora.py Normal file
View File

@ -0,0 +1,697 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import math
import re
import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.pytorch_utils import Conv1D
from ..utils import PeftConfig, PeftType, transpose
def is_bnb_available():
return importlib.util.find_spec("bitsandbytes") is not None
def is_gptq_available():
return importlib.util.find_spec("quant") is not None
if is_bnb_available():
import bitsandbytes as bnb
if is_gptq_available():
import quant
@dataclass
class LoraConfig(PeftConfig):
"""
This is the configuration class to store the configuration of a [`~peft.Lora`].
Args:
r (`int`): Lora attention dimension
target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to.
lora_alpha (`float`): The alpha parameter for Lora scaling.
lora_dropout (`float`): The dropout probability for Lora layers.
merge_weights (`bool`):
Whether to merge the weights of the Lora layers with the base transformer model in `eval` mode.
fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out)
enable_lora ( `List[bool]`): Used with `lora.MergedLinear`.
bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only'
modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable
and saved in the final checkpoint.
"""
r: int = field(default=8, metadata={"help": "Lora attention dimension"})
target_modules: Optional[Union[List[str], str]] = field(
default=None,
metadata={
"help": "List of module names or regex expression of the module names to replace with Lora."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
},
)
lora_alpha: int = field(default=None, metadata={"help": "Lora alpha"})
lora_dropout: float = field(default=None, metadata={"help": "Lora dropout"})
merge_weights: bool = field(
default=False, metadata={"help": "Merge weights of the original model and the Lora model"}
)
fan_in_fan_out: bool = field(
default=False,
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
)
enable_lora: Optional[List[bool]] = field(default=None, metadata={"help": "Used with `lora.MergedLinear`."})
bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"})
modules_to_save: Optional[List[str]] = field(
default=None,
metadata={
"help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. "
"For example, in Sequence Classification or Token Classification tasks, "
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
},
)
def __post_init__(self):
self.peft_type = PeftType.LORA
class LoraModel(torch.nn.Module):
"""
Creates Low Rank Adapter (Lora) model from a pretrained transformers model.
Args:
model ([`transformers.PreTrainedModel`]): The model to be adapted.
config ([`LoraConfig`]): The configuration of the Lora model.
Returns:
`torch.nn.Module`: The Lora model.
Example::
>>> from transformers import AutoModelForSeq2SeqLM, LoraConfig >>> from peft import LoraModel, LoraConfig >>>
config = LoraConfig(
peft_type="LORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"],
lora_dropout=0.01, )
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> lora_model = LoraModel(config, model)
**Attributes**:
- **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted.
- **peft_config** ([`LoraConfig`]): The configuration of the Lora model.
"""
def __init__(self, config, model):
super().__init__()
self.peft_config = config
self.model = model
self._find_and_replace()
mark_only_lora_as_trainable(self.model, self.peft_config.bias)
self.forward = self.model.forward
def _find_and_replace(self):
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
if loaded_in_8bit and not is_bnb_available():
raise ImportError(
"To use Lora with 8-bit quantization, please install the `bitsandbytes` package. "
"You can install it with `pip install bitsandbytes`."
)
is_target_modules_in_base_model = False
is_hf_device_map_available = hasattr(self.model, "hf_device_map")
kwargs = {
"r": self.peft_config.r,
"lora_alpha": self.peft_config.lora_alpha,
"lora_dropout": self.peft_config.lora_dropout,
"fan_in_fan_out": self.peft_config.fan_in_fan_out,
"merge_weights": (self.peft_config.merge_weights or self.peft_config.inference_mode)
and not is_hf_device_map_available,
}
key_list = [key for key, _ in self.model.named_modules()]
for key in key_list:
if isinstance(self.peft_config.target_modules, str):
target_module_found = re.fullmatch(self.peft_config.target_modules, key)
else:
target_module_found = any(key.endswith(target_key) for target_key in self.peft_config.target_modules)
if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = self._get_submodules(key)
bias = target.bias is not None
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
"memory_efficient_backward": target.state.memory_efficient_backward,
"threshold": target.state.threshold,
"index": target.index,
}
)
if self.peft_config.enable_lora is None:
new_module = Linear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
else:
kwargs.update({"enable_lora": self.peft_config.enable_lora})
new_module = MergedLinear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
elif isinstance(target, torch.nn.Linear) and self.peft_config.enable_lora is None:
new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs)
elif isinstance(target, Autograd4bitQuantLinear) and self.peft_config.enable_lora is None:
new_module = Linear4bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
elif self.peft_config.enable_lora is not None:
kwargs.update({"enable_lora": self.peft_config.enable_lora})
if isinstance(target, Conv1D):
in_features, out_features = (
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
)
else:
in_features, out_features = target.in_features, target.out_features
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is not a Conv1D. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = self.peft_config.fan_in_fan_out = False
new_module = MergedLinear(in_features, out_features, bias=bias, **kwargs)
self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
f"Target modules {self.peft_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)
def _get_submodules(self, key):
parent = self.model.get_submodule(".".join(key.split(".")[:-1]))
target_name = key.split(".")[-1]
target = self.model.get_submodule(key)
return parent, target, target_name
def _replace_module(self, parent_module, child_name, new_module, old_module):
setattr(parent_module, child_name, new_module)
if isinstance(old_module, Autograd4bitQuantLinear) and isinstance(new_module, Linear4bitLt):
new_module.qweight = old_module.qweight
new_module.scales = old_module.scales
new_module.zeros = old_module.zeros
new_module.bias = old_module.bias
if getattr(old_module, "state", None) is not None:
new_module.state = old_module.state
new_module.to(old_module.qweight.device)
# dispatch to correct device
for name, module in new_module.named_modules():
if "lora_" in name:
module.to(old_module.qweight.device)
else:
new_module.weight = old_module.weight
if old_module.bias is not None:
new_module.bias = old_module.bias
if getattr(old_module, "state", None) is not None:
new_module.state = old_module.state
new_module.to(old_module.weight.device)
# dispatch to correct device
for name, module in new_module.named_modules():
if "lora_" in name:
module.to(old_module.weight.device)
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.model, name)
@property
def modules_to_save(self):
return None
def get_peft_config_as_dict(self, inference: bool = False):
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.peft_config).items()}
if inference:
config["inference_mode"] = True
return config
def _set_adapter_layers(self, enabled=True):
for module in self.model.modules():
if isinstance(module, LoraLayer):
module.disable_adapters = False if enabled else True
def enable_adapter_layers(self):
self._set_adapter_layers(enabled=True)
def disable_adapter_layers(self):
self._set_adapter_layers(enabled=False)
# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# and modified to work with PyTorch FSDP
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
# had to adapt it for `lora_only` to work
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
for n, p in model.named_parameters():
if "lora_" not in n:
p.requires_grad = False
if bias == "none":
return
elif bias == "all":
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias == "lora_only":
for m in model.modules():
if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError
class LoraLayer:
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
self.disable_adapters = False
class Linear(nn.Linear, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
**kwargs,
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Linear(in_features, r, bias=False)
self.lora_B = nn.Linear(r, out_features, bias=False)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def train(self, mode: bool = True):
nn.Linear.train(self, mode)
self.lora_A.train(mode)
self.lora_B.train(mode)
if not mode and self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += (
transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
)
self.merged = True
elif self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= (
transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
)
self.merged = False
def eval(self):
nn.Linear.eval(self)
self.lora_A.eval()
self.lora_B.eval()
def forward(self, x: torch.Tensor):
if self.disable_adapters:
if self.r > 0 and self.merged:
self.weight.data -= (
transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
)
self.merged = False
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
elif self.r > 0 and not self.merged:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if self.r > 0:
result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
return result
else:
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
class MergedLinear(nn.Linear, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
enable_lora: List[bool] = [False],
fan_in_fan_out: bool = False,
merge_weights: bool = True,
**kwargs,
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
if out_features % len(enable_lora) != 0:
raise ValueError("The length of enable_lora must divide out_features")
self.enable_lora = enable_lora
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0 and any(enable_lora):
self.lora_A = nn.Linear(in_features, r * sum(enable_lora), bias=False)
self.lora_B = nn.Conv1d(
r * sum(enable_lora),
out_features // len(enable_lora) * sum(enable_lora),
kernel_size=1,
groups=2,
bias=False,
)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# Compute the indices
self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1)
self.lora_ind[enable_lora, :] = True
self.lora_ind = self.lora_ind.view(-1)
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def zero_pad(self, x):
result = x.new_zeros((*x.shape[:-1], self.out_features))
result = result.view(-1, self.out_features)
result[:, self.lora_ind] = x.reshape(-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora))
return result.view((*x.shape[:-1], self.out_features))
def train(self, mode: bool = True):
nn.Linear.train(self, mode)
self.lora_A.train(mode)
self.lora_B.train(mode)
if not mode and self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0 and any(self.enable_lora):
delta_w = (
F.conv1d(
self.lora_A.weight.data.unsqueeze(0),
self.lora_B.weight.data,
groups=sum(self.enable_lora),
)
.squeeze(0)
.transpose(-2, -1)
)
self.weight.data += transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out)
self.merged = True
elif self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0 and any(self.enable_lora):
delta_w = (
F.conv1d(
self.lora_A.weight.data.unsqueeze(0),
self.lora_B.weight.data,
groups=sum(self.enable_lora),
)
.squeeze(0)
.transpose(-2, -1)
)
self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out)
self.merged = False
def eval(self):
nn.Linear.eval(self)
self.lora_A.eval()
self.lora_B.eval()
def forward(self, x: torch.Tensor):
if self.disable_adapters:
if self.r > 0 and self.merged and any(self.enable_lora):
delta_w = (
F.conv1d(
self.lora_A.weight.data.unsqueeze(0),
self.lora_B.weight.data,
groups=sum(self.enable_lora),
)
.squeeze(0)
.transpose(-2, -1)
)
self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out)
self.merged = False
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
elif self.merged:
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
else:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if self.r > 0:
after_A = self.lora_A(self.lora_dropout(x))
after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
result += self.zero_pad(after_B) * self.scaling
return result
if is_bnb_available():
class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
in_features,
out_features,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
):
bnb.nn.Linear8bitLt.__init__(
self,
in_features,
out_features,
bias=kwargs.get("bias", True),
has_fp16_weights=kwargs.get("has_fp16_weights", True),
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
threshold=kwargs.get("threshold", 0.0),
index=kwargs.get("index", None),
)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Linear(in_features, r, bias=False)
self.lora_B = nn.Linear(r, out_features, bias=False)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def forward(self, x: torch.Tensor):
result = super().forward(x)
if self.disable_adapters:
return result
elif self.r > 0:
if not torch.is_autocast_enabled():
expected_dtype = result.dtype
if x.dtype != torch.float32:
x = x.float()
output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling
result += output
else:
output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
result += output
return result
class MergedLinear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
enable_lora: List[bool] = [False],
**kwargs,
):
bnb.nn.Linear8bitLt.__init__(
self,
in_features,
out_features,
bias=kwargs.get("bias", True),
has_fp16_weights=kwargs.get("has_fp16_weights", True),
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
threshold=kwargs.get("threshold", 0.0),
index=kwargs.get("index", None),
)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
if out_features % len(enable_lora) != 0:
raise ValueError("The length of enable_lora must divide out_features")
self.enable_lora = enable_lora
# Actual trainable parameters
if r > 0 and any(enable_lora):
self.lora_A = nn.Linear(in_features, r * sum(enable_lora), bias=False)
self.lora_B = nn.Conv1d(
r * sum(enable_lora),
out_features // len(enable_lora) * sum(enable_lora),
kernel_size=1,
groups=2,
bias=False,
)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# Compute the indices
self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1)
self.lora_ind[enable_lora, :] = True
self.lora_ind = self.lora_ind.view(-1)
self.reset_parameters()
def reset_parameters(self):
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def zero_pad(self, x):
result = x.new_zeros((*x.shape[:-1], self.out_features))
result = result.view(-1, self.out_features)
result[:, self.lora_ind] = x.reshape(
-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
)
return result.view((*x.shape[:-1], self.out_features))
def forward(self, x: torch.Tensor):
result = super().forward(x)
if self.disable_adapters:
return result
elif self.r > 0:
if not torch.is_autocast_enabled():
expected_dtype = result.dtype
if x.dtype != torch.float32:
x = x.float()
after_A = self.lora_A(self.lora_dropout(x))
after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
output = self.zero_pad(after_B).to(expected_dtype) * self.scaling
result += output
else:
after_A = self.lora_A(self.lora_dropout(x))
after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
output = self.zero_pad(after_B) * self.scaling
result += output
return result
if is_gptq_available():
from autograd_4bit import Autograd4bitQuantLinear
class Linear4bitLt(Autograd4bitQuantLinear, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
in_features,
out_features,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
):
Autograd4bitQuantLinear.__init__(
self,
in_features,
out_features
)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Linear(in_features, r, bias=False)
self.lora_B = nn.Linear(r, out_features, bias=False)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.qweight.requires_grad = False
self.scales.requires_grad = False
self.zeros.requires_grad = False
self.bias.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def forward(self, x: torch.Tensor):
result = super().forward(x)
if self.disable_adapters:
return result
elif self.r > 0:
if not torch.is_autocast_enabled():
expected_dtype = result.dtype
if x.dtype != torch.float32:
x = x.float()
output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling
result += output
else:
output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
result += output
return result

6
requirements.txt Normal file
View File

@ -0,0 +1,6 @@
torch
accelerate
bitsandbytes
git+https://github.com/zphang/transformers@llama_push
git+https://github.com/qwopqwop200/GPTQ-for-LLaMa.git
git+https://github.com/huggingface/peft.git