Add AutoDistributed{Model, ModelForCausalLM, ModelForSequenceClassification} (#329)
This PR adds `petals.AutoDistributed{Model, ModelForCausalLM, ModelForSequenceClassification}` classes, similar to their `transformers.Auto{Model, ModelForCausalLM, ModelForSequenceClassification}` counterparts.pull/330/head
parent
cb3f018f9f
commit
7a37513f77
@ -1 +1,6 @@
|
||||
from petals.utils.auto_config import AutoDistributedConfig
|
||||
from petals.utils.auto_config import (
|
||||
AutoDistributedConfig,
|
||||
AutoDistributedModel,
|
||||
AutoDistributedModelForCausalLM,
|
||||
AutoDistributedModelForSequenceClassification,
|
||||
)
|
||||
|
@ -1,23 +1,54 @@
|
||||
from typing import Type
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Type
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
||||
|
||||
CONFIG_MAPPING = {} # Populated with AutoDistributedConfig.register()
|
||||
|
||||
@dataclass
|
||||
class _ModelClasses:
|
||||
config: Type[PretrainedConfig]
|
||||
model: Optional[Type[PreTrainedModel]] = None
|
||||
model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
|
||||
model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None
|
||||
|
||||
|
||||
_CLASS_MAPPING = {} # Populated by petals.models.* subpackages with register_model_classes()
|
||||
|
||||
|
||||
def register_model_classes(*, config: Type[PretrainedConfig], **kwargs):
|
||||
assert issubclass(config, PretrainedConfig)
|
||||
assert config.model_type not in _CLASS_MAPPING, f"Model type {config.model_type} is already registered"
|
||||
|
||||
_CLASS_MAPPING[config.model_type] = _ModelClasses(config=config, **kwargs)
|
||||
|
||||
|
||||
class _AutoDistributedBase:
|
||||
_mapping_field = None # Should be defined in child classes
|
||||
|
||||
class AutoDistributedConfig:
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig:
|
||||
config = AutoConfig.from_pretrained(*args, **kwargs)
|
||||
if config.model_type not in CONFIG_MAPPING:
|
||||
if config.model_type not in _CLASS_MAPPING:
|
||||
raise ValueError(f"Petals does not support model type {config.model_type}")
|
||||
|
||||
dist_config_class = CONFIG_MAPPING[config.model_type]
|
||||
return dist_config_class.from_pretrained(*args, **kwargs)
|
||||
proper_cls = getattr(_CLASS_MAPPING[config.model_type], cls._mapping_field)
|
||||
if proper_cls is None:
|
||||
raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}")
|
||||
|
||||
return proper_cls.from_pretrained(*args, **kwargs)
|
||||
|
||||
|
||||
class AutoDistributedConfig(_AutoDistributedBase):
|
||||
_mapping_field = "config"
|
||||
|
||||
|
||||
class AutoDistributedModel(_AutoDistributedBase):
|
||||
_mapping_field = "model"
|
||||
|
||||
|
||||
class AutoDistributedModelForCausalLM(_AutoDistributedBase):
|
||||
_mapping_field = "model_for_causal_lm"
|
||||
|
||||
@staticmethod
|
||||
def register(config_class: Type[PretrainedConfig]) -> None:
|
||||
assert issubclass(config_class, PretrainedConfig)
|
||||
assert config_class.model_type not in CONFIG_MAPPING
|
||||
|
||||
CONFIG_MAPPING[config_class.model_type] = config_class
|
||||
class AutoDistributedModelForSequenceClassification(_AutoDistributedBase):
|
||||
_mapping_field = "model_for_sequence_classification"
|
||||
|
Loading…
Reference in New Issue