Add AutoDistributed{Model, ModelForCausalLM, ModelForSequenceClassification} (#329)

This PR adds `petals.AutoDistributed{Model, ModelForCausalLM, ModelForSequenceClassification}` classes, similar to their `transformers.Auto{Model, ModelForCausalLM, ModelForSequenceClassification}` counterparts.
pull/330/head
Alexander Borzunov 11 months ago committed by GitHub
parent cb3f018f9f
commit 7a37513f77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,3 +5,11 @@ from petals.models.bloom.model import (
DistributedBloomForSequenceClassification,
DistributedBloomModel,
)
from petals.utils.auto_config import register_model_classes
register_model_classes(
config=DistributedBloomConfig,
model=DistributedBloomModel,
model_for_causal_lm=DistributedBloomForCausalLM,
model_for_sequence_classification=DistributedBloomForSequenceClassification,
)

@ -30,6 +30,3 @@ class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LM
dht_prefix = str(model_name_or_path) + "-petals"
logger.info(f"Using DHT prefix: {dht_prefix}")
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
AutoDistributedConfig.register(DistributedBloomConfig)

@ -5,3 +5,11 @@ from petals.models.llama.model import (
DistributedLlamaForSequenceClassification,
DistributedLlamaModel,
)
from petals.utils.auto_config import register_model_classes
register_model_classes(
config=DistributedLlamaConfig,
model=DistributedLlamaModel,
model_for_causal_lm=DistributedLlamaForCausalLM,
model_for_sequence_classification=DistributedLlamaForSequenceClassification,
)

@ -30,6 +30,3 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM
dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :]
logger.info(f"Using DHT prefix: {dht_prefix}")
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
AutoDistributedConfig.register(DistributedLlamaConfig)

@ -1 +1,6 @@
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.auto_config import (
AutoDistributedConfig,
AutoDistributedModel,
AutoDistributedModelForCausalLM,
AutoDistributedModelForSequenceClassification,
)

@ -1,23 +1,54 @@
from typing import Type
from dataclasses import dataclass
from typing import Optional, Type
from transformers import AutoConfig, PretrainedConfig
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
CONFIG_MAPPING = {} # Populated with AutoDistributedConfig.register()
@dataclass
class _ModelClasses:
config: Type[PretrainedConfig]
model: Optional[Type[PreTrainedModel]] = None
model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None
_CLASS_MAPPING = {} # Populated by petals.models.* subpackages with register_model_classes()
def register_model_classes(*, config: Type[PretrainedConfig], **kwargs):
assert issubclass(config, PretrainedConfig)
assert config.model_type not in _CLASS_MAPPING, f"Model type {config.model_type} is already registered"
_CLASS_MAPPING[config.model_type] = _ModelClasses(config=config, **kwargs)
class _AutoDistributedBase:
_mapping_field = None # Should be defined in child classes
class AutoDistributedConfig:
@classmethod
def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig:
config = AutoConfig.from_pretrained(*args, **kwargs)
if config.model_type not in CONFIG_MAPPING:
if config.model_type not in _CLASS_MAPPING:
raise ValueError(f"Petals does not support model type {config.model_type}")
dist_config_class = CONFIG_MAPPING[config.model_type]
return dist_config_class.from_pretrained(*args, **kwargs)
proper_cls = getattr(_CLASS_MAPPING[config.model_type], cls._mapping_field)
if proper_cls is None:
raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}")
return proper_cls.from_pretrained(*args, **kwargs)
class AutoDistributedConfig(_AutoDistributedBase):
_mapping_field = "config"
class AutoDistributedModel(_AutoDistributedBase):
_mapping_field = "model"
class AutoDistributedModelForCausalLM(_AutoDistributedBase):
_mapping_field = "model_for_causal_lm"
@staticmethod
def register(config_class: Type[PretrainedConfig]) -> None:
assert issubclass(config_class, PretrainedConfig)
assert config_class.model_type not in CONFIG_MAPPING
CONFIG_MAPPING[config_class.model_type] = config_class
class AutoDistributedModelForSequenceClassification(_AutoDistributedBase):
_mapping_field = "model_for_sequence_classification"

Loading…
Cancel
Save