@ -114,7 +114,7 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
class WrappedLlamaBlock(LlamaDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,