Fix args of flash attention
This commit is contained in:
parent
30bf938d03
commit
7770e76c9c
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
@ -16,13 +16,14 @@ class LlamaAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size: int,
|
config: LlamaConfig,
|
||||||
num_heads: int,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
num_heads = config.num_attention_heads
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_dim = hidden_size // num_heads
|
self.head_dim = self.hidden_size // num_heads
|
||||||
|
|
||||||
if (self.head_dim * num_heads) != self.hidden_size:
|
if (self.head_dim * num_heads) != self.hidden_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue