|
|
|
@ -43,6 +43,13 @@ class OptimizedFalconRotaryEmbedding(nn.Module):
|
|
|
|
|
self.input_surface = None
|
|
|
|
|
self.static_outputs = None
|
|
|
|
|
|
|
|
|
|
self.cos_sin(
|
|
|
|
|
seq_len=INFERENCE_MAX_LENGTH,
|
|
|
|
|
past_key_values_length=0,
|
|
|
|
|
device=self.inv_freq.device,
|
|
|
|
|
dtype=torch.get_default_dtype(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _optimized_apply_rotary(self, query, key, cos, sin):
|
|
|
|
|
if self.cuda_graph is None:
|
|
|
|
|
self.cuda_graph = torch.cuda.CUDAGraph()
|
|
|
|
@ -80,11 +87,11 @@ class OptimizedFalconRotaryEmbedding(nn.Module):
|
|
|
|
|
emb = emb.float()
|
|
|
|
|
|
|
|
|
|
self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype))
|
|
|
|
|
self.register_buffer("sin_cached", emb.cos()[None, :, :].type(dtype))
|
|
|
|
|
self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype))
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
|
|
|
|
|
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
|
|
|
|
|
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
|
|
|
|
|
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, query, key, past_key_values_length=0):
|
|
|
|
|