Make client ignore blacklist if all servers holding a block are blacklisted (#197)

If all servers holding a certain block are blacklisted, we should display errors from them instead of raising `No peers holding blocks`.

Indeed, if the error is client-caused, the client should learn its reason from the latest error messages. In turn, if the error is server/network-caused and we only have a few servers, we'd better know the error instead of banning all the servers and making the user think that no servers are available.
pull/196/head^2
Alexander Borzunov 1 year ago committed by GitHub
parent 127cf66bee
commit b4f3224cda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,7 +17,7 @@ from hivemind import (
serialize_torch_tensor,
)
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2PHandlerError, StubBase
from hivemind.p2p import StubBase
from hivemind.proto import runtime_pb2
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
@ -305,7 +305,7 @@ class InferenceSession:
self._sequence_manager.on_request_success(span.peer_id)
break
except Exception as e:
if span is not None and not isinstance(e, P2PHandlerError):
if span is not None:
self._sequence_manager.on_request_failure(span.peer_id)
delay = self._sequence_manager.get_retry_delay(attempt_no)
logger.warning(

@ -156,10 +156,20 @@ class RemoteSequenceManager:
for block_info in new_block_infos:
if not block_info:
continue
for peer_id in tuple(block_info.servers.keys()):
if peer_id in self.banned_peers:
logger.debug(f"Ignoring banned {peer_id} for block {block_info.uid}")
block_info.servers.pop(peer_id)
valid_servers = {
peer_id: server_info
for peer_id, server_info in block_info.servers.items()
if peer_id not in self.banned_peers
}
if len(valid_servers) < len(block_info.servers):
if valid_servers:
logger.debug(
f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}"
)
block_info.servers = valid_servers
else:
# If we blacklisted all servers, the error may actually be client-caused
logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist")
with self.lock_changes:
self.sequence_info.update_(new_block_infos)

@ -10,7 +10,6 @@ from typing import List, Optional, Sequence, Tuple
import torch
from hivemind import MSGPackSerializer
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2PHandlerError
from hivemind.utils.logging import get_logger
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
@ -94,7 +93,7 @@ async def sequential_forward(
sequence_manager.on_request_success(span.peer_id)
break
except Exception as e:
if span is not None and not isinstance(e, P2PHandlerError):
if span is not None:
sequence_manager.on_request_failure(span.peer_id)
delay = sequence_manager.get_retry_delay(attempt_no)
logger.warning(
@ -171,7 +170,7 @@ async def sequential_backward(
sequence_manager.on_request_success(span.peer_id)
break
except Exception as e:
if span is not None and not isinstance(e, P2PHandlerError):
if span is not None:
sequence_manager.on_request_failure(span.peer_id)
delay = sequence_manager.get_retry_delay(attempt_no)
logger.warning(

Loading…
Cancel
Save