diff --git a/GPTQ-for-LLaMa/autograd_4bit.py b/GPTQ-for-LLaMa/autograd_4bit.py new file mode 100644 index 0000000..0837a1d --- /dev/null +++ b/GPTQ-for-LLaMa/autograd_4bit.py @@ -0,0 +1,144 @@ +import quant +import torch +import numpy as np +import torch.nn as nn + + +def matmul4bit(x, qweight, scales, zeros): + """ + input x: (n, m) + qweight: (j, k) + where m == j*8 + + perform x @ qweight + + return y: + """ + assert qweight.shape[0] * 8 == x.shape[-1] + outshape = tuple(list(x.shape[:-1]) + [qweight.shape[1]]) + x = x.reshape(-1, x.shape[-1]) + assert x.shape[0] % 256 == 0 + y = torch.zeros((x.shape[0], qweight.shape[-1]), dtype=torch.float32, device=x.device) + dtype = x.dtype + x = x.float() + quant.quant_cuda.vecquant4matmul(x, qweight, y, scales, zeros) + y = y.to(dtype) + return y.reshape(outshape) + + +def matmul4bit_transpose(x, qweight, scales, zeros): + """ + input x: (n, m) + qweight: (j, k) + where m == k + + perform qweight @ x.T + + return y: + """ + assert qweight.shape[1] == x.shape[-1] + outshape = tuple(list(x.shape[:-1]) + [qweight.shape[0] * 8]) + x = x.reshape(-1, x.shape[-1]) + assert x.shape[0] % 256 == 0 + y = torch.zeros((qweight.shape[0] * 8, x.shape[0]), dtype=torch.float32, device=x.device) + dtype = x.dtype + x = x.float() + quant.quant_cuda.vecquant4transposematmul(x, qweight, y, scales, zeros) + y = y.to(dtype) + return y.reshape(outshape) + + +class AutogradMatmul4bit(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, qweight, scales, zeros): + ctx.save_for_backward(x, qweight, scales, zeros) + output = matmul4bit(x, qweight, scales, zeros).clone() + return output # equals to torch.matmul(x, qweight) + + @staticmethod + def backward(ctx, grad_output): + x, qweight, scales, zeros = ctx.saved_tensors + # print(grad_output.shape, A.shape, B.shape) + + # compute x @ qweight.T = (qweight @ x.T).T = f(x, qweight).T + grad1 = matmul4bit_transpose(grad_output, qweight, scales, zeros) + grad2 = torch.matmul(x.transpose(-1, -2), grad_output) + + return grad1, grad2, None, None + + +# Assumes layer is perfectly divisible into 256 * 256 blocks +class Autograd4bitQuantLinear(nn.Module): + + def __init__(self, infeatures, outfeatures): + super().__init__() + bits = 4 + self.in_features = infeatures + self.out_features = outfeatures + self.bits = bits + self.register_buffer('zeros', torch.empty((outfeatures, 1))) + self.register_buffer('scales', torch.empty((outfeatures, 1))) + self.register_buffer('bias', torch.empty(outfeatures)) + self.register_buffer( + 'qweight', torch.empty((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int) + ) + + def forward(self, x): + out = AutogradMatmul4bit.apply(x, self.qweight, self.scales, self.zeros) + out += self.bias + return out + + +def make_quant_for_4bit_autograd(module, names, name=''): + if isinstance(module, Autograd4bitQuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + setattr( + module, attr, Autograd4bitQuantLinear(tmp.in_features, tmp.out_features) + ) + for name1, child in module.named_children(): + make_quant_for_4bit_autograd(child, names, name + '.' + name1 if name != '' else name1) + + +def load_llama_model_4bit_low_ram(config_path, model_path): + import transformers + import accelerate + from transformers import LLaMAConfig, LLaMAForCausalLM, LLaMATokenizer + from modelutils import find_layers + + print("Loading Model ...") + t0 = time.time() + + with accelerate.init_empty_weights(): + config = LLaMAConfig.from_pretrained(config_path) + def noop(*args, **kwargs): + pass + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop + torch.set_default_dtype(torch.half) + transformers.modeling_utils._init_weights = False + torch.set_default_dtype(torch.half) + model = LLaMAForCausalLM(config) + torch.set_default_dtype(torch.float) + model = model.eval() + layers = find_layers(model) + for name in ['lm_head']: + if name in layers: + del layers[name] + make_quant_for_4bit_autograd(model, layers) + model = accelerate.load_checkpoint_and_dispatch(model=model, checkpoint=model_path, device_map='auto') + model.cuda() + model.seqlen = 2048 + + tokenizer = LLaMATokenizer.from_pretrained(config_path) + tokenizer.truncation_side = 'left' + + print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") + + return model, tokenizer + diff --git a/GPTQ-for-LLaMa/quant_cuda.cpp b/GPTQ-for-LLaMa/quant_cuda.cpp new file mode 100644 index 0000000..af4e28f --- /dev/null +++ b/GPTQ-for-LLaMa/quant_cuda.cpp @@ -0,0 +1,76 @@ +#include +#include +#include + +void vecquant2matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant2matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant2matmul_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant3matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant3matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant3matmul_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant4matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant4matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant8matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant4transposematmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant4transposematmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4transposematmul_cuda(vec, mat, mul, scales, zeros); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)"); + m.def("vecquant4transposematmul", &vecquant4transposematmul, "Vector 4-bit Transpose Quantized Matrix Multiplication (CUDA)"); +} diff --git a/GPTQ-for-LLaMa/quant_cuda_kernel.cu b/GPTQ-for-LLaMa/quant_cuda_kernel.cu new file mode 100644 index 0000000..769fe6f --- /dev/null +++ b/GPTQ-for-LLaMa/quant_cuda_kernel.cu @@ -0,0 +1,480 @@ +#include +#include +#include +#include + +template +__global__ void VecQuant2MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width +); + +template +__global__ void VecQuant3MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width +); + +template +__global__ void VecQuant4MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width +); + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width +); + +const int BLOCKWIDTH = 256; +const int BLOCKHEIGHT2 = 16; +const int BLOCKHEIGHT3 = 24; +const int BLOCKHEIGHT4 = 32; +const int BLOCKHEIGHT8 = 64; + +__device__ inline unsigned int as_unsigned(int i) { + return *reinterpret_cast(&i); +} + +void vecquant2matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant2matmul_cuda", ([&] { + VecQuant2MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, vec_height, height, width + ); + }) + ); +} + +template +__global__ void VecQuant2MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width +) { + int b = blockIdx.z; + int h = BLOCKHEIGHT2 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT2) * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + + scalar_t scale = scales[w]; + scalar_t zero = zeros[w]; + + scalar_t res = 0; + int i = width * h + w; + int k = 0; + + unsigned int tmp; + + while (k < BLOCKWIDTH) { + tmp = as_unsigned(mat[i]); + res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0]; + res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1]; + res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2]; + res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3]; + res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4]; + res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5]; + res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6]; + res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7]; + res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8]; + res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9]; + res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10]; + res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11]; + res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12]; + res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13]; + res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14]; + res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15]; + i += width; + k += 16; + } + + atomicAdd(&mul[b * width + w], res); +} + +void vecquant3matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant3matmul_cuda", ([&] { + VecQuant3MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, vec_height, height, width + ); + }) + ); +} + +template +__global__ void VecQuant3MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width +) { + int b = blockIdx.z; + int h = BLOCKHEIGHT3 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT3) * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + + scalar_t scale = scales[w]; + scalar_t zero = zeros[w]; + + scalar_t res = 0; + int i = width * h + w; + int k = 0; + + unsigned int tmp1; + unsigned int tmp2; + unsigned int tmp; + + while (k < BLOCKWIDTH) { + tmp1 = as_unsigned(mat[i]); + res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; + res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; + res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; + res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; + res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; + res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; + res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; + res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; + res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; + res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; + i += width; + tmp2 = as_unsigned(mat[i]); + tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4); + tmp2 >>= 1; + res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; + k += 11; + res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0]; + res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1]; + res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2]; + res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3]; + res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4]; + res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5]; + res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6]; + res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7]; + res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8]; + res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9]; + i += width; + tmp1 = as_unsigned(mat[i]); + tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6); + tmp1 >>= 2; + res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; + k += 11; + res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; + res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; + res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; + res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; + res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; + res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; + res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; + res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; + res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; + res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; + i += width; + k += 10; + } + + atomicAdd(&mul[b * width + w], res); +} + +void vecquant4matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_cuda", ([&] { + VecQuant4MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, vec_height, height, width + ); + }) + ); +} + +template +__global__ void VecQuant4MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width +) { + int b = blockIdx.z; + int h = BLOCKHEIGHT4 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + + scalar_t scale = scales[w]; + scalar_t zero = zeros[w]; + + scalar_t res = 0; + int i = width * h + w; + int k = 0; + + unsigned int tmp; + + while (k < BLOCKWIDTH) { + tmp = as_unsigned(mat[i]); + res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0]; + res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1]; + res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2]; + res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3]; + res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4]; + res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5]; + res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6]; + res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7]; + i += width; + k += 8; + } + + atomicAdd(&mul[b * width + w], res); +} + +void vecquant8matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_cuda", ([&] { + VecQuant8MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, vec_height, height, width + ); + }) + ); +} + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width +) { + int b = blockIdx.z; + int h = BLOCKHEIGHT8 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + blockvec[threadIdx.x] = vec[b * vec_height + (h / BLOCKHEIGHT8) * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + + scalar_t scale = scales[w]; + scalar_t zero = zeros[w]; + + scalar_t res = 0; + int i = width * h + w; + int k = 0; + + unsigned int tmp; + + while (k < BLOCKWIDTH) { + tmp = as_unsigned(mat[i]); + res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0]; + res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1]; + res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2]; + res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3]; + i += width; + k += 4; + } + + atomicAdd(&mul[b * width + w], res); +} + +template +__global__ void VecQuant4TransposeMatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width +) { + int b = blockIdx.z; + int h = BLOCKHEIGHT4 * blockIdx.x + threadIdx.x / 8; + unsigned int shift = (unsigned int)((threadIdx.x % 8) * 4); + int w = BLOCKWIDTH * blockIdx.y; + + int n_rows = 8 * BLOCKHEIGHT4 * blockIdx.x + threadIdx.x; + int n_cols = b; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + blockvec[threadIdx.x] = vec[n_cols * vec_height + w + threadIdx.x]; + __syncthreads(); + + scalar_t res = 0; + int i = width * h + w; + int k = 0; + int j = w; + unsigned int tmp; + while (k < BLOCKWIDTH) { + tmp = as_unsigned(mat[i]); + res += (scales[j] * scalar_t((tmp >> shift) & 0xF) - zeros[j]) * blockvec[k]; + i += 1; + j += 1; + k += 1; + } + + atomicAdd(&mul[n_cols * height * 8 + n_rows], res); +} + +void vecquant4transposematmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH, + batch + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4transposematmul_cuda", ([&] { + VecQuant4TransposeMatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, vec_height, height, width + ); + }) + ); +} diff --git a/peft/tuners/lora.py b/peft/tuners/lora.py new file mode 100644 index 0000000..588c39b --- /dev/null +++ b/peft/tuners/lora.py @@ -0,0 +1,697 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import math +import re +import warnings +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.pytorch_utils import Conv1D + +from ..utils import PeftConfig, PeftType, transpose + + +def is_bnb_available(): + return importlib.util.find_spec("bitsandbytes") is not None + + +def is_gptq_available(): + return importlib.util.find_spec("quant") is not None + + +if is_bnb_available(): + import bitsandbytes as bnb + + +if is_gptq_available(): + import quant + + +@dataclass +class LoraConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`~peft.Lora`]. + + Args: + r (`int`): Lora attention dimension + target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to. + lora_alpha (`float`): The alpha parameter for Lora scaling. + lora_dropout (`float`): The dropout probability for Lora layers. + merge_weights (`bool`): + Whether to merge the weights of the Lora layers with the base transformer model in `eval` mode. + fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out) + enable_lora ( `List[bool]`): Used with `lora.MergedLinear`. + bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only' + modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable + and saved in the final checkpoint. + """ + + r: int = field(default=8, metadata={"help": "Lora attention dimension"}) + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with Lora." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " + }, + ) + lora_alpha: int = field(default=None, metadata={"help": "Lora alpha"}) + lora_dropout: float = field(default=None, metadata={"help": "Lora dropout"}) + merge_weights: bool = field( + default=False, metadata={"help": "Merge weights of the original model and the Lora model"} + ) + fan_in_fan_out: bool = field( + default=False, + metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, + ) + enable_lora: Optional[List[bool]] = field(default=None, metadata={"help": "Used with `lora.MergedLinear`."}) + bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.LORA + + +class LoraModel(torch.nn.Module): + """ + Creates Low Rank Adapter (Lora) model from a pretrained transformers model. + + Args: + model ([`transformers.PreTrainedModel`]): The model to be adapted. + config ([`LoraConfig`]): The configuration of the Lora model. + + Returns: + `torch.nn.Module`: The Lora model. + + Example:: + + >>> from transformers import AutoModelForSeq2SeqLM, LoraConfig >>> from peft import LoraModel, LoraConfig >>> + config = LoraConfig( + peft_type="LORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"], + lora_dropout=0.01, ) + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> lora_model = LoraModel(config, model) + + **Attributes**: + - **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`LoraConfig`]): The configuration of the Lora model. + """ + + def __init__(self, config, model): + super().__init__() + self.peft_config = config + self.model = model + self._find_and_replace() + mark_only_lora_as_trainable(self.model, self.peft_config.bias) + self.forward = self.model.forward + + def _find_and_replace(self): + loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) + if loaded_in_8bit and not is_bnb_available(): + raise ImportError( + "To use Lora with 8-bit quantization, please install the `bitsandbytes` package. " + "You can install it with `pip install bitsandbytes`." + ) + is_target_modules_in_base_model = False + is_hf_device_map_available = hasattr(self.model, "hf_device_map") + kwargs = { + "r": self.peft_config.r, + "lora_alpha": self.peft_config.lora_alpha, + "lora_dropout": self.peft_config.lora_dropout, + "fan_in_fan_out": self.peft_config.fan_in_fan_out, + "merge_weights": (self.peft_config.merge_weights or self.peft_config.inference_mode) + and not is_hf_device_map_available, + } + key_list = [key for key, _ in self.model.named_modules()] + for key in key_list: + if isinstance(self.peft_config.target_modules, str): + target_module_found = re.fullmatch(self.peft_config.target_modules, key) + else: + target_module_found = any(key.endswith(target_key) for target_key in self.peft_config.target_modules) + if target_module_found: + if not is_target_modules_in_base_model: + is_target_modules_in_base_model = True + parent, target, target_name = self._get_submodules(key) + bias = target.bias is not None + if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): + kwargs.update( + { + "has_fp16_weights": target.state.has_fp16_weights, + "memory_efficient_backward": target.state.memory_efficient_backward, + "threshold": target.state.threshold, + "index": target.index, + } + ) + if self.peft_config.enable_lora is None: + new_module = Linear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs) + else: + kwargs.update({"enable_lora": self.peft_config.enable_lora}) + new_module = MergedLinear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs) + elif isinstance(target, torch.nn.Linear) and self.peft_config.enable_lora is None: + new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs) + elif isinstance(target, Autograd4bitQuantLinear) and self.peft_config.enable_lora is None: + new_module = Linear4bitLt(target.in_features, target.out_features, bias=bias, **kwargs) + elif self.peft_config.enable_lora is not None: + kwargs.update({"enable_lora": self.peft_config.enable_lora}) + if isinstance(target, Conv1D): + in_features, out_features = ( + target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape + ) + else: + in_features, out_features = target.in_features, target.out_features + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is not a Conv1D. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = self.peft_config.fan_in_fan_out = False + new_module = MergedLinear(in_features, out_features, bias=bias, **kwargs) + self._replace_module(parent, target_name, new_module, target) + if not is_target_modules_in_base_model: + raise ValueError( + f"Target modules {self.peft_config.target_modules} not found in the base model. " + f"Please check the target modules and try again." + ) + + def _get_submodules(self, key): + parent = self.model.get_submodule(".".join(key.split(".")[:-1])) + target_name = key.split(".")[-1] + target = self.model.get_submodule(key) + return parent, target, target_name + + def _replace_module(self, parent_module, child_name, new_module, old_module): + setattr(parent_module, child_name, new_module) + if isinstance(old_module, Autograd4bitQuantLinear) and isinstance(new_module, Linear4bitLt): + new_module.qweight = old_module.qweight + new_module.scales = old_module.scales + new_module.zeros = old_module.zeros + new_module.bias = old_module.bias + if getattr(old_module, "state", None) is not None: + new_module.state = old_module.state + new_module.to(old_module.qweight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if "lora_" in name: + module.to(old_module.qweight.device) + else: + new_module.weight = old_module.weight + if old_module.bias is not None: + new_module.bias = old_module.bias + if getattr(old_module, "state", None) is not None: + new_module.state = old_module.state + new_module.to(old_module.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if "lora_" in name: + module.to(old_module.weight.device) + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.model, name) + + @property + def modules_to_save(self): + return None + + def get_peft_config_as_dict(self, inference: bool = False): + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.peft_config).items()} + if inference: + config["inference_mode"] = True + return config + + def _set_adapter_layers(self, enabled=True): + for module in self.model.modules(): + if isinstance(module, LoraLayer): + module.disable_adapters = False if enabled else True + + def enable_adapter_layers(self): + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self): + self._set_adapter_layers(enabled=False) + + +# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# and modified to work with PyTorch FSDP + + +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + + +# had to adapt it for `lora_only` to work +def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: + for n, p in model.named_parameters(): + if "lora_" not in n: + p.requires_grad = False + if bias == "none": + return + elif bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "lora_only": + for m in model.modules(): + if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError + + +class LoraLayer: + def __init__( + self, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.0: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + self.disable_adapters = False + + +class Linear(nn.Linear, LoraLayer): + # Lora implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + merge_weights: bool = True, + **kwargs, + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) + + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Linear(in_features, r, bias=False) + self.lora_B = nn.Linear(r, out_features, bias=False) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + if fan_in_fan_out: + self.weight.data = self.weight.data.T + + def reset_parameters(self): + nn.Linear.reset_parameters(self) + if hasattr(self, "lora_A"): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B.weight) + + def train(self, mode: bool = True): + nn.Linear.train(self, mode) + self.lora_A.train(mode) + self.lora_B.train(mode) + if not mode and self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0: + self.weight.data += ( + transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling + ) + self.merged = True + elif self.merge_weights and self.merged: + # Make sure that the weights are not merged + if self.r > 0: + self.weight.data -= ( + transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling + ) + self.merged = False + + def eval(self): + nn.Linear.eval(self) + self.lora_A.eval() + self.lora_B.eval() + + def forward(self, x: torch.Tensor): + if self.disable_adapters: + if self.r > 0 and self.merged: + self.weight.data -= ( + transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling + ) + self.merged = False + + return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) + elif self.r > 0 and not self.merged: + result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) + if self.r > 0: + result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + return result + else: + return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) + + +class MergedLinear(nn.Linear, LoraLayer): + # Lora implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + enable_lora: List[bool] = [False], + fan_in_fan_out: bool = False, + merge_weights: bool = True, + **kwargs, + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) + if out_features % len(enable_lora) != 0: + raise ValueError("The length of enable_lora must divide out_features") + self.enable_lora = enable_lora + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + if r > 0 and any(enable_lora): + self.lora_A = nn.Linear(in_features, r * sum(enable_lora), bias=False) + self.lora_B = nn.Conv1d( + r * sum(enable_lora), + out_features // len(enable_lora) * sum(enable_lora), + kernel_size=1, + groups=2, + bias=False, + ) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + # Compute the indices + self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1) + self.lora_ind[enable_lora, :] = True + self.lora_ind = self.lora_ind.view(-1) + self.reset_parameters() + if fan_in_fan_out: + self.weight.data = self.weight.data.T + + def reset_parameters(self): + nn.Linear.reset_parameters(self) + if hasattr(self, "lora_A"): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B.weight) + + def zero_pad(self, x): + result = x.new_zeros((*x.shape[:-1], self.out_features)) + result = result.view(-1, self.out_features) + result[:, self.lora_ind] = x.reshape(-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)) + return result.view((*x.shape[:-1], self.out_features)) + + def train(self, mode: bool = True): + nn.Linear.train(self, mode) + self.lora_A.train(mode) + self.lora_B.train(mode) + if not mode and self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0 and any(self.enable_lora): + delta_w = ( + F.conv1d( + self.lora_A.weight.data.unsqueeze(0), + self.lora_B.weight.data, + groups=sum(self.enable_lora), + ) + .squeeze(0) + .transpose(-2, -1) + ) + self.weight.data += transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out) + self.merged = True + elif self.merge_weights and self.merged: + # Make sure that the weights are not merged + if self.r > 0 and any(self.enable_lora): + delta_w = ( + F.conv1d( + self.lora_A.weight.data.unsqueeze(0), + self.lora_B.weight.data, + groups=sum(self.enable_lora), + ) + .squeeze(0) + .transpose(-2, -1) + ) + self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out) + self.merged = False + + def eval(self): + nn.Linear.eval(self) + self.lora_A.eval() + self.lora_B.eval() + + def forward(self, x: torch.Tensor): + if self.disable_adapters: + if self.r > 0 and self.merged and any(self.enable_lora): + delta_w = ( + F.conv1d( + self.lora_A.weight.data.unsqueeze(0), + self.lora_B.weight.data, + groups=sum(self.enable_lora), + ) + .squeeze(0) + .transpose(-2, -1) + ) + self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out) + self.merged = False + return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) + elif self.merged: + return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) + else: + result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) + if self.r > 0: + after_A = self.lora_A(self.lora_dropout(x)) + after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1) + result += self.zero_pad(after_B) * self.scaling + return result + + +if is_bnb_available(): + + class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer): + # Lora implemented in a dense layer + def __init__( + self, + in_features, + out_features, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs, + ): + bnb.nn.Linear8bitLt.__init__( + self, + in_features, + out_features, + bias=kwargs.get("bias", True), + has_fp16_weights=kwargs.get("has_fp16_weights", True), + memory_efficient_backward=kwargs.get("memory_efficient_backward", False), + threshold=kwargs.get("threshold", 0.0), + index=kwargs.get("index", None), + ) + LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Linear(in_features, r, bias=False) + self.lora_B = nn.Linear(r, out_features, bias=False) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + + def reset_parameters(self): + if hasattr(self, "lora_A"): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B.weight) + + def forward(self, x: torch.Tensor): + result = super().forward(x) + + if self.disable_adapters: + return result + elif self.r > 0: + if not torch.is_autocast_enabled(): + expected_dtype = result.dtype + + if x.dtype != torch.float32: + x = x.float() + output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling + result += output + else: + output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + result += output + return result + + class MergedLinear8bitLt(bnb.nn.Linear8bitLt, LoraLayer): + # Lora implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + enable_lora: List[bool] = [False], + **kwargs, + ): + bnb.nn.Linear8bitLt.__init__( + self, + in_features, + out_features, + bias=kwargs.get("bias", True), + has_fp16_weights=kwargs.get("has_fp16_weights", True), + memory_efficient_backward=kwargs.get("memory_efficient_backward", False), + threshold=kwargs.get("threshold", 0.0), + index=kwargs.get("index", None), + ) + LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) + if out_features % len(enable_lora) != 0: + raise ValueError("The length of enable_lora must divide out_features") + self.enable_lora = enable_lora + # Actual trainable parameters + if r > 0 and any(enable_lora): + self.lora_A = nn.Linear(in_features, r * sum(enable_lora), bias=False) + self.lora_B = nn.Conv1d( + r * sum(enable_lora), + out_features // len(enable_lora) * sum(enable_lora), + kernel_size=1, + groups=2, + bias=False, + ) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + # Compute the indices + self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1) + self.lora_ind[enable_lora, :] = True + self.lora_ind = self.lora_ind.view(-1) + self.reset_parameters() + + def reset_parameters(self): + if hasattr(self, "lora_A"): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B.weight) + + def zero_pad(self, x): + result = x.new_zeros((*x.shape[:-1], self.out_features)) + result = result.view(-1, self.out_features) + result[:, self.lora_ind] = x.reshape( + -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora) + ) + return result.view((*x.shape[:-1], self.out_features)) + + def forward(self, x: torch.Tensor): + result = super().forward(x) + if self.disable_adapters: + return result + elif self.r > 0: + if not torch.is_autocast_enabled(): + expected_dtype = result.dtype + if x.dtype != torch.float32: + x = x.float() + after_A = self.lora_A(self.lora_dropout(x)) + after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1) + output = self.zero_pad(after_B).to(expected_dtype) * self.scaling + result += output + else: + after_A = self.lora_A(self.lora_dropout(x)) + after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1) + output = self.zero_pad(after_B) * self.scaling + result += output + return result + +if is_gptq_available(): + + from autograd_4bit import Autograd4bitQuantLinear + + class Linear4bitLt(Autograd4bitQuantLinear, LoraLayer): + # Lora implemented in a dense layer + def __init__( + self, + in_features, + out_features, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs, + ): + Autograd4bitQuantLinear.__init__( + self, + in_features, + out_features + ) + LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Linear(in_features, r, bias=False) + self.lora_B = nn.Linear(r, out_features, bias=False) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.qweight.requires_grad = False + self.scales.requires_grad = False + self.zeros.requires_grad = False + self.bias.requires_grad = False + self.reset_parameters() + + def reset_parameters(self): + if hasattr(self, "lora_A"): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B.weight) + + def forward(self, x: torch.Tensor): + result = super().forward(x) + + if self.disable_adapters: + return result + elif self.r > 0: + if not torch.is_autocast_enabled(): + expected_dtype = result.dtype + + if x.dtype != torch.float32: + x = x.float() + output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling + result += output + else: + output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + result += output + return result diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1f7dcbe --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +torch +accelerate +bitsandbytes +git+https://github.com/zphang/transformers@llama_push +git+https://github.com/qwopqwop200/GPTQ-for-LLaMa.git +git+https://github.com/huggingface/peft.git