|
|
|
@ -158,6 +158,11 @@ class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedMixtralForSequenceClassification(FromPretrainedMixin, MixtralForSequenceClassification):
|
|
|
|
|
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
|
|
|
|
|
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
|
|
|
|
|
|
|
|
|
|
config_class = DistributedMixtralConfig
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedMixtralConfig):
|
|
|
|
|
MixtralPreTrainedModel.__init__(self, config)
|
|
|
|
|
self.num_labels = config.num_labels
|
|
|
|
|