|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
import itertools
|
|
|
|
|
import logging
|
|
|
|
|
import random
|
|
|
|
@ -17,7 +18,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
|
import petals.dht_utils
|
|
|
|
|
from petals.client.routing.sequence_info import RemoteSequenceInfo
|
|
|
|
|
from petals.client.routing.spending_policy import NoSpendingPolicy
|
|
|
|
|
from petals.data_structures import ModuleUID, RemoteSpanInfo
|
|
|
|
|
from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
|
|
|
|
|
from petals.server.handler import TransformerConnectionHandler
|
|
|
|
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
|
@ -169,8 +170,7 @@ class RemoteSequenceManager:
|
|
|
|
|
except Exception as e:
|
|
|
|
|
delay = self.get_retry_delay(attempt_no)
|
|
|
|
|
logger.warning(f"Could not find route through the model: {repr(e)} (retry in {delay:.0f} sec)")
|
|
|
|
|
traceback_level = logging.DEBUG if str(e) else logging.WARNING
|
|
|
|
|
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
|
|
|
|
|
maybe_log_traceback(e)
|
|
|
|
|
time.sleep(delay)
|
|
|
|
|
|
|
|
|
|
def on_request_failure(self, peer_id: PeerID):
|
|
|
|
@ -215,7 +215,16 @@ class RemoteSequenceManager:
|
|
|
|
|
try:
|
|
|
|
|
if not self.ready.is_set():
|
|
|
|
|
self.update(wait=True)
|
|
|
|
|
peer_id, _ = random.choice(list(self.sequence_info.block_infos[0].servers.items()))
|
|
|
|
|
|
|
|
|
|
active_servers = [
|
|
|
|
|
peer_id
|
|
|
|
|
for peer_id, server in self.sequence_info.block_infos[0].servers.items()
|
|
|
|
|
if server.state == ServerState.ONLINE
|
|
|
|
|
]
|
|
|
|
|
if not active_servers:
|
|
|
|
|
raise MissingBlocksError("no servers holding the first block are online")
|
|
|
|
|
peer_id = random.choice(active_servers)
|
|
|
|
|
|
|
|
|
|
stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
|
|
|
|
|
outputs = RemoteExpertWorker.run_coroutine(
|
|
|
|
|
stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
|
|
|
|
@ -231,8 +240,7 @@ class RemoteSequenceManager:
|
|
|
|
|
f"Caught exception when gathering information from peer {peer_id} "
|
|
|
|
|
f"(retry in {delay:.0f} sec): {repr(e)}"
|
|
|
|
|
)
|
|
|
|
|
traceback_level = logging.DEBUG if str(e) else logging.WARNING
|
|
|
|
|
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
|
|
|
|
|
maybe_log_traceback(e)
|
|
|
|
|
time.sleep(delay)
|
|
|
|
|
|
|
|
|
|
return self._rpc_info
|
|
|
|
@ -298,6 +306,11 @@ class _SequenceManagerUpdateThread(threading.Thread):
|
|
|
|
|
self.shutdown()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def maybe_log_traceback(exc: Exception):
|
|
|
|
|
traceback_level = logging.DEBUG if str(exc) or isinstance(exc, asyncio.TimeoutError) else logging.WARNING
|
|
|
|
|
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MissingBlocksError(Exception):
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return self.args[0]
|
|
|
|
|