Fix inference and rpc_info() fault tolerance (#131)

pull/133/head
Alexander Borzunov 1 year ago committed by GitHub
parent 79a4308992
commit f56edaa13f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -20,7 +20,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import StubBase
from hivemind.proto import runtime_pb2
from petals.client.routing.sequence_manager import RemoteSequenceManager
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
@ -307,12 +307,11 @@ class InferenceSession:
f"Caught exception when running inference from block {block_idx} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
traceback_level = logging.DEBUG if str(e) else logging.WARNING
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
maybe_log_traceback(e)
time.sleep(delay)
self._position += n_input_tokens
inputs = inputs[:, -n_input_tokens:]
outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
return outputs

@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import itertools
import logging
import random
@ -17,7 +18,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
import petals.dht_utils
from petals.client.routing.sequence_info import RemoteSequenceInfo
from petals.client.routing.spending_policy import NoSpendingPolicy
from petals.data_structures import ModuleUID, RemoteSpanInfo
from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
from petals.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
@ -169,8 +170,7 @@ class RemoteSequenceManager:
except Exception as e:
delay = self.get_retry_delay(attempt_no)
logger.warning(f"Could not find route through the model: {repr(e)} (retry in {delay:.0f} sec)")
traceback_level = logging.DEBUG if str(e) else logging.WARNING
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
maybe_log_traceback(e)
time.sleep(delay)
def on_request_failure(self, peer_id: PeerID):
@ -215,7 +215,16 @@ class RemoteSequenceManager:
try:
if not self.ready.is_set():
self.update(wait=True)
peer_id, _ = random.choice(list(self.sequence_info.block_infos[0].servers.items()))
active_servers = [
peer_id
for peer_id, server in self.sequence_info.block_infos[0].servers.items()
if server.state == ServerState.ONLINE
]
if not active_servers:
raise MissingBlocksError("no servers holding the first block are online")
peer_id = random.choice(active_servers)
stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
outputs = RemoteExpertWorker.run_coroutine(
stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
@ -231,8 +240,7 @@ class RemoteSequenceManager:
f"Caught exception when gathering information from peer {peer_id} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
traceback_level = logging.DEBUG if str(e) else logging.WARNING
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
maybe_log_traceback(e)
time.sleep(delay)
return self._rpc_info
@ -298,6 +306,11 @@ class _SequenceManagerUpdateThread(threading.Thread):
self.shutdown()
def maybe_log_traceback(exc: Exception):
traceback_level = logging.DEBUG if str(exc) or isinstance(exc, asyncio.TimeoutError) else logging.WARNING
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
class MissingBlocksError(Exception):
def __repr__(self):
return self.args[0]

@ -13,7 +13,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.utils.logging import get_logger
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
from petals.client.routing.sequence_manager import RemoteSequenceManager
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
@ -100,8 +100,7 @@ async def sequential_forward(
f"Caught exception when running forward from block {block_idx} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
traceback_level = logging.DEBUG if str(e) else logging.WARNING
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
maybe_log_traceback(e)
await asyncio.sleep(delay)
outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
@ -178,8 +177,7 @@ async def sequential_backward(
f"Caught exception when running backward between blocks {span.start}-{span.end} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
traceback_level = logging.DEBUG if str(e) else logging.WARNING
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
maybe_log_traceback(e)
await asyncio.sleep(delay)
# For now, we do not support mixed dummy and grad prompts

@ -146,7 +146,7 @@ def measure_compute_rps(
def get_device_name(device: torch.device) -> str:
return f"{torch.cuda.get_device_name(device)} GPU" if device == "cuda" else "CPU"
return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU"
def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str:

Loading…
Cancel
Save