From 329f7d31e87781e48039e159f32340d611c18a2f Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 14 Aug 2023 10:41:13 +0400 Subject: [PATCH] Add `blocked_servers` argument (#462) Should be used as: ```python model = AutoDistributedModelForCausalLM(model_name, blocked_servers=[peer_id1, peer_id2]) ``` --- src/petals/client/routing/sequence_manager.py | 37 +++++++++++++++---- tests/test_remote_sequential.py | 2 +- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 7328cdc..bd15c2b 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -7,7 +7,7 @@ import logging import random import threading import time -from typing import Any, Collection, Dict, List, Optional, Sequence, Union +from typing import Any, Collection, Dict, List, Optional, Sequence, Set, Union from weakref import WeakMethod import dijkstar @@ -38,6 +38,7 @@ class SequenceManagerConfig: show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True] allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers + blocked_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, do not use these servers use_server_to_server: bool = True # Use direct server-to-server communication connect_timeout: float = 5 # timeout for opening a connection @@ -116,6 +117,9 @@ class RemoteSequenceManager: self._thread_start_lock = threading.Lock() self.policy = NoSpendingPolicy() + self.allowed_servers = self._peer_ids_to_set(config.allowed_servers) + self.blocked_servers = self._peer_ids_to_set(config.blocked_servers) + self.ping_aggregator = PingAggregator(dht) if state.banned_peers is None: @@ -128,6 +132,23 @@ class RemoteSequenceManager: self._thread.ready.set() # no need to await the first dht fetch self._need_latest_infos = True + @staticmethod + def _peer_ids_to_set(peer_ids: Optional[Collection[Union[PeerID, str]]]) -> Optional[Set[PeerID]]: + if peer_ids is None: + return None + + result = set() + for peer_id in peer_ids: + if isinstance(peer_id, PeerID): + result.add(peer_id) + elif isinstance(peer_id, str): + result.add(PeerID.from_base58(peer_id)) + else: + raise TypeError( + f"`allowed_servers` and `blocked_servers` have to contain only PeerIDs or strings, but got {type(peer_id)}" + ) + return result + def make_sequence( self, start_index: int = 0, @@ -341,13 +362,13 @@ class RemoteSequenceManager: if not block_info: continue - # Apply whitelist, if defined - if self.config.allowed_servers is not None: - block_info.servers = { - peer_id: server_info - for peer_id, server_info in block_info.servers.items() - if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers - } + # Apply allow and block lists + block_info.servers = { + peer_id: server_info + for peer_id, server_info in block_info.servers.items() + if (self.allowed_servers is None or peer_id in self.allowed_servers) + and (self.blocked_servers is None or peer_id not in self.blocked_servers) + } # Remove temporarily banned peers, unless there are no peers left valid_servers = { diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 30698c5..533ba73 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -43,7 +43,7 @@ def test_remote_sequential(): assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3) (second_half_outputs * grad_proj).sum().backward() - assert torch.allclose(test_inputs.grad, full_grad, atol=1e-2) + assert torch.allclose(test_inputs.grad, full_grad, atol=3e-2) # test RemoteSequential with lossy compression block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]