diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index b7616a1..46898b9 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -114,7 +114,7 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer): self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) -class WrappedLlamaBlock(OptimizedLlamaDecoderLayer): +class WrappedLlamaBlock(LlamaDecoderLayer): def forward( self, hidden_states: torch.Tensor,