From 063e94b4c8027e1e8d47061681007e9db292734f Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 14 Aug 2023 17:05:20 +0400 Subject: [PATCH] Move SequenceManagerConfig -> ClientConfig, petals.dht_utils -> petals.utils.dht (#463) --- src/petals/client/__init__.py | 4 +- src/petals/client/config.py | 31 +++++ src/petals/client/inference_session.py | 7 +- src/petals/client/remote_forward_backward.py | 14 +- src/petals/client/remote_sequential.py | 5 +- src/petals/client/routing/__init__.py | 3 +- src/petals/client/routing/sequence_manager.py | 45 +++--- src/petals/client/sequential_autograd.py | 2 +- src/petals/data_structures.py | 4 +- src/petals/dht_utils.py | 129 +----------------- src/petals/models/bloom/config.py | 4 +- src/petals/models/llama/config.py | 4 +- src/petals/server/block_functions.py | 3 +- src/petals/server/handler.py | 3 +- src/petals/server/memory_cache.py | 3 +- src/petals/server/server.py | 2 +- src/petals/utils/__init__.py | 1 + src/petals/utils/dht.py | 124 +++++++++++++++++ 18 files changed, 208 insertions(+), 180 deletions(-) create mode 100644 src/petals/client/config.py create mode 100644 src/petals/utils/dht.py diff --git a/src/petals/client/__init__.py b/src/petals/client/__init__.py index f80c4b1..4b728e7 100644 --- a/src/petals/client/__init__.py +++ b/src/petals/client/__init__.py @@ -1,4 +1,4 @@ +from petals.client.config import ClientConfig from petals.client.inference_session import InferenceSession from petals.client.remote_sequential import RemoteSequential -from petals.client.routing.sequence_manager import RemoteSequenceManager -from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase +from petals.client.routing import NoSpendingPolicy, RemoteSequenceManager, SpendingPolicyBase diff --git a/src/petals/client/config.py b/src/petals/client/config.py new file mode 100644 index 0000000..e255024 --- /dev/null +++ b/src/petals/client/config.py @@ -0,0 +1,31 @@ +import dataclasses +from typing import Optional, Sequence, Union + +from hivemind import PeerID + +from petals.constants import PUBLIC_INITIAL_PEERS + + +@dataclasses.dataclass +class ClientConfig: + initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT + dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name) + daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers + + show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True] + allowed_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, send requests only to these servers + blocked_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, do not use these servers + use_server_to_server: bool = True # Use direct server-to-server communication + + connect_timeout: float = 5 # timeout for opening a connection + request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests + update_period: float = 60 # refresh DHT information once in this many seconds + + max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf) + min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1) + max_backoff: float = 60 # limit maximal sleep time between retries to this value + ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds + active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo) + + max_pinged: int = 3 # max servers to ping from each sequence side, per update + ping_timeout: float = 2 # max time to wait for pings, per update diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index e4d36f6..7f467b6 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -13,7 +13,8 @@ from hivemind.p2p import P2P from hivemind.proto import runtime_pb2 from hivemind.utils.tensor_descr import BatchTensorDescriptor -from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig, maybe_log_traceback +from petals.client.config import ClientConfig +from petals.client.routing import RemoteSequenceManager, maybe_log_traceback from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo from petals.server.handler import TransformerConnectionHandler from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy @@ -31,7 +32,7 @@ class _ServerInferenceSession: def __init__( self, - config: SequenceManagerConfig, + config: ClientConfig, span: RemoteSpanInfo, uid: ModuleUID, rpc_info: RPCInfo, @@ -58,7 +59,7 @@ class _ServerInferenceSession: @classmethod async def create( cls, - config: SequenceManagerConfig, + config: ClientConfig, p2p: P2P, span: RemoteSpanInfo, uid: ModuleUID, diff --git a/src/petals/client/remote_forward_backward.py b/src/petals/client/remote_forward_backward.py index c7cb7c2..44abe26 100644 --- a/src/petals/client/remote_forward_backward.py +++ b/src/petals/client/remote_forward_backward.py @@ -14,12 +14,12 @@ from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter from hivemind.utils.streaming import split_for_streaming from hivemind.utils.tensor_descr import BatchTensorDescriptor -from petals.client.routing.sequence_manager import SequenceManagerConfig +from petals.client.config import ClientConfig from petals.data_structures import ModuleUID, RPCInfo async def _forward_unary( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs ) -> List[torch.Tensor]: outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward( runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), @@ -29,7 +29,7 @@ async def _forward_unary( async def _backward_unary( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs ) -> List[torch.Tensor]: grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward( runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), @@ -39,7 +39,7 @@ async def _backward_unary( async def _forward_stream( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs ) -> List[torch.Tensor]: parts = ( runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) @@ -52,7 +52,7 @@ async def _forward_stream( async def _backward_stream( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs ) -> List[torch.Tensor]: parts = ( runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) @@ -69,7 +69,7 @@ async def run_remote_forward( stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, - config: SequenceManagerConfig, + config: ClientConfig, metadata: Optional[bytes] = None, **kwargs, ) -> Tuple[torch.Tensor, ...]: @@ -115,7 +115,7 @@ async def run_remote_backward( stub: StubBase, rpc_info: RPCInfo, *inputs_and_grad_outputs: torch.Tensor, - config: SequenceManagerConfig, + config: ClientConfig, metadata: Optional[bytes] = None, **kwargs, ) -> Sequence[torch.Tensor]: diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index 6ae664a..1df4a42 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -6,8 +6,9 @@ import torch from hivemind import DHT, get_logger from torch import nn +from petals.client.config import ClientConfig from petals.client.inference_session import InferenceSession -from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig +from petals.client.routing import RemoteSequenceManager from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction from petals.data_structures import UID_DELIMITER from petals.utils.misc import DUMMY @@ -22,7 +23,7 @@ class RemoteSequential(nn.Module): def __init__( self, - config: SequenceManagerConfig, + config: ClientConfig, *, sequence_manager: Optional[RemoteSequenceManager] = None, dht: Optional[DHT] = None, diff --git a/src/petals/client/routing/__init__.py b/src/petals/client/routing/__init__.py index 62885d8..3be2710 100644 --- a/src/petals/client/routing/__init__.py +++ b/src/petals/client/routing/__init__.py @@ -1 +1,2 @@ -"""Client-side functions responsible for choosing the best server, """ +from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback +from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 0c97bb2..3e239b4 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -7,7 +7,8 @@ import logging import random import threading import time -from typing import Any, Collection, Dict, List, Optional, Sequence, Set, Union +import warnings +from typing import Any, Dict, List, Optional, Sequence, Set, Union from weakref import WeakMethod import dijkstar @@ -18,41 +19,27 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.proto import runtime_pb2 from hivemind.utils.logging import get_logger -import petals.dht_utils +from petals.client.config import ClientConfig from petals.client.routing.sequence_info import RemoteSequenceInfo from petals.client.routing.spending_policy import NoSpendingPolicy -from petals.constants import PUBLIC_INITIAL_PEERS from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState from petals.server.handler import TransformerConnectionHandler +from petals.utils.dht import get_remote_module_infos from petals.utils.ping import PingAggregator from petals.utils.random import sample_up_to logger = get_logger(__name__) -@dataclasses.dataclass -class SequenceManagerConfig: - initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT - dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name) - daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers - - show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True] - allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers - blocked_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, do not use these servers - use_server_to_server: bool = True # Use direct server-to-server communication - - connect_timeout: float = 5 # timeout for opening a connection - request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests - update_period: float = 60 # refresh DHT information once in this many seconds - - max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf) - min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1) - max_backoff: float = 60 # limit maximal sleep time between retries to this value - ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds - active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo) - - max_pinged: int = 3 # max servers to ping from each sequence side, per update - ping_timeout: float = 2 # max time to wait for pings, per update +class SequenceManagerConfig(ClientConfig): + def __init__(self, *args, **kwargs): + warnings.warn( + "petals.client.routing.SequenceManagerConfig has been moved to petals.ClientConfig. " + "This alias will be removed in Petals 2.2.0+", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) @dataclasses.dataclass @@ -83,7 +70,7 @@ class RemoteSequenceManager: def __init__( self, - config: SequenceManagerConfig, + config: ClientConfig, block_uids: Sequence[ModuleUID], *, dht: Optional[DHT] = None, @@ -133,7 +120,7 @@ class RemoteSequenceManager: self._need_latest_infos = True @staticmethod - def _peer_ids_to_set(peer_ids: Optional[Collection[Union[PeerID, str]]]) -> Optional[Set[PeerID]]: + def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]: if peer_ids is None: return None @@ -354,7 +341,7 @@ class RemoteSequenceManager: def _update(self): """Perform an immediate and synchronous refresh, may take time""" - new_block_infos = petals.dht_utils.get_remote_module_infos( + new_block_infos = get_remote_module_infos( self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True ) diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 7996ff5..41bc994 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -12,7 +12,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.utils.logging import get_logger from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward -from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback +from petals.client.routing import RemoteSequenceManager, maybe_log_traceback from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo from petals.server.handler import TransformerConnectionHandler from petals.utils.misc import DUMMY, is_dummy diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index 38d706f..c1e31b4 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -6,8 +6,6 @@ import pydantic from hivemind import PeerID from hivemind.moe.expert_uid import ExpertUID -from petals.server.memory_cache import Handle - ModuleUID = str UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention" CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4" @@ -78,6 +76,8 @@ class RemoteSpanInfo: RPCInfo = Dict[str, Any] +Handle = int + @dataclasses.dataclass(frozen=True) class InferenceMetadata: diff --git a/src/petals/dht_utils.py b/src/petals/dht_utils.py index 0710f60..c884479 100644 --- a/src/petals/dht_utils.py +++ b/src/petals/dht_utils.py @@ -1,124 +1,9 @@ -""" -Utilities for declaring and retrieving active model layers using a shared DHT. -""" -from __future__ import annotations +import warnings -import math -from functools import partial -from typing import Dict, List, Optional, Sequence, Union +warnings.warn( + "petals.dht_utils has been moved to petals.utils.dht. This alias will be removed in Petals 2.2.0+", + DeprecationWarning, + stacklevel=2, +) -from hivemind.dht import DHT, DHTNode, DHTValue -from hivemind.p2p import PeerID -from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger - -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo - -logger = get_logger(__name__) - - -def declare_active_modules( - dht: DHT, - uids: Sequence[ModuleUID], - server_info: ServerInfo, - expiration_time: DHTExpiration, - wait: bool = True, -) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]: - """ - Declare that your node serves the specified modules; update timestamps if declared previously - - :param uids: a list of module ids to declare - :param wait: if True, awaits for declaration to finish, otherwise runs in background - :param throughput: specify your performance in terms of compute throughput - :param expiration_time: declared modules will be visible for this many seconds - :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected) - """ - if isinstance(uids, str): - uids = [uids] - if not isinstance(uids, list): - uids = list(uids) - for uid in uids: - assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid - - return dht.run_coroutine( - partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time), - return_future=not wait, - ) - - -async def _declare_active_modules( - dht: DHT, - node: DHTNode, - uids: List[ModuleUID], - server_info: ServerInfo, - expiration_time: DHTExpiration, -) -> Dict[ModuleUID, bool]: - num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) - return await node.store_many( - keys=uids, - subkeys=[dht.peer_id.to_base58()] * len(uids), - values=[server_info.to_tuple()] * len(uids), - expiration_time=expiration_time, - num_workers=num_workers, - ) - - -def get_remote_module_infos( - dht: DHT, - uids: Sequence[ModuleUID], - expiration_time: Optional[DHTExpiration] = None, - active_adapter: Optional[str] = None, - *, - latest: bool = False, - return_future: bool = False, -) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]: - return dht.run_coroutine( - partial( - _get_remote_module_infos, - uids=uids, - active_adapter=active_adapter, - expiration_time=expiration_time, - latest=latest, - ), - return_future=return_future, - ) - - -async def _get_remote_module_infos( - dht: DHT, - node: DHTNode, - uids: List[ModuleUID], - active_adapter: Optional[str], - expiration_time: Optional[DHTExpiration], - latest: bool, -) -> List[Optional[RemoteModuleInfo]]: - if latest: - assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both" - expiration_time = math.inf - elif expiration_time is None: - expiration_time = get_dht_time() - num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) - found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers) - - modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids) - for i, uid in enumerate(uids): - metadata = found[uid] - if metadata is None or not isinstance(metadata.value, dict): - if metadata is not None: - logger.warning(f"Incorrect metadata for {uid}: {metadata}") - continue - servers = {} - for peer_id, server_info in metadata.value.items(): - try: - peer_id = PeerID.from_base58(peer_id) - server_info = ServerInfo.from_tuple(server_info.value) - - if active_adapter and active_adapter not in server_info.adapters: - logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}") - continue - - servers[peer_id] = server_info - except (TypeError, ValueError) as e: - logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}") - if servers: - modules[i] = RemoteModuleInfo(uid, servers) - return modules +from petals.utils.dht import * diff --git a/src/petals/models/bloom/config.py b/src/petals/models/bloom/config.py index 494c187..cc5c839 100644 --- a/src/petals/models/bloom/config.py +++ b/src/petals/models/bloom/config.py @@ -5,15 +5,15 @@ from hivemind import get_logger from transformers.models.bloom import BloomConfig from transformers.models.bloom.modeling_bloom import BloomAttention +from petals.client.config import ClientConfig from petals.client.lm_head import LMHeadConfig from petals.client.ptune import PTuneConfig -from petals.client.routing.sequence_manager import SequenceManagerConfig from petals.models.bloom.block import WrappedBloomBlock logger = get_logger(__name__) -class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig): +class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfig): block_class = WrappedBloomBlock attn_class = BloomAttention block_prefix = "h" diff --git a/src/petals/models/llama/config.py b/src/petals/models/llama/config.py index 241525a..c5144c2 100644 --- a/src/petals/models/llama/config.py +++ b/src/petals/models/llama/config.py @@ -5,15 +5,15 @@ from hivemind import get_logger from transformers.models.llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaAttention +from petals.client.config import ClientConfig from petals.client.lm_head import LMHeadConfig from petals.client.ptune import PTuneConfig -from petals.client.routing.sequence_manager import SequenceManagerConfig from petals.models.llama.block import WrappedLlamaBlock logger = get_logger(__name__) -class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig): +class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfig): block_class = WrappedLlamaBlock attn_class = LlamaAttention block_prefix = "model.layers" diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index f3e512f..c682663 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -12,9 +12,8 @@ from hivemind.proto import runtime_pb2 from hivemind.utils.logging import get_logger from hivemind.utils.nested import nested_flatten -from petals.data_structures import InferenceMetadata +from petals.data_structures import Handle, InferenceMetadata from petals.server.backend import TransformerBackend -from petals.server.memory_cache import Handle from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_prioritizer import TaskPrioritizerBase from petals.utils.convert_block import QuantType diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index c4db8ef..0dd63bd 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -29,10 +29,9 @@ from hivemind.utils.logging import get_logger from hivemind.utils.streaming import split_for_streaming import petals -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, Handle, ModuleUID from petals.server.backend import TransformerBackend from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward -from petals.server.memory_cache import Handle from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase from petals.utils.convert_block import QuantType diff --git a/src/petals/server/memory_cache.py b/src/petals/server/memory_cache.py index c2aa192..9e79f17 100644 --- a/src/petals/server/memory_cache.py +++ b/src/petals/server/memory_cache.py @@ -16,12 +16,11 @@ import hivemind import torch from hivemind.utils import TensorDescriptor, get_logger +from petals.data_structures import Handle from petals.utils.asyncio import shield_and_wait logger = get_logger(__name__) -Handle = int - class MemoryCache: """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs""" diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 405dd9b..ba0403c 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -20,7 +20,6 @@ from transformers import PretrainedConfig import petals from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState -from petals.dht_utils import declare_active_modules, get_remote_module_infos from petals.server import block_selection from petals.server.backend import TransformerBackend, merge_inference_pools_inplace from petals.server.block_utils import get_block_size, resolve_block_dtype @@ -31,6 +30,7 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha from petals.server.throughput import get_dtype_name, get_server_throughput from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, check_device_balance, convert_block +from petals.utils.dht import declare_active_modules, get_remote_module_infos from petals.utils.ping import PingAggregator from petals.utils.random import sample_up_to from petals.utils.version import get_compatible_model_repo diff --git a/src/petals/utils/__init__.py b/src/petals/utils/__init__.py index 0852074..c8aa484 100644 --- a/src/petals/utils/__init__.py +++ b/src/petals/utils/__init__.py @@ -4,3 +4,4 @@ from petals.utils.auto_config import ( AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification, ) +from petals.utils.dht import declare_active_modules, get_remote_module_infos diff --git a/src/petals/utils/dht.py b/src/petals/utils/dht.py new file mode 100644 index 0000000..0710f60 --- /dev/null +++ b/src/petals/utils/dht.py @@ -0,0 +1,124 @@ +""" +Utilities for declaring and retrieving active model layers using a shared DHT. +""" +from __future__ import annotations + +import math +from functools import partial +from typing import Dict, List, Optional, Sequence, Union + +from hivemind.dht import DHT, DHTNode, DHTValue +from hivemind.p2p import PeerID +from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger + +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo + +logger = get_logger(__name__) + + +def declare_active_modules( + dht: DHT, + uids: Sequence[ModuleUID], + server_info: ServerInfo, + expiration_time: DHTExpiration, + wait: bool = True, +) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]: + """ + Declare that your node serves the specified modules; update timestamps if declared previously + + :param uids: a list of module ids to declare + :param wait: if True, awaits for declaration to finish, otherwise runs in background + :param throughput: specify your performance in terms of compute throughput + :param expiration_time: declared modules will be visible for this many seconds + :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected) + """ + if isinstance(uids, str): + uids = [uids] + if not isinstance(uids, list): + uids = list(uids) + for uid in uids: + assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid + + return dht.run_coroutine( + partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time), + return_future=not wait, + ) + + +async def _declare_active_modules( + dht: DHT, + node: DHTNode, + uids: List[ModuleUID], + server_info: ServerInfo, + expiration_time: DHTExpiration, +) -> Dict[ModuleUID, bool]: + num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) + return await node.store_many( + keys=uids, + subkeys=[dht.peer_id.to_base58()] * len(uids), + values=[server_info.to_tuple()] * len(uids), + expiration_time=expiration_time, + num_workers=num_workers, + ) + + +def get_remote_module_infos( + dht: DHT, + uids: Sequence[ModuleUID], + expiration_time: Optional[DHTExpiration] = None, + active_adapter: Optional[str] = None, + *, + latest: bool = False, + return_future: bool = False, +) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]: + return dht.run_coroutine( + partial( + _get_remote_module_infos, + uids=uids, + active_adapter=active_adapter, + expiration_time=expiration_time, + latest=latest, + ), + return_future=return_future, + ) + + +async def _get_remote_module_infos( + dht: DHT, + node: DHTNode, + uids: List[ModuleUID], + active_adapter: Optional[str], + expiration_time: Optional[DHTExpiration], + latest: bool, +) -> List[Optional[RemoteModuleInfo]]: + if latest: + assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both" + expiration_time = math.inf + elif expiration_time is None: + expiration_time = get_dht_time() + num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) + found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers) + + modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids) + for i, uid in enumerate(uids): + metadata = found[uid] + if metadata is None or not isinstance(metadata.value, dict): + if metadata is not None: + logger.warning(f"Incorrect metadata for {uid}: {metadata}") + continue + servers = {} + for peer_id, server_info in metadata.value.items(): + try: + peer_id = PeerID.from_base58(peer_id) + server_info = ServerInfo.from_tuple(server_info.value) + + if active_adapter and active_adapter not in server_info.adapters: + logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}") + continue + + servers[peer_id] = server_info + except (TypeError, ValueError) as e: + logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}") + if servers: + modules[i] = RemoteModuleInfo(uid, servers) + return modules