|
|
|
@ -7,7 +7,7 @@ import logging
|
|
|
|
|
import random
|
|
|
|
|
import threading
|
|
|
|
|
import time
|
|
|
|
|
from typing import Any, Collection, Dict, List, Optional, Sequence, Union
|
|
|
|
|
from typing import Any, Collection, Dict, List, Optional, Sequence, Set, Union
|
|
|
|
|
from weakref import WeakMethod
|
|
|
|
|
|
|
|
|
|
import dijkstar
|
|
|
|
@ -38,6 +38,7 @@ class SequenceManagerConfig:
|
|
|
|
|
|
|
|
|
|
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
|
|
|
|
|
allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers
|
|
|
|
|
blocked_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, do not use these servers
|
|
|
|
|
use_server_to_server: bool = True # Use direct server-to-server communication
|
|
|
|
|
|
|
|
|
|
connect_timeout: float = 5 # timeout for opening a connection
|
|
|
|
@ -116,6 +117,9 @@ class RemoteSequenceManager:
|
|
|
|
|
self._thread_start_lock = threading.Lock()
|
|
|
|
|
self.policy = NoSpendingPolicy()
|
|
|
|
|
|
|
|
|
|
self.allowed_servers = self._peer_ids_to_set(config.allowed_servers)
|
|
|
|
|
self.blocked_servers = self._peer_ids_to_set(config.blocked_servers)
|
|
|
|
|
|
|
|
|
|
self.ping_aggregator = PingAggregator(dht)
|
|
|
|
|
|
|
|
|
|
if state.banned_peers is None:
|
|
|
|
@ -128,6 +132,23 @@ class RemoteSequenceManager:
|
|
|
|
|
self._thread.ready.set() # no need to await the first dht fetch
|
|
|
|
|
self._need_latest_infos = True
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _peer_ids_to_set(peer_ids: Optional[Collection[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
|
|
|
|
|
if peer_ids is None:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
result = set()
|
|
|
|
|
for peer_id in peer_ids:
|
|
|
|
|
if isinstance(peer_id, PeerID):
|
|
|
|
|
result.add(peer_id)
|
|
|
|
|
elif isinstance(peer_id, str):
|
|
|
|
|
result.add(PeerID.from_base58(peer_id))
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f"`allowed_servers` and `blocked_servers` have to contain only PeerIDs or strings, but got {type(peer_id)}"
|
|
|
|
|
)
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
def make_sequence(
|
|
|
|
|
self,
|
|
|
|
|
start_index: int = 0,
|
|
|
|
@ -341,13 +362,13 @@ class RemoteSequenceManager:
|
|
|
|
|
if not block_info:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Apply whitelist, if defined
|
|
|
|
|
if self.config.allowed_servers is not None:
|
|
|
|
|
block_info.servers = {
|
|
|
|
|
peer_id: server_info
|
|
|
|
|
for peer_id, server_info in block_info.servers.items()
|
|
|
|
|
if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers
|
|
|
|
|
}
|
|
|
|
|
# Apply allow and block lists
|
|
|
|
|
block_info.servers = {
|
|
|
|
|
peer_id: server_info
|
|
|
|
|
for peer_id, server_info in block_info.servers.items()
|
|
|
|
|
if (self.allowed_servers is None or peer_id in self.allowed_servers)
|
|
|
|
|
and (self.blocked_servers is None or peer_id not in self.blocked_servers)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Remove temporarily banned peers, unless there are no peers left
|
|
|
|
|
valid_servers = {
|
|
|
|
|