add xformers support
This commit is contained in:
parent
7871baf311
commit
4261bd8070
|
|
@ -15,7 +15,7 @@ class Finetune4bConfig:
|
||||||
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
|
warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int,
|
||||||
checkpoint: bool, skip: bool, verbose: bool,
|
checkpoint: bool, skip: bool, verbose: bool,
|
||||||
txt_row_thd: int, use_eos_token: bool, groupsize: int, v1: bool,
|
txt_row_thd: int, use_eos_token: bool, groupsize: int, v1: bool,
|
||||||
local_rank: int, flash_attention: bool, backend: str
|
local_rank: int, flash_attention: bool, xformers: bool, backend: str
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -50,6 +50,7 @@ class Finetune4bConfig:
|
||||||
v1 (bool): v1 model flag
|
v1 (bool): v1 model flag
|
||||||
local_rank (int): local rank if using torch.distributed.launch
|
local_rank (int): local rank if using torch.distributed.launch
|
||||||
flash_attention (bool): Enables flash attention
|
flash_attention (bool): Enables flash attention
|
||||||
|
xformers (bool): use xformers or not
|
||||||
"""
|
"""
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.ds_type = ds_type
|
self.ds_type = ds_type
|
||||||
|
|
@ -88,6 +89,7 @@ class Finetune4bConfig:
|
||||||
self.groupsize = groupsize
|
self.groupsize = groupsize
|
||||||
self.v1 = v1
|
self.v1 = v1
|
||||||
self.flash_attention = flash_attention
|
self.flash_attention = flash_attention
|
||||||
|
self.xformers = xformers
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,7 @@ def parse_commandline():
|
||||||
|
|
||||||
# Flash Attention
|
# Flash Attention
|
||||||
parser_training.add_argument("--flash_attention", action="store_true", help="enables flash attention, can improve performance and reduce VRAM use")
|
parser_training.add_argument("--flash_attention", action="store_true", help="enables flash attention, can improve performance and reduce VRAM use")
|
||||||
|
parser_training.add_argument("--xformers", action="store_true", help="enables xformers memory efficient attention, can improve performance and reduce VRAM use")
|
||||||
|
|
||||||
# Train Backend
|
# Train Backend
|
||||||
parser_training.add_argument("--backend", type=str, default='cuda', help="Backend to use. Triton or Cuda.")
|
parser_training.add_argument("--backend", type=str, default='cuda', help="Backend to use. Triton or Cuda.")
|
||||||
|
|
@ -111,5 +112,6 @@ def get_config() -> Finetune4bConfig:
|
||||||
v1=args["v1"],
|
v1=args["v1"],
|
||||||
local_rank=args["local_rank"],
|
local_rank=args["local_rank"],
|
||||||
flash_attention=args["flash_attention"],
|
flash_attention=args["flash_attention"],
|
||||||
|
xformers=args["xformers"],
|
||||||
backend=args["backend"],
|
backend=args["backend"],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,9 @@ ft_config = get_config()
|
||||||
if ft_config.flash_attention:
|
if ft_config.flash_attention:
|
||||||
from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
||||||
replace_llama_attn_with_flash_attn()
|
replace_llama_attn_with_flash_attn()
|
||||||
|
elif ft_config.xformers:
|
||||||
|
from monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention
|
||||||
|
hijack_llama_attention()
|
||||||
|
|
||||||
import autograd_4bit
|
import autograd_4bit
|
||||||
if ft_config.backend.lower() == 'triton':
|
if ft_config.backend.lower() == 'triton':
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,101 @@
|
||||||
|
'''
|
||||||
|
Directly copied the code from https://github.com/oobabooga/text-generation-webui/pull/950/commits and made some adjustments
|
||||||
|
'''
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
try:
|
||||||
|
import xformers.ops
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install xformers to use this module")
|
||||||
|
|
||||||
|
def hijack_llama_attention():
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||||
|
print("Replaced attention with xformers_attention")
|
||||||
|
|
||||||
|
def xformers_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
#We only apply xformers optimizations if we don't need to output the whole attention matrix
|
||||||
|
if not output_attentions:
|
||||||
|
dtype = query_states.dtype
|
||||||
|
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
#This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||||
|
#We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||||
|
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||||
|
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||||
|
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
||||||
|
else:
|
||||||
|
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||||
|
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
|
||||||
|
attn_weights = None
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
@ -8,8 +8,11 @@ from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmb
|
||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
try:
|
||||||
from flash_attn.bert_padding import unpad_input, pad_input
|
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||||
|
from flash_attn.bert_padding import unpad_input, pad_input
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install flash_attn to use this module")
|
||||||
|
|
||||||
class LlamaAttention(nn.Module):
|
class LlamaAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,7 @@ bitsandbytes
|
||||||
datasets
|
datasets
|
||||||
sentencepiece
|
sentencepiece
|
||||||
safetensors
|
safetensors
|
||||||
flash-attn
|
einops
|
||||||
triton
|
|
||||||
colorama
|
colorama
|
||||||
git+https://github.com/huggingface/transformers.git
|
git+https://github.com/huggingface/transformers.git
|
||||||
git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit
|
git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue