From 2a150770a47d2ee8fd9f9bf129a54d1afdf29e64 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 7 Aug 2023 21:43:21 +0400 Subject: [PATCH] Prefer longer servers for fine-tuning, exclude unreachable (#448) We choose longer servers to minimize the number of hops but leave some randomization to distribute the load. We also exclude servers known to be unreachable. --- src/petals/client/routing/sequence_manager.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index b19d468..7328cdc 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -50,7 +50,7 @@ class SequenceManagerConfig: ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo) - max_pinged: int = 5 # max servers to ping from each sequence side, per update + max_pinged: int = 3 # max servers to ping from each sequence side, per update ping_timeout: float = 2 # max time to wait for pings, per update @@ -293,6 +293,8 @@ class RemoteSequenceManager: return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]: + client_server_rtts = self.ping_aggregator.to_dict() + span_sequence = [] current_index = start_index while current_index < end_index: @@ -300,7 +302,13 @@ class RemoteSequenceManager: if not candidate_spans: raise MissingBlocksError(current_index) - span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64) + # We choose longer servers to minimize the number of hops but leave some randomization + # to distribute the load. We also exclude servers known to be unreachable. + eps = 1e-6 + span_weights = np.array( + [span.length if client_server_rtts.get(span.peer_id) != np.inf else eps for span in candidate_spans], + dtype=np.float64, + ) chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum()) assert chosen_span.start <= current_index < chosen_span.end @@ -361,9 +369,13 @@ class RemoteSequenceManager: self.state.sequence_info.update_(new_block_infos) first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]] + middle_servers = [ + span.peer_id for spans in self.state.sequence_info.spans_containing_block[1:-1] for span in spans + ] last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]] pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged)) + pinged_servers = set(sample_up_to(middle_servers, self.config.max_pinged)) pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged)) self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)