diff --git a/src/petals/models/mixtral/model.py b/src/petals/models/mixtral/model.py index 7e127ab..9840204 100644 --- a/src/petals/models/mixtral/model.py +++ b/src/petals/models/mixtral/model.py @@ -122,10 +122,19 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin return self.embed_tokens + @property + def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests + return nn.Identity() + @property def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin return self.layers + @property + def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests + return self.norm + + class DistributedMixtralForCausalLM( DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM