|
|
|
@ -122,10 +122,19 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
|
|
|
|
|
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
|
|
|
|
|
return self.embed_tokens
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
|
|
|
|
|
return nn.Identity()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
|
|
|
|
|
return self.layers
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
|
|
|
|
|
return self.norm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedMixtralForCausalLM(
|
|
|
|
|
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
|
|
|
|
|