You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
36 lines
1.4 KiB
Python
36 lines
1.4 KiB
Python
import os
|
|
from typing import Optional, Union
|
|
|
|
from hivemind import get_logger
|
|
from transformers.models.bloom import BloomConfig
|
|
from transformers.models.bloom.modeling_bloom import BloomAttention
|
|
|
|
from petals.client.lm_head import LMHeadConfig
|
|
from petals.client.ptune import PTuneConfig
|
|
from petals.client.routing.sequence_manager import SequenceManagerConfig
|
|
from petals.models.bloom.block import WrappedBloomBlock
|
|
from petals.utils.auto_config import AutoDistributedConfig
|
|
from petals.utils.version import get_compatible_model_repo
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
|
|
block_class = WrappedBloomBlock
|
|
attn_class = BloomAttention
|
|
block_prefix = "h"
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
|
):
|
|
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
|
if loading_from_repo and dht_prefix is None:
|
|
# We need "-petals" for backward compatibility with Petals < 1.2.0
|
|
dht_prefix = str(model_name_or_path) + "-petals"
|
|
logger.info(f"Using DHT prefix: {dht_prefix}")
|
|
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
|
|
|
|
|
AutoDistributedConfig.register(DistributedBloomConfig)
|