|
|
|
@ -50,9 +50,15 @@ class OptimizedLlamaAttention(LlamaAttention):
|
|
|
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
|
|
|
output_attentions: bool = False,
|
|
|
|
|
use_cache: bool = False,
|
|
|
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
|
assert not output_attentions
|
|
|
|
|
assert position_ids is None
|
|
|
|
|
if position_ids is None:
|
|
|
|
|
past_seen_tokens = past_key_value[0].shape[2] if past_key_value is not None else 0
|
|
|
|
|
position_ids = torch.arange(
|
|
|
|
|
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
|
|
|
|
|
).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
|
|
|
|
|
|
if self.config.pretraining_tp > 1:
|
|
|
|
@ -84,9 +90,8 @@ class OptimizedLlamaAttention(LlamaAttention):
|
|
|
|
|
kv_seq_len = key_states.shape[-2]
|
|
|
|
|
if past_key_value is not None:
|
|
|
|
|
kv_seq_len += past_key_value[0].shape[-2]
|
|
|
|
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
|
|
cos = cos[kv_seq_len - q_len :]
|
|
|
|
|
sin = sin[kv_seq_len - q_len :]
|
|
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
|
|
|
|
cos, sin = cos.unsqueeze(1), sin.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
|
|
|
|
|
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
|
|
|
|
@ -160,6 +165,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
|
|
|
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
|
|
|
output_attentions: Optional[bool] = False,
|
|
|
|
|
use_cache: Optional[bool] = False,
|
|
|
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
@ -190,6 +197,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
|
|
|
|
|
past_key_value=past_key_value,
|
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
|
use_cache=use_cache,
|
|
|
|
|
cache_position=cache_position,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
|