Fix args of flash attention
This commit is contained in:
parent
30bf938d03
commit
7770e76c9c
|
|
@ -4,7 +4,7 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
import transformers
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
|
|
@ -16,13 +16,14 @@ class LlamaAttention(nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
config: LlamaConfig,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
num_heads = config.num_attention_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.head_dim = self.hidden_size // num_heads
|
||||
|
||||
if (self.head_dim * num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
|
|
|
|||
Loading…
Reference in New Issue