Fix args of flash attention

This commit is contained in:
yamashi 2023-04-06 17:32:01 +02:00 committed by GitHub
parent 30bf938d03
commit 7770e76c9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 4 deletions

View File

@ -4,7 +4,7 @@ import torch
from torch import nn
import transformers
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb
from einops import rearrange
@ -16,13 +16,14 @@ class LlamaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
config: LlamaConfig,
):
super().__init__()
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.head_dim = self.hidden_size // num_heads
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(