@ -50,9 +50,15 @@ class OptimizedLlamaAttention(LlamaAttention):
past_key_value : Optional [ Tuple [ torch . Tensor ] ] = None ,
output_attentions : bool = False ,
use_cache : bool = False ,
cache_position : Optional [ torch . LongTensor ] = None ,
) - > Tuple [ torch . Tensor , Optional [ torch . Tensor ] , Optional [ Tuple [ torch . Tensor ] ] ] :
assert not output_attentions
assert position_ids is None
if position_ids is None :
past_seen_tokens = past_key_value [ 0 ] . shape [ 2 ] if past_key_value is not None else 0
position_ids = torch . arange (
past_seen_tokens , past_seen_tokens + hidden_states . shape [ 1 ] , device = hidden_states . device
) . unsqueeze ( 0 )
bsz , q_len , _ = hidden_states . size ( )
if self . config . pretraining_tp > 1 :
@ -84,9 +90,8 @@ class OptimizedLlamaAttention(LlamaAttention):
kv_seq_len = key_states . shape [ - 2 ]
if past_key_value is not None :
kv_seq_len + = past_key_value [ 0 ] . shape [ - 2 ]
cos , sin = self . rotary_emb ( value_states , seq_len = kv_seq_len )
cos = cos [ kv_seq_len - q_len : ]
sin = sin [ kv_seq_len - q_len : ]
cos , sin = self . rotary_emb ( value_states , position_ids , seq_len = kv_seq_len )
cos , sin = cos . unsqueeze ( 1 ) , sin . unsqueeze ( 1 )
if q_len == 1 and torch . is_inference_mode_enabled ( ) and hidden_states . device . type == " cuda " :
query_states , key_states = self . _optimized_apply_rotary ( query_states , key_states , cos , sin )
@ -160,6 +165,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
past_key_value : Optional [ Tuple [ torch . Tensor ] ] = None ,
output_attentions : Optional [ bool ] = False ,
use_cache : Optional [ bool ] = False ,
cache_position : Optional [ torch . LongTensor ] = None ,
* * kwargs ,
) - > Tuple [ torch . FloatTensor , Optional [ Tuple [ torch . FloatTensor , torch . FloatTensor ] ] ] :
"""
Args :
@ -190,6 +197,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
past_key_value = past_key_value ,
output_attentions = output_attentions ,
use_cache = use_cache ,
cache_position = cache_position ,
* * kwargs ,
)
hidden_states = residual + hidden_states