|
|
|
@ -135,9 +135,7 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
|
|
|
|
|
return self.norm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedMixtralForCausalLM(
|
|
|
|
|
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
|
|
|
|
|
):
|
|
|
|
|
class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM):
|
|
|
|
|
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
|
|
|
|
|
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
|
|
|
|
|
|
|
|
|
@ -159,9 +157,12 @@ class DistributedMixtralForCausalLM(
|
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedMixtralForSequenceClassification(
|
|
|
|
|
DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification
|
|
|
|
|
):
|
|
|
|
|
class DistributedMixtralForSequenceClassification(FromPretrainedMixin, MixtralForSequenceClassification):
|
|
|
|
|
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
|
|
|
|
|
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
|
|
|
|
|
|
|
|
|
|
config_class = DistributedMixtralConfig
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedMixtralConfig):
|
|
|
|
|
MixtralPreTrainedModel.__init__(self, config)
|
|
|
|
|
self.num_labels = config.num_labels
|
|
|
|
|