From 3ea18575c7af99356b52f884dc9dc6eec8c50852 Mon Sep 17 00:00:00 2001 From: yamashi Date: Thu, 6 Apr 2023 13:49:12 +0200 Subject: [PATCH 1/8] Use flash attention monkeypatch --- finetune.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/finetune.py b/finetune.py index f374e2b..998b1d6 100644 --- a/finetune.py +++ b/finetune.py @@ -16,6 +16,9 @@ } ] """ +from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn + +replace_llama_attn_with_flash_attn() import sys From 7b18b39dd8f29545f3e728a7777a100cfe64bb0a Mon Sep 17 00:00:00 2001 From: yamashi Date: Thu, 6 Apr 2023 13:49:36 +0200 Subject: [PATCH 2/8] Create llama_flash_attn_monkey_patch.py --- monkeypatch/llama_flash_attn_monkey_patch.py | 143 +++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 monkeypatch/llama_flash_attn_monkey_patch.py diff --git a/monkeypatch/llama_flash_attn_monkey_patch.py b/monkeypatch/llama_flash_attn_monkey_patch.py new file mode 100644 index 0000000..1d48bb5 --- /dev/null +++ b/monkeypatch/llama_flash_attn_monkey_patch.py @@ -0,0 +1,143 @@ +from typing import List, Optional, Tuple + +import torch +from torch import nn + +import transformers +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb + +from einops import rearrange + +from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads}).") + self.q_proj = nn.Linear( + hidden_size, + num_heads * self.head_dim, + bias=False, + ) + self.k_proj = nn.Linear( + hidden_size, + num_heads * self.head_dim, + bias=False, + ) + self.v_proj = nn.Linear( + hidden_size, + num_heads * self.head_dim, + bias=False, + ) + self.o_proj = nn.Linear( + num_heads * self.head_dim, + hidden_size, + bias=False, + ) + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel + + attention_mask: [bsz, q_len] + """ + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view( + bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view( + bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view( + bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + # [bsz, q_len, nh, hd] + # [bsz, nh, q_len, hd] + + kv_seq_len = key_states.shape[-2] + offset = 0 + if past_key_value is not None: + offset = past_key_value[0].shape[-2] + kv_seq_len += offset + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, + key_states, + cos, + sin, + offset=offset) + # [bsz, nh, t, hd] + assert not output_attentions, "output_attentions is not supported" + assert not use_cache, "use_cache is not supported" + assert past_key_value is None, "past_key_value is not supported" + + # Flash attention codes from + # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py + + # transform the data into the format required by flash attention + qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + # We have disabled _prepare_decoder_attention_mask in LlamaModel + # the attention_mask should be the same as the key_padding_mask + key_padding_mask = attention_mask + + + if key_padding_mask is None: + qkv = rearrange(qkv, 'b s ... -> (b s) ...') + max_s = q_len + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, + device=qkv.device) + output = flash_attn_unpadded_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, + softmax_scale=None, causal=True + ) + output = rearrange(output, '(b s) ... -> b s ...', b=bsz) + else: + nheads = qkv.shape[-2] + x = rearrange(qkv, 'b s three h d -> b s (three h d)') + x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) + x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) + output_unpad = flash_attn_unpadded_qkvpacked_func( + x_unpad, cu_q_lens, max_s, 0.0, + softmax_scale=None, causal=True + ) + output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), + indices, bsz, q_len), + 'b s (h d) -> b s h d', h=nheads) + return self.o_proj(rearrange(output, + 'b s h d -> b s (h d)')), None, None + + +# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask +def _prepare_decoder_attention_mask(self, attention_mask, input_shape, + inputs_embeds, past_key_values_length): + # [bsz, seq_len] + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention From 30bf938d03ada62a680820529b0db025484a67c1 Mon Sep 17 00:00:00 2001 From: yamashi Date: Thu, 6 Apr 2023 13:50:25 +0200 Subject: [PATCH 3/8] Update requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 605c0d1..2db1f01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ bitsandbytes datasets sentencepiece safetensors +flash-attn git+https://github.com/huggingface/transformers.git git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit git+https://github.com/sterlind/peft.git From 7770e76c9c143eaab04657d082d3f3406ebede71 Mon Sep 17 00:00:00 2001 From: yamashi Date: Thu, 6 Apr 2023 17:32:01 +0200 Subject: [PATCH 4/8] Fix args of flash attention --- monkeypatch/llama_flash_attn_monkey_patch.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/monkeypatch/llama_flash_attn_monkey_patch.py b/monkeypatch/llama_flash_attn_monkey_patch.py index 1d48bb5..dd2589d 100644 --- a/monkeypatch/llama_flash_attn_monkey_patch.py +++ b/monkeypatch/llama_flash_attn_monkey_patch.py @@ -4,7 +4,7 @@ import torch from torch import nn import transformers -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb +from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb from einops import rearrange @@ -16,13 +16,14 @@ class LlamaAttention(nn.Module): def __init__( self, - hidden_size: int, - num_heads: int, + config: LlamaConfig, ): super().__init__() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads self.hidden_size = hidden_size self.num_heads = num_heads - self.head_dim = hidden_size // num_heads + self.head_dim = self.hidden_size // num_heads if (self.head_dim * num_heads) != self.hidden_size: raise ValueError( From 2bf5d42f287127427e09951c78c4b023203b7b22 Mon Sep 17 00:00:00 2001 From: yamashi Date: Thu, 6 Apr 2023 17:46:15 +0200 Subject: [PATCH 5/8] Add position_ids to flash attention --- monkeypatch/llama_flash_attn_monkey_patch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monkeypatch/llama_flash_attn_monkey_patch.py b/monkeypatch/llama_flash_attn_monkey_patch.py index dd2589d..0c80227 100644 --- a/monkeypatch/llama_flash_attn_monkey_patch.py +++ b/monkeypatch/llama_flash_attn_monkey_patch.py @@ -60,6 +60,7 @@ class LlamaAttention(nn.Module): hidden_states: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], @@ -80,16 +81,15 @@ class LlamaAttention(nn.Module): # [bsz, nh, q_len, hd] kv_seq_len = key_states.shape[-2] - offset = 0 if past_key_value is not None: - offset = past_key_value[0].shape[-2] - kv_seq_len += offset + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, - offset=offset) + position_ids) # [bsz, nh, t, hd] assert not output_attentions, "output_attentions is not supported" assert not use_cache, "use_cache is not supported" From 778035152d7ba63a2f980eea98142d5e6eec7fe3 Mon Sep 17 00:00:00 2001 From: yamashi Date: Fri, 7 Apr 2023 00:42:34 +0200 Subject: [PATCH 6/8] Update arg_parser.py --- arg_parser.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/arg_parser.py b/arg_parser.py index a985380..be18154 100644 --- a/arg_parser.py +++ b/arg_parser.py @@ -66,6 +66,8 @@ def parse_commandline(): # Multi GPU Support parser_training.add_argument("--local_rank", type=int, default=0, help="local rank if using torch.distributed.launch") + + parser_training.add_argument("--flash_attention", help="enables flash attention, can improve performance and reduce VRAM use") return vars(parser.parse_args()) @@ -102,4 +104,5 @@ def get_config() -> Finetune4bConfig: use_eos_token=args["use_eos_token"]!=0, groupsize=args["groupsize"], local_rank=args["local_rank"], + flash_attention=args["flash_attention"], ) From 95cd390d2528dcaaa618b6c2f4765b136e4e313b Mon Sep 17 00:00:00 2001 From: yamashi Date: Fri, 7 Apr 2023 00:43:15 +0200 Subject: [PATCH 7/8] Update Finetune4bConfig.py --- Finetune4bConfig.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Finetune4bConfig.py b/Finetune4bConfig.py index a8a33bf..09c535c 100644 --- a/Finetune4bConfig.py +++ b/Finetune4bConfig.py @@ -15,7 +15,7 @@ class Finetune4bConfig: warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int, checkpoint: bool, skip: bool, verbose: bool, txt_row_thd: int, use_eos_token: bool, groupsize: int, - local_rank: int, + local_rank: int, flash_attention: bool ): """ Args: @@ -48,6 +48,7 @@ class Finetune4bConfig: use_eos_token (bool): Use Eos token instead of padding with 0 groupsize (int): Group size of V2 model, use -1 to load V1 model local_rank (int): local rank if using torch.distributed.launch + flash_attention (bool): Enables flash attention """ self.dataset = dataset self.ds_type = ds_type @@ -84,6 +85,7 @@ class Finetune4bConfig: if self.ddp: self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size self.groupsize = groupsize + self.flash_attention = flash_attention def __str__(self) -> str: From c5aa7fb6951ac573089daf7a1aa990bc68270f1f Mon Sep 17 00:00:00 2001 From: yamashi Date: Fri, 7 Apr 2023 00:43:36 +0200 Subject: [PATCH 8/8] Update finetune.py --- finetune.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/finetune.py b/finetune.py index 998b1d6..eea0e66 100644 --- a/finetune.py +++ b/finetune.py @@ -16,9 +16,13 @@ } ] """ -from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn +# Early load config to replace attn if needed +from arg_parser import get_config +ft_config = get_config() -replace_llama_attn_with_flash_attn() +if ft_config.flash_attention: + from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn + replace_llama_attn_with_flash_attn() import sys @@ -32,10 +36,9 @@ from autograd_4bit import load_llama_model_4bit_low_ram from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, PeftModel # ! Config -from arg_parser import get_config import train_data -ft_config = get_config() + # * Show loaded parameters if ft_config.local_rank == 0: