|
|
|
@ -40,6 +40,7 @@ class RemoteSequenceManager:
|
|
|
|
|
:param update_period: by default, refresh DHT information once in this many seconds
|
|
|
|
|
:param request_timeout: float, in seconds, default timeout for RPC forward/backward/inference requests
|
|
|
|
|
:param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1)
|
|
|
|
|
:param max_backoff: limit maximal sleep time between retries to this value
|
|
|
|
|
:param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds
|
|
|
|
|
:param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht
|
|
|
|
|
:param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time
|
|
|
|
@ -59,6 +60,7 @@ class RemoteSequenceManager:
|
|
|
|
|
request_timeout: float = 30,
|
|
|
|
|
max_retries: Optional[int] = None,
|
|
|
|
|
min_backoff: float = 1,
|
|
|
|
|
max_backoff: float = 15 * 60,
|
|
|
|
|
ban_timeout: float = 15,
|
|
|
|
|
sequence_info: Optional[RemoteSequenceInfo] = None,
|
|
|
|
|
rpc_info: Optional[dict] = None,
|
|
|
|
@ -70,7 +72,7 @@ class RemoteSequenceManager:
|
|
|
|
|
assert len(block_uids) > 0, "Sequences must contain at least one block"
|
|
|
|
|
self.dht, self.p2p = dht, p2p
|
|
|
|
|
self.request_timeout, self.max_retries = request_timeout, max_retries
|
|
|
|
|
self.ban_timeout, self.min_backoff = ban_timeout, min_backoff
|
|
|
|
|
self.ban_timeout, self.min_backoff, self.max_backoff = ban_timeout, min_backoff, max_backoff
|
|
|
|
|
self.lock_changes = threading.Lock()
|
|
|
|
|
self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update))
|
|
|
|
|
self.policy = NoSpendingPolicy()
|
|
|
|
@ -156,6 +158,7 @@ class RemoteSequenceManager:
|
|
|
|
|
request_timeout=self.request_timeout,
|
|
|
|
|
ban_timeout=self.ban_timeout,
|
|
|
|
|
min_backoff=self.min_backoff,
|
|
|
|
|
max_backoff=self.max_backoff,
|
|
|
|
|
sequence_info=self.sequence_info[ix],
|
|
|
|
|
rpc_info=self._rpc_info,
|
|
|
|
|
allowed_servers=self.allowed_servers,
|
|
|
|
@ -296,7 +299,7 @@ class RemoteSequenceManager:
|
|
|
|
|
def get_retry_delay(self, attempt_no: int) -> float:
|
|
|
|
|
if attempt_no == 0:
|
|
|
|
|
return 0
|
|
|
|
|
return self.min_backoff * 2 ** (attempt_no - 1)
|
|
|
|
|
return min(self.min_backoff * 2 ** (attempt_no - 1), self.max_backoff)
|
|
|
|
|
|
|
|
|
|
def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
|
|
|
|
|
"""
|
|
|
|
|