From 7770e76c9c143eaab04657d082d3f3406ebede71 Mon Sep 17 00:00:00 2001 From: yamashi Date: Thu, 6 Apr 2023 17:32:01 +0200 Subject: [PATCH] Fix args of flash attention --- monkeypatch/llama_flash_attn_monkey_patch.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/monkeypatch/llama_flash_attn_monkey_patch.py b/monkeypatch/llama_flash_attn_monkey_patch.py index 1d48bb5..dd2589d 100644 --- a/monkeypatch/llama_flash_attn_monkey_patch.py +++ b/monkeypatch/llama_flash_attn_monkey_patch.py @@ -4,7 +4,7 @@ import torch from torch import nn import transformers -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb +from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb from einops import rearrange @@ -16,13 +16,14 @@ class LlamaAttention(nn.Module): def __init__( self, - hidden_size: int, - num_heads: int, + config: LlamaConfig, ): super().__init__() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads self.hidden_size = hidden_size self.num_heads = num_heads - self.head_dim = hidden_size // num_heads + self.head_dim = self.hidden_size // num_heads if (self.head_dim * num_heads) != self.hidden_size: raise ValueError(