Return compatibility with tests

pull/570/head
Artem Chumachenko 1 month ago
parent 0f498814be
commit 2ca531623c

@ -122,10 +122,19 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
return self.embed_tokens
@property
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
return nn.Identity()
@property
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
return self.layers
@property
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
return self.norm
class DistributedMixtralForCausalLM(
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM

Loading…
Cancel
Save