From b4f3224cda1405a002643d25e42ab30313b1c266 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 11 Jan 2023 16:50:24 +0400 Subject: [PATCH] Make client ignore blacklist if all servers holding a block are blacklisted (#197) If all servers holding a certain block are blacklisted, we should display errors from them instead of raising `No peers holding blocks`. Indeed, if the error is client-caused, the client should learn its reason from the latest error messages. In turn, if the error is server/network-caused and we only have a few servers, we'd better know the error instead of banning all the servers and making the user think that no servers are available. --- src/petals/client/inference_session.py | 4 ++-- src/petals/client/routing/sequence_manager.py | 18 ++++++++++++++---- src/petals/client/sequential_autograd.py | 5 ++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index b7a068b..3d41b6f 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -17,7 +17,7 @@ from hivemind import ( serialize_torch_tensor, ) from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker -from hivemind.p2p import P2PHandlerError, StubBase +from hivemind.p2p import StubBase from hivemind.proto import runtime_pb2 from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback @@ -305,7 +305,7 @@ class InferenceSession: self._sequence_manager.on_request_success(span.peer_id) break except Exception as e: - if span is not None and not isinstance(e, P2PHandlerError): + if span is not None: self._sequence_manager.on_request_failure(span.peer_id) delay = self._sequence_manager.get_retry_delay(attempt_no) logger.warning( diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index bb93158..d77a575 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -156,10 +156,20 @@ class RemoteSequenceManager: for block_info in new_block_infos: if not block_info: continue - for peer_id in tuple(block_info.servers.keys()): - if peer_id in self.banned_peers: - logger.debug(f"Ignoring banned {peer_id} for block {block_info.uid}") - block_info.servers.pop(peer_id) + valid_servers = { + peer_id: server_info + for peer_id, server_info in block_info.servers.items() + if peer_id not in self.banned_peers + } + if len(valid_servers) < len(block_info.servers): + if valid_servers: + logger.debug( + f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}" + ) + block_info.servers = valid_servers + else: + # If we blacklisted all servers, the error may actually be client-caused + logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist") with self.lock_changes: self.sequence_info.update_(new_block_infos) diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 8ee786d..30c20ad 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -10,7 +10,6 @@ from typing import List, Optional, Sequence, Tuple import torch from hivemind import MSGPackSerializer from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker -from hivemind.p2p import P2PHandlerError from hivemind.utils.logging import get_logger from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward @@ -94,7 +93,7 @@ async def sequential_forward( sequence_manager.on_request_success(span.peer_id) break except Exception as e: - if span is not None and not isinstance(e, P2PHandlerError): + if span is not None: sequence_manager.on_request_failure(span.peer_id) delay = sequence_manager.get_retry_delay(attempt_no) logger.warning( @@ -171,7 +170,7 @@ async def sequential_backward( sequence_manager.on_request_success(span.peer_id) break except Exception as e: - if span is not None and not isinstance(e, P2PHandlerError): + if span is not None: sequence_manager.on_request_failure(span.peer_id) delay = sequence_manager.get_retry_delay(attempt_no) logger.warning(