Fix buffer registration

pull/500/head
Max Ryabinin 9 months ago
parent ce401f1163
commit 2c1452de5c

@ -16,7 +16,6 @@ from transformers.models.falcon.modeling_falcon import (
FalconDecoderLayer,
FalconLinear,
FalconMLP,
FalconRotaryEmbedding,
LayerNorm,
dropout_add,
rotate_half,
@ -30,9 +29,14 @@ def apply_rotary(query, key, cos, sin):
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
class OptimizedFalconRotaryEmbedding(FalconRotaryEmbedding):
class OptimizedFalconRotaryEmbedding(nn.Module):
def __init__(self, head_dim: int, base=10000):
super().__init__(head_dim, base)
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.head_dim = head_dim
self.seq_len_cached = -1
self.cuda_graph = None
self.input_surface = None
self.static_outputs = None

Loading…
Cancel
Save