Add position_ids to flash attention
This commit is contained in:
parent
7770e76c9c
commit
2bf5d42f28
|
|
@ -60,6 +60,7 @@ class LlamaAttention(nn.Module):
|
|||
hidden_states: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
|
|
@ -80,16 +81,15 @@ class LlamaAttention(nn.Module):
|
|||
# [bsz, nh, q_len, hd]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
offset = 0
|
||||
if past_key_value is not None:
|
||||
offset = past_key_value[0].shape[-2]
|
||||
kv_seq_len += offset
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states,
|
||||
key_states,
|
||||
cos,
|
||||
sin,
|
||||
offset=offset)
|
||||
position_ids)
|
||||
# [bsz, nh, t, hd]
|
||||
assert not output_attentions, "output_attentions is not supported"
|
||||
assert not use_cache, "use_cache is not supported"
|
||||
|
|
|
|||
Loading…
Reference in New Issue