diff --git a/src/petals/models/mixtral/block.py b/src/petals/models/mixtral/block.py index c9fcd8d..b90a39b 100644 --- a/src/petals/models/mixtral/block.py +++ b/src/petals/models/mixtral/block.py @@ -12,10 +12,11 @@ from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, Mi class WrappedMixtralBlock(MixtralDecoderLayer): def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__(config, layer_idx) + self._attn_implementation = config._attn_implementation self.sliding_window = config.sliding_window self.layer_idx = layer_idx - super().__init__(config, layer_idx) def forward( self,