Add position_ids to flash attention
This commit is contained in:
parent
7770e76c9c
commit
2bf5d42f28
|
|
@ -60,6 +60,7 @@ class LlamaAttention(nn.Module):
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
|
|
@ -80,16 +81,15 @@ class LlamaAttention(nn.Module):
|
||||||
# [bsz, nh, q_len, hd]
|
# [bsz, nh, q_len, hd]
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
offset = 0
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
offset = past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
kv_seq_len += offset
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states,
|
query_states, key_states = apply_rotary_pos_emb(query_states,
|
||||||
key_states,
|
key_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
offset=offset)
|
position_ids)
|
||||||
# [bsz, nh, t, hd]
|
# [bsz, nh, t, hd]
|
||||||
assert not output_attentions, "output_attentions is not supported"
|
assert not output_attentions, "output_attentions is not supported"
|
||||||
assert not use_cache, "use_cache is not supported"
|
assert not use_cache, "use_cache is not supported"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue