From 4261bd8070d1dfd5c90d65517772c849af2dd79d Mon Sep 17 00:00:00 2001 From: John Smith Date: Wed, 12 Apr 2023 12:59:44 +0800 Subject: [PATCH] add xformers support --- Finetune4bConfig.py | 4 +- arg_parser.py | 2 + finetune.py | 3 + monkeypatch/__init__.py | 0 monkeypatch/llama_attn_hijack_xformers.py | 101 +++++++++++++++++++ monkeypatch/llama_flash_attn_monkey_patch.py | 7 +- requirements.txt | 3 +- 7 files changed, 115 insertions(+), 5 deletions(-) create mode 100644 monkeypatch/__init__.py create mode 100644 monkeypatch/llama_attn_hijack_xformers.py diff --git a/Finetune4bConfig.py b/Finetune4bConfig.py index 2169f46..f3ab3e9 100644 --- a/Finetune4bConfig.py +++ b/Finetune4bConfig.py @@ -15,7 +15,7 @@ class Finetune4bConfig: warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int, checkpoint: bool, skip: bool, verbose: bool, txt_row_thd: int, use_eos_token: bool, groupsize: int, v1: bool, - local_rank: int, flash_attention: bool, backend: str + local_rank: int, flash_attention: bool, xformers: bool, backend: str ): """ Args: @@ -50,6 +50,7 @@ class Finetune4bConfig: v1 (bool): v1 model flag local_rank (int): local rank if using torch.distributed.launch flash_attention (bool): Enables flash attention + xformers (bool): use xformers or not """ self.dataset = dataset self.ds_type = ds_type @@ -88,6 +89,7 @@ class Finetune4bConfig: self.groupsize = groupsize self.v1 = v1 self.flash_attention = flash_attention + self.xformers = xformers self.backend = backend diff --git a/arg_parser.py b/arg_parser.py index 02c56d2..8ab0fe1 100644 --- a/arg_parser.py +++ b/arg_parser.py @@ -70,6 +70,7 @@ def parse_commandline(): # Flash Attention parser_training.add_argument("--flash_attention", action="store_true", help="enables flash attention, can improve performance and reduce VRAM use") + parser_training.add_argument("--xformers", action="store_true", help="enables xformers memory efficient attention, can improve performance and reduce VRAM use") # Train Backend parser_training.add_argument("--backend", type=str, default='cuda', help="Backend to use. Triton or Cuda.") @@ -111,5 +112,6 @@ def get_config() -> Finetune4bConfig: v1=args["v1"], local_rank=args["local_rank"], flash_attention=args["flash_attention"], + xformers=args["xformers"], backend=args["backend"], ) diff --git a/finetune.py b/finetune.py index ff29538..523dbad 100644 --- a/finetune.py +++ b/finetune.py @@ -23,6 +23,9 @@ ft_config = get_config() if ft_config.flash_attention: from monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() +elif ft_config.xformers: + from monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention + hijack_llama_attention() import autograd_4bit if ft_config.backend.lower() == 'triton': diff --git a/monkeypatch/__init__.py b/monkeypatch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/monkeypatch/llama_attn_hijack_xformers.py b/monkeypatch/llama_attn_hijack_xformers.py new file mode 100644 index 0000000..4c4feaf --- /dev/null +++ b/monkeypatch/llama_attn_hijack_xformers.py @@ -0,0 +1,101 @@ +''' +Directly copied the code from https://github.com/oobabooga/text-generation-webui/pull/950/commits and made some adjustments +''' +import math +import sys +import torch +import torch.nn as nn +import transformers.models.llama.modeling_llama + +from typing import Optional +from typing import Tuple + +try: + import xformers.ops +except ImportError: + raise ImportError("Please install xformers to use this module") + +def hijack_llama_attention(): + transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward + print("Replaced attention with xformers_attention") + +def xformers_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + #We only apply xformers optimizations if we don't need to output the whole attention matrix + if not output_attentions: + dtype = query_states.dtype + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + #This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. + #We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. + if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None) + else: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask()) + attn_weights = None + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value diff --git a/monkeypatch/llama_flash_attn_monkey_patch.py b/monkeypatch/llama_flash_attn_monkey_patch.py index 0c80227..aa5c877 100644 --- a/monkeypatch/llama_flash_attn_monkey_patch.py +++ b/monkeypatch/llama_flash_attn_monkey_patch.py @@ -8,8 +8,11 @@ from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmb from einops import rearrange -from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func -from flash_attn.bert_padding import unpad_input, pad_input +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func + from flash_attn.bert_padding import unpad_input, pad_input +except ImportError: + raise ImportError("Please install flash_attn to use this module") class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/requirements.txt b/requirements.txt index 9832253..26bc641 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,8 +4,7 @@ bitsandbytes datasets sentencepiece safetensors -flash-attn -triton +einops colorama git+https://github.com/huggingface/transformers.git git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit