diff --git a/monkeypatch/llama_flash_attn_monkey_patch.py b/monkeypatch/llama_flash_attn_monkey_patch.py index 1d48bb5..dd2589d 100644 --- a/monkeypatch/llama_flash_attn_monkey_patch.py +++ b/monkeypatch/llama_flash_attn_monkey_patch.py @@ -4,7 +4,7 @@ import torch from torch import nn import transformers -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb +from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb from einops import rearrange @@ -16,13 +16,14 @@ class LlamaAttention(nn.Module): def __init__( self, - hidden_size: int, - num_heads: int, + config: LlamaConfig, ): super().__init__() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads self.hidden_size = hidden_size self.num_heads = num_heads - self.head_dim = hidden_size // num_heads + self.head_dim = self.hidden_size // num_heads if (self.head_dim * num_heads) != self.hidden_size: raise ValueError(