Fix rotary embeddings

pull/500/head
Max Ryabinin 9 months ago
parent 841a0d5262
commit d56f57acd2

@ -43,6 +43,13 @@ class OptimizedFalconRotaryEmbedding(nn.Module):
self.input_surface = None
self.static_outputs = None
self.cos_sin(
seq_len=INFERENCE_MAX_LENGTH,
past_key_values_length=0,
device=self.inv_freq.device,
dtype=torch.get_default_dtype(),
)
def _optimized_apply_rotary(self, query, key, cos, sin):
if self.cuda_graph is None:
self.cuda_graph = torch.cuda.CUDAGraph()
@ -80,11 +87,11 @@ class OptimizedFalconRotaryEmbedding(nn.Module):
emb = emb.float()
self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype))
self.register_buffer("sin_cached", emb.cos()[None, :, :].type(dtype))
self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype))
return (
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
)
def forward(self, query, key, past_key_values_length=0):

Loading…
Cancel
Save