Llama rotary dims from 4 to 2

pull/545/head
younesbelkada 6 months ago
parent 2bdbf2da58
commit fa254cff02

@ -85,8 +85,8 @@ class OptimizedLlamaAttention(LlamaAttention):
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos = cos[:, :, kv_seq_len - q_len :]
sin = sin[:, :, kv_seq_len - q_len :]
cos = cos[kv_seq_len - q_len :]
sin = sin[kv_seq_len - q_len :]
if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)

Loading…
Cancel
Save