|
|
|
@ -16,7 +16,6 @@ from transformers.models.falcon.modeling_falcon import (
|
|
|
|
|
FalconDecoderLayer,
|
|
|
|
|
FalconLinear,
|
|
|
|
|
FalconMLP,
|
|
|
|
|
FalconRotaryEmbedding,
|
|
|
|
|
LayerNorm,
|
|
|
|
|
dropout_add,
|
|
|
|
|
rotate_half,
|
|
|
|
@ -30,9 +29,14 @@ def apply_rotary(query, key, cos, sin):
|
|
|
|
|
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OptimizedFalconRotaryEmbedding(FalconRotaryEmbedding):
|
|
|
|
|
class OptimizedFalconRotaryEmbedding(nn.Module):
|
|
|
|
|
def __init__(self, head_dim: int, base=10000):
|
|
|
|
|
super().__init__(head_dim, base)
|
|
|
|
|
super().__init__()
|
|
|
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
|
|
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
|
self.head_dim = head_dim
|
|
|
|
|
self.seq_len_cached = -1
|
|
|
|
|
|
|
|
|
|
self.cuda_graph = None
|
|
|
|
|
self.input_surface = None
|
|
|
|
|
self.static_outputs = None
|
|
|
|
|