|
|
|
@ -6,7 +6,7 @@ import logging
|
|
|
|
|
import random
|
|
|
|
|
import threading
|
|
|
|
|
import time
|
|
|
|
|
from typing import Any, Dict, List, Optional, Sequence, Union
|
|
|
|
|
from typing import Any, Collection, Dict, List, Optional, Sequence, Union
|
|
|
|
|
from weakref import WeakMethod
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
@ -40,9 +40,10 @@ class RemoteSequenceManager:
|
|
|
|
|
:param update_period: by default, refresh DHT information once in this many seconds
|
|
|
|
|
:param request_timeout: float, in seconds, default timeout for RPC forward/backward/inference requests
|
|
|
|
|
:param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1)
|
|
|
|
|
:param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds
|
|
|
|
|
:param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht
|
|
|
|
|
:param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time
|
|
|
|
|
:param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds
|
|
|
|
|
:param allowed_servers: if defined, send requests only to these servers
|
|
|
|
|
:param start: start the background thread (see the note below). If false, you will need to start it manually.
|
|
|
|
|
:note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
|
|
|
|
|
running redundant sequence managers for the same set of layers.
|
|
|
|
@ -56,21 +57,30 @@ class RemoteSequenceManager:
|
|
|
|
|
p2p: P2P,
|
|
|
|
|
update_period: float = 30,
|
|
|
|
|
request_timeout: float = 30,
|
|
|
|
|
max_retries: Optional[int] = None,
|
|
|
|
|
min_backoff: float = 1,
|
|
|
|
|
ban_timeout: float = 15,
|
|
|
|
|
sequence_info: Optional[RemoteSequenceInfo] = None,
|
|
|
|
|
rpc_info: Optional[dict] = None,
|
|
|
|
|
allowed_servers: Optional[Collection[Union[str, hivemind.PeerID]]] = None,
|
|
|
|
|
banned_peers: Optional[Blacklist] = None,
|
|
|
|
|
*, # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
|
|
|
|
|
start: bool,
|
|
|
|
|
):
|
|
|
|
|
assert len(block_uids) > 0, "Sequences must contain at least one block"
|
|
|
|
|
self.dht, self.p2p = dht, p2p
|
|
|
|
|
self.request_timeout, self.ban_timeout, self.min_backoff = request_timeout, ban_timeout, min_backoff
|
|
|
|
|
self.request_timeout, self.max_retries = request_timeout, max_retries
|
|
|
|
|
self.ban_timeout, self.min_backoff = ban_timeout, min_backoff
|
|
|
|
|
self.lock_changes = threading.Lock()
|
|
|
|
|
self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update))
|
|
|
|
|
self.policy = NoSpendingPolicy()
|
|
|
|
|
self._rpc_info = rpc_info
|
|
|
|
|
|
|
|
|
|
if allowed_servers is not None:
|
|
|
|
|
allowed_servers = {
|
|
|
|
|
PeerID.from_base58(peer_id) if isinstance(peer_id, str) else peer_id for peer_id in allowed_servers
|
|
|
|
|
}
|
|
|
|
|
self.allowed_servers = allowed_servers
|
|
|
|
|
self.banned_peers = Blacklist(base_time=ban_timeout, backoff_rate=2.0) if banned_peers is None else banned_peers
|
|
|
|
|
|
|
|
|
|
if sequence_info is None:
|
|
|
|
@ -148,6 +158,7 @@ class RemoteSequenceManager:
|
|
|
|
|
min_backoff=self.min_backoff,
|
|
|
|
|
sequence_info=self.sequence_info[ix],
|
|
|
|
|
rpc_info=self._rpc_info,
|
|
|
|
|
allowed_servers=self.allowed_servers,
|
|
|
|
|
banned_peers=self.banned_peers,
|
|
|
|
|
start=True,
|
|
|
|
|
)
|
|
|
|
@ -169,6 +180,16 @@ class RemoteSequenceManager:
|
|
|
|
|
for block_info in new_block_infos:
|
|
|
|
|
if not block_info:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Apply whitelist, if defined
|
|
|
|
|
if self.allowed_servers is not None:
|
|
|
|
|
block_info.servers = {
|
|
|
|
|
peer_id: server_info
|
|
|
|
|
for peer_id, server_info in block_info.servers.items()
|
|
|
|
|
if peer_id in self.allowed_servers
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Remove temporarily banned peers, unless there are no peers left
|
|
|
|
|
valid_servers = {
|
|
|
|
|
peer_id: server_info
|
|
|
|
|
for peer_id, server_info in block_info.servers.items()
|
|
|
|
@ -260,6 +281,8 @@ class RemoteSequenceManager:
|
|
|
|
|
except Exception as e:
|
|
|
|
|
if peer_id is not None and not isinstance(e, P2PHandlerError):
|
|
|
|
|
self.on_request_failure(peer_id)
|
|
|
|
|
if attempt_no + 1 == self.max_retries:
|
|
|
|
|
raise
|
|
|
|
|
delay = self.get_retry_delay(attempt_no)
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Caught exception when gathering information from peer {peer_id} "
|
|
|
|
|