You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/utils/auto_config.py

24 lines
862 B
Python

from typing import Type
from transformers import AutoConfig, PretrainedConfig
CONFIG_MAPPING = {} # Populated with AutoDistributedConfig.register()
class AutoDistributedConfig:
@classmethod
def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig:
config = AutoConfig.from_pretrained(*args, **kwargs)
if config.model_type not in CONFIG_MAPPING:
raise ValueError(f"Petals does not support model type {config.model_type}")
dist_config_class = CONFIG_MAPPING[config.model_type]
return dist_config_class.from_pretrained(*args, **kwargs)
@staticmethod
def register(config_class: Type[PretrainedConfig]) -> None:
assert issubclass(config_class, PretrainedConfig)
assert config_class.model_type not in CONFIG_MAPPING
CONFIG_MAPPING[config_class.model_type] = config_class