Penalize servers that use relays during rebalancing (#428)

Servers accessible only via relays may introduce issues if they are the only type of servers holding certain blocks. Specifically, a connection to such servers may be unstable or opened after a certain delay.

This PR changes their self-reported throughput, so that the rebalancing algorithm prefers to put directly available servers for hosting each block.
pull/430/head
Alexander Borzunov 9 months ago committed by GitHub
parent 6a1b8a6a90
commit 351e96bc46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -292,9 +292,7 @@ class RemoteSequenceManager:
# This is okay since false positives are more costly than false negatives here.
return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
def _make_sequence_with_max_throughput(
self, start_index: int, end_index: int, *, relay_penalty: float = 0.5
) -> List[RemoteSpanInfo]:
def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
span_sequence = []
current_index = start_index
while current_index < end_index:
@ -302,13 +300,7 @@ class RemoteSequenceManager:
if not candidate_spans:
raise MissingBlocksError(current_index)
span_weights = np.array(
[
span.server_info.throughput * (1 if not span.server_info.using_relay else relay_penalty)
for span in candidate_spans
],
dtype=np.float64,
)
span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
assert chosen_span.start <= current_index < chosen_span.end

@ -83,7 +83,7 @@ class Server:
quant_type: Optional[QuantType] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
skip_reachability_check: bool = False,
dht_client_mode: Optional[bool] = None,
reachable_via_relay: Optional[bool] = None,
use_relay: bool = True,
use_auto_relay: bool = True,
adapters: Sequence[str] = (),
@ -129,20 +129,20 @@ class Server:
for block_index in range(self.block_config.num_hidden_layers)
]
if dht_client_mode is None:
if reachable_via_relay is None:
is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
dht_client_mode = is_reachable is False # if could not check reachability (returns None), run a full peer
logger.info(f"This server is accessible {'via relays' if dht_client_mode else 'directly'}")
reachable_via_relay = is_reachable is False # if can't check reachability (returns None), run a full peer
logger.info(f"This server is accessible {'via relays' if reachable_via_relay else 'directly'}")
self.dht = DHT(
initial_peers=initial_peers,
start=True,
num_workers=self.block_config.num_hidden_layers,
use_relay=use_relay,
use_auto_relay=use_auto_relay,
client_mode=dht_client_mode,
client_mode=reachable_via_relay,
**kwargs,
)
self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not dht_client_mode else None
self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not reachable_via_relay else None
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
if initial_peers == PUBLIC_INITIAL_PEERS:
@ -227,6 +227,7 @@ class Server:
num_blocks=num_blocks,
quant_type=quant_type,
tensor_parallel_devices=self.tensor_parallel_devices,
reachable_via_relay=reachable_via_relay,
force_eval=(throughput == "eval"),
cache_dir=cache_dir,
)
@ -239,7 +240,7 @@ class Server:
adapters=tuple(adapters),
torch_dtype=str(torch_dtype).replace("torch.", ""),
quant_type=quant_type.name.lower(),
using_relay=self.dht.client_mode,
using_relay=reachable_via_relay,
**throughput_info,
)

@ -41,6 +41,8 @@ def get_server_throughput(
num_blocks: int,
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
reachable_via_relay: bool,
relay_penalty: float = 0.2,
force_eval: bool = False,
cache_dir: Optional[str] = None,
) -> Dict[str, float]:
@ -94,7 +96,10 @@ def get_server_throughput(
# E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2
average_blocks_used = (num_blocks + 1) / 2
throughput = throughput_info["forward_rps"] / average_blocks_used
throughput = min(throughput, throughput_info.get("network_rps", math.inf))
network_rps = throughput_info["network_rps"] * (relay_penalty if reachable_via_relay else 1)
throughput = min(throughput, network_rps)
throughput_info["throughput"] = throughput
logger.info(f"Reporting throughput: {throughput:.1f} tokens/sec for {num_blocks} blocks")

Loading…
Cancel
Save