Add position_ids to flash attention

This commit is contained in:
yamashi 2023-04-06 17:46:15 +02:00 committed by GitHub
parent 7770e76c9c
commit 2bf5d42f28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 4 deletions

View File

@ -60,6 +60,7 @@ class LlamaAttention(nn.Module):
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
@ -80,16 +81,15 @@ class LlamaAttention(nn.Module):
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
offset = 0
if past_key_value is not None:
offset = past_key_value[0].shape[-2]
kv_seq_len += offset
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states,
key_states,
cos,
sin,
offset=offset)
position_ids)
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"