diff --git a/monkeypatch/llama_flash_attn_monkey_patch.py b/monkeypatch/llama_flash_attn_monkey_patch.py index dd2589d..0c80227 100644 --- a/monkeypatch/llama_flash_attn_monkey_patch.py +++ b/monkeypatch/llama_flash_attn_monkey_patch.py @@ -60,6 +60,7 @@ class LlamaAttention(nn.Module): hidden_states: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], @@ -80,16 +81,15 @@ class LlamaAttention(nn.Module): # [bsz, nh, q_len, hd] kv_seq_len = key_states.shape[-2] - offset = 0 if past_key_value is not None: - offset = past_key_value[0].shape[-2] - kv_seq_len += offset + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, - offset=offset) + position_ids) # [bsz, nh, t, hd] assert not output_attentions, "output_attentions is not supported" assert not use_cache, "use_cache is not supported"