From 351e96bc469b260c86d30cd823a262a0b71be66e Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 3 Aug 2023 02:00:43 +0200 Subject: [PATCH] Penalize servers that use relays during rebalancing (#428) Servers accessible only via relays may introduce issues if they are the only type of servers holding certain blocks. Specifically, a connection to such servers may be unstable or opened after a certain delay. This PR changes their self-reported throughput, so that the rebalancing algorithm prefers to put directly available servers for hosting each block. --- src/petals/client/routing/sequence_manager.py | 12 ++---------- src/petals/server/server.py | 15 ++++++++------- src/petals/server/throughput.py | 7 ++++++- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index a7a0f1d..b19d468 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -292,9 +292,7 @@ class RemoteSequenceManager: # This is okay since false positives are more costly than false negatives here. return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left - def _make_sequence_with_max_throughput( - self, start_index: int, end_index: int, *, relay_penalty: float = 0.5 - ) -> List[RemoteSpanInfo]: + def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]: span_sequence = [] current_index = start_index while current_index < end_index: @@ -302,13 +300,7 @@ class RemoteSequenceManager: if not candidate_spans: raise MissingBlocksError(current_index) - span_weights = np.array( - [ - span.server_info.throughput * (1 if not span.server_info.using_relay else relay_penalty) - for span in candidate_spans - ], - dtype=np.float64, - ) + span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64) chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum()) assert chosen_span.start <= current_index < chosen_span.end diff --git a/src/petals/server/server.py b/src/petals/server/server.py index fec6e82..5c47270 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -83,7 +83,7 @@ class Server: quant_type: Optional[QuantType] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None, skip_reachability_check: bool = False, - dht_client_mode: Optional[bool] = None, + reachable_via_relay: Optional[bool] = None, use_relay: bool = True, use_auto_relay: bool = True, adapters: Sequence[str] = (), @@ -129,20 +129,20 @@ class Server: for block_index in range(self.block_config.num_hidden_layers) ] - if dht_client_mode is None: + if reachable_via_relay is None: is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs) - dht_client_mode = is_reachable is False # if could not check reachability (returns None), run a full peer - logger.info(f"This server is accessible {'via relays' if dht_client_mode else 'directly'}") + reachable_via_relay = is_reachable is False # if can't check reachability (returns None), run a full peer + logger.info(f"This server is accessible {'via relays' if reachable_via_relay else 'directly'}") self.dht = DHT( initial_peers=initial_peers, start=True, num_workers=self.block_config.num_hidden_layers, use_relay=use_relay, use_auto_relay=use_auto_relay, - client_mode=dht_client_mode, + client_mode=reachable_via_relay, **kwargs, ) - self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not dht_client_mode else None + self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not reachable_via_relay else None visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()] if initial_peers == PUBLIC_INITIAL_PEERS: @@ -227,6 +227,7 @@ class Server: num_blocks=num_blocks, quant_type=quant_type, tensor_parallel_devices=self.tensor_parallel_devices, + reachable_via_relay=reachable_via_relay, force_eval=(throughput == "eval"), cache_dir=cache_dir, ) @@ -239,7 +240,7 @@ class Server: adapters=tuple(adapters), torch_dtype=str(torch_dtype).replace("torch.", ""), quant_type=quant_type.name.lower(), - using_relay=self.dht.client_mode, + using_relay=reachable_via_relay, **throughput_info, ) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index fbea3d2..d977611 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -41,6 +41,8 @@ def get_server_throughput( num_blocks: int, quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], + reachable_via_relay: bool, + relay_penalty: float = 0.2, force_eval: bool = False, cache_dir: Optional[str] = None, ) -> Dict[str, float]: @@ -94,7 +96,10 @@ def get_server_throughput( # E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2 average_blocks_used = (num_blocks + 1) / 2 throughput = throughput_info["forward_rps"] / average_blocks_used - throughput = min(throughput, throughput_info.get("network_rps", math.inf)) + + network_rps = throughput_info["network_rps"] * (relay_penalty if reachable_via_relay else 1) + throughput = min(throughput, network_rps) + throughput_info["throughput"] = throughput logger.info(f"Reporting throughput: {throughput:.1f} tokens/sec for {num_blocks} blocks")