|
|
|
@ -85,8 +85,8 @@ class OptimizedLlamaAttention(LlamaAttention):
|
|
|
|
|
if past_key_value is not None:
|
|
|
|
|
kv_seq_len += past_key_value[0].shape[-2]
|
|
|
|
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
|
|
cos = cos[:, :, kv_seq_len - q_len :]
|
|
|
|
|
sin = sin[:, :, kv_seq_len - q_len :]
|
|
|
|
|
cos = cos[kv_seq_len - q_len :]
|
|
|
|
|
sin = sin[kv_seq_len - q_len :]
|
|
|
|
|
|
|
|
|
|
if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
|
|
|
|
|
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
|
|
|
|
|