From 2ca531623ce4233d377205c6d0f3601939db0f32 Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Mon, 8 Apr 2024 19:40:03 +0200 Subject: [PATCH] Return compatibility with tests --- src/petals/models/mixtral/model.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/petals/models/mixtral/model.py b/src/petals/models/mixtral/model.py index 7e127ab..9840204 100644 --- a/src/petals/models/mixtral/model.py +++ b/src/petals/models/mixtral/model.py @@ -122,10 +122,19 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin return self.embed_tokens + @property + def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests + return nn.Identity() + @property def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin return self.layers + @property + def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests + return self.norm + + class DistributedMixtralForCausalLM( DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM