You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
24 lines
862 B
Python
24 lines
862 B
Python
from typing import Type
|
|
|
|
from transformers import AutoConfig, PretrainedConfig
|
|
|
|
CONFIG_MAPPING = {} # Populated with AutoDistributedConfig.register()
|
|
|
|
|
|
class AutoDistributedConfig:
|
|
@classmethod
|
|
def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig:
|
|
config = AutoConfig.from_pretrained(*args, **kwargs)
|
|
if config.model_type not in CONFIG_MAPPING:
|
|
raise ValueError(f"Petals does not support model type {config.model_type}")
|
|
|
|
dist_config_class = CONFIG_MAPPING[config.model_type]
|
|
return dist_config_class.from_pretrained(*args, **kwargs)
|
|
|
|
@staticmethod
|
|
def register(config_class: Type[PretrainedConfig]) -> None:
|
|
assert issubclass(config_class, PretrainedConfig)
|
|
assert config_class.model_type not in CONFIG_MAPPING
|
|
|
|
CONFIG_MAPPING[config_class.model_type] = config_class
|