Move SequenceManagerConfig -> ClientConfig, petals.dht_utils -> petals.utils.dht (#463)
parent
568f21dc3b
commit
063e94b4c8
@ -1,4 +1,4 @@
|
|||||||
|
from petals.client.config import ClientConfig
|
||||||
from petals.client.inference_session import InferenceSession
|
from petals.client.inference_session import InferenceSession
|
||||||
from petals.client.remote_sequential import RemoteSequential
|
from petals.client.remote_sequential import RemoteSequential
|
||||||
from petals.client.routing.sequence_manager import RemoteSequenceManager
|
from petals.client.routing import NoSpendingPolicy, RemoteSequenceManager, SpendingPolicyBase
|
||||||
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
import dataclasses
|
||||||
|
from typing import Optional, Sequence, Union
|
||||||
|
|
||||||
|
from hivemind import PeerID
|
||||||
|
|
||||||
|
from petals.constants import PUBLIC_INITIAL_PEERS
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ClientConfig:
|
||||||
|
initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT
|
||||||
|
dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name)
|
||||||
|
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
|
||||||
|
|
||||||
|
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
|
||||||
|
allowed_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, send requests only to these servers
|
||||||
|
blocked_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, do not use these servers
|
||||||
|
use_server_to_server: bool = True # Use direct server-to-server communication
|
||||||
|
|
||||||
|
connect_timeout: float = 5 # timeout for opening a connection
|
||||||
|
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
|
||||||
|
update_period: float = 60 # refresh DHT information once in this many seconds
|
||||||
|
|
||||||
|
max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf)
|
||||||
|
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
|
||||||
|
max_backoff: float = 60 # limit maximal sleep time between retries to this value
|
||||||
|
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
|
||||||
|
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
|
||||||
|
|
||||||
|
max_pinged: int = 3 # max servers to ping from each sequence side, per update
|
||||||
|
ping_timeout: float = 2 # max time to wait for pings, per update
|
@ -1 +1,2 @@
|
|||||||
"""Client-side functions responsible for choosing the best server, """
|
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
|
||||||
|
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
||||||
|
@ -1,124 +1,9 @@
|
|||||||
"""
|
import warnings
|
||||||
Utilities for declaring and retrieving active model layers using a shared DHT.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import math
|
warnings.warn(
|
||||||
from functools import partial
|
"petals.dht_utils has been moved to petals.utils.dht. This alias will be removed in Petals 2.2.0+",
|
||||||
from typing import Dict, List, Optional, Sequence, Union
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
from hivemind.dht import DHT, DHTNode, DHTValue
|
from petals.utils.dht import *
|
||||||
from hivemind.p2p import PeerID
|
|
||||||
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
|
|
||||||
|
|
||||||
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def declare_active_modules(
|
|
||||||
dht: DHT,
|
|
||||||
uids: Sequence[ModuleUID],
|
|
||||||
server_info: ServerInfo,
|
|
||||||
expiration_time: DHTExpiration,
|
|
||||||
wait: bool = True,
|
|
||||||
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
|
|
||||||
"""
|
|
||||||
Declare that your node serves the specified modules; update timestamps if declared previously
|
|
||||||
|
|
||||||
:param uids: a list of module ids to declare
|
|
||||||
:param wait: if True, awaits for declaration to finish, otherwise runs in background
|
|
||||||
:param throughput: specify your performance in terms of compute throughput
|
|
||||||
:param expiration_time: declared modules will be visible for this many seconds
|
|
||||||
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
|
|
||||||
"""
|
|
||||||
if isinstance(uids, str):
|
|
||||||
uids = [uids]
|
|
||||||
if not isinstance(uids, list):
|
|
||||||
uids = list(uids)
|
|
||||||
for uid in uids:
|
|
||||||
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
|
|
||||||
|
|
||||||
return dht.run_coroutine(
|
|
||||||
partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
|
|
||||||
return_future=not wait,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _declare_active_modules(
|
|
||||||
dht: DHT,
|
|
||||||
node: DHTNode,
|
|
||||||
uids: List[ModuleUID],
|
|
||||||
server_info: ServerInfo,
|
|
||||||
expiration_time: DHTExpiration,
|
|
||||||
) -> Dict[ModuleUID, bool]:
|
|
||||||
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
|
||||||
return await node.store_many(
|
|
||||||
keys=uids,
|
|
||||||
subkeys=[dht.peer_id.to_base58()] * len(uids),
|
|
||||||
values=[server_info.to_tuple()] * len(uids),
|
|
||||||
expiration_time=expiration_time,
|
|
||||||
num_workers=num_workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_remote_module_infos(
|
|
||||||
dht: DHT,
|
|
||||||
uids: Sequence[ModuleUID],
|
|
||||||
expiration_time: Optional[DHTExpiration] = None,
|
|
||||||
active_adapter: Optional[str] = None,
|
|
||||||
*,
|
|
||||||
latest: bool = False,
|
|
||||||
return_future: bool = False,
|
|
||||||
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
|
|
||||||
return dht.run_coroutine(
|
|
||||||
partial(
|
|
||||||
_get_remote_module_infos,
|
|
||||||
uids=uids,
|
|
||||||
active_adapter=active_adapter,
|
|
||||||
expiration_time=expiration_time,
|
|
||||||
latest=latest,
|
|
||||||
),
|
|
||||||
return_future=return_future,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_remote_module_infos(
|
|
||||||
dht: DHT,
|
|
||||||
node: DHTNode,
|
|
||||||
uids: List[ModuleUID],
|
|
||||||
active_adapter: Optional[str],
|
|
||||||
expiration_time: Optional[DHTExpiration],
|
|
||||||
latest: bool,
|
|
||||||
) -> List[Optional[RemoteModuleInfo]]:
|
|
||||||
if latest:
|
|
||||||
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
|
|
||||||
expiration_time = math.inf
|
|
||||||
elif expiration_time is None:
|
|
||||||
expiration_time = get_dht_time()
|
|
||||||
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
|
||||||
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
|
|
||||||
|
|
||||||
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
|
|
||||||
for i, uid in enumerate(uids):
|
|
||||||
metadata = found[uid]
|
|
||||||
if metadata is None or not isinstance(metadata.value, dict):
|
|
||||||
if metadata is not None:
|
|
||||||
logger.warning(f"Incorrect metadata for {uid}: {metadata}")
|
|
||||||
continue
|
|
||||||
servers = {}
|
|
||||||
for peer_id, server_info in metadata.value.items():
|
|
||||||
try:
|
|
||||||
peer_id = PeerID.from_base58(peer_id)
|
|
||||||
server_info = ServerInfo.from_tuple(server_info.value)
|
|
||||||
|
|
||||||
if active_adapter and active_adapter not in server_info.adapters:
|
|
||||||
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
servers[peer_id] = server_info
|
|
||||||
except (TypeError, ValueError) as e:
|
|
||||||
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
|
|
||||||
if servers:
|
|
||||||
modules[i] = RemoteModuleInfo(uid, servers)
|
|
||||||
return modules
|
|
||||||
|
@ -0,0 +1,124 @@
|
|||||||
|
"""
|
||||||
|
Utilities for declaring and retrieving active model layers using a shared DHT.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from functools import partial
|
||||||
|
from typing import Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
from hivemind.dht import DHT, DHTNode, DHTValue
|
||||||
|
from hivemind.p2p import PeerID
|
||||||
|
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
|
||||||
|
|
||||||
|
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def declare_active_modules(
|
||||||
|
dht: DHT,
|
||||||
|
uids: Sequence[ModuleUID],
|
||||||
|
server_info: ServerInfo,
|
||||||
|
expiration_time: DHTExpiration,
|
||||||
|
wait: bool = True,
|
||||||
|
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
|
||||||
|
"""
|
||||||
|
Declare that your node serves the specified modules; update timestamps if declared previously
|
||||||
|
|
||||||
|
:param uids: a list of module ids to declare
|
||||||
|
:param wait: if True, awaits for declaration to finish, otherwise runs in background
|
||||||
|
:param throughput: specify your performance in terms of compute throughput
|
||||||
|
:param expiration_time: declared modules will be visible for this many seconds
|
||||||
|
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
|
||||||
|
"""
|
||||||
|
if isinstance(uids, str):
|
||||||
|
uids = [uids]
|
||||||
|
if not isinstance(uids, list):
|
||||||
|
uids = list(uids)
|
||||||
|
for uid in uids:
|
||||||
|
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
|
||||||
|
|
||||||
|
return dht.run_coroutine(
|
||||||
|
partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
|
||||||
|
return_future=not wait,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _declare_active_modules(
|
||||||
|
dht: DHT,
|
||||||
|
node: DHTNode,
|
||||||
|
uids: List[ModuleUID],
|
||||||
|
server_info: ServerInfo,
|
||||||
|
expiration_time: DHTExpiration,
|
||||||
|
) -> Dict[ModuleUID, bool]:
|
||||||
|
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
||||||
|
return await node.store_many(
|
||||||
|
keys=uids,
|
||||||
|
subkeys=[dht.peer_id.to_base58()] * len(uids),
|
||||||
|
values=[server_info.to_tuple()] * len(uids),
|
||||||
|
expiration_time=expiration_time,
|
||||||
|
num_workers=num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_remote_module_infos(
|
||||||
|
dht: DHT,
|
||||||
|
uids: Sequence[ModuleUID],
|
||||||
|
expiration_time: Optional[DHTExpiration] = None,
|
||||||
|
active_adapter: Optional[str] = None,
|
||||||
|
*,
|
||||||
|
latest: bool = False,
|
||||||
|
return_future: bool = False,
|
||||||
|
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
|
||||||
|
return dht.run_coroutine(
|
||||||
|
partial(
|
||||||
|
_get_remote_module_infos,
|
||||||
|
uids=uids,
|
||||||
|
active_adapter=active_adapter,
|
||||||
|
expiration_time=expiration_time,
|
||||||
|
latest=latest,
|
||||||
|
),
|
||||||
|
return_future=return_future,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_remote_module_infos(
|
||||||
|
dht: DHT,
|
||||||
|
node: DHTNode,
|
||||||
|
uids: List[ModuleUID],
|
||||||
|
active_adapter: Optional[str],
|
||||||
|
expiration_time: Optional[DHTExpiration],
|
||||||
|
latest: bool,
|
||||||
|
) -> List[Optional[RemoteModuleInfo]]:
|
||||||
|
if latest:
|
||||||
|
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
|
||||||
|
expiration_time = math.inf
|
||||||
|
elif expiration_time is None:
|
||||||
|
expiration_time = get_dht_time()
|
||||||
|
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
||||||
|
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
|
||||||
|
|
||||||
|
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
|
||||||
|
for i, uid in enumerate(uids):
|
||||||
|
metadata = found[uid]
|
||||||
|
if metadata is None or not isinstance(metadata.value, dict):
|
||||||
|
if metadata is not None:
|
||||||
|
logger.warning(f"Incorrect metadata for {uid}: {metadata}")
|
||||||
|
continue
|
||||||
|
servers = {}
|
||||||
|
for peer_id, server_info in metadata.value.items():
|
||||||
|
try:
|
||||||
|
peer_id = PeerID.from_base58(peer_id)
|
||||||
|
server_info = ServerInfo.from_tuple(server_info.value)
|
||||||
|
|
||||||
|
if active_adapter and active_adapter not in server_info.adapters:
|
||||||
|
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
servers[peer_id] = server_info
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
|
||||||
|
if servers:
|
||||||
|
modules[i] = RemoteModuleInfo(uid, servers)
|
||||||
|
return modules
|
Loading…
Reference in New Issue