Use length-weighted sampling in routing for inference (#204)

This pull-request implements a simple (1) greedy (2) latency-agnostic routing optimization that should speed up both our use cases.

Why this exists: our effort to merge full routing (ping-aware, throughut-aware, dijkstra) is in a sorry state between several branches; merging it into main would take many days.

Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
pull/205/head
justheuristic 1 year ago committed by GitHub
parent 42d1bbb568
commit 012f840f7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -255,7 +255,7 @@ class InferenceSession:
)
recovery_until = max(recovery_until, update_end)
updated_spans = self._sequence_manager.make_sequence(block_idx, update_end)
updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="fastest")
# make_sequence() could return a longer sequence
updated_spans[-1].end = min(updated_spans[-1].end, update_end)
updated_sessions = self._enter_server_sessions(updated_spans)

@ -9,6 +9,7 @@ import time
from typing import Any, Dict, List, Optional, Sequence, Union
from weakref import WeakMethod
import numpy as np
from hivemind import DHT, P2P, MSGPackSerializer, PeerID
from hivemind.dht.node import Blacklist
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
@ -92,12 +93,15 @@ class RemoteSequenceManager:
if await_ready:
self._thread.ready.wait(timeout)
def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
def make_sequence(
self, start_index: int = 0, end_index: Optional[int] = None, mode: str = "random"
) -> List[RemoteSpanInfo]:
"""
Form a sequence of remote servers that collectively serve all consecutive layers
:param start_index: optional index of the first module in a sequence, default = the first of block_uids
:param end_index: optional index of the last module (non-inclusive), default = after last of block uids
:param mode: either random or fastest
"""
if not self.is_alive():
logger.error("Using a sequence manager that is not running: it has either crashed or never started")
@ -110,7 +114,14 @@ class RemoteSequenceManager:
current_index = start_index
while current_index < end_index:
candidate_spans = self.sequence_info.spans_containing_block[current_index]
chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
if mode == "random":
chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
elif mode == "fastest":
# note: this too is a heuristic that will be replaced once we integrate fastest wall time routing
span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64)
chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
else:
raise RuntimeError(f"Unexpected mode {mode}")
assert chosen_span.start <= current_index < chosen_span.end
span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))

@ -60,7 +60,7 @@ async def sequential_forward(
span = None
try:
if not sequences or attempt_no >= 1:
sequences = deque(sequence_manager.make_sequence(block_idx, end_index))
sequences = deque(sequence_manager.make_sequence(block_idx, end_index, mode="random"))
# make_sequence() could return a longer sequence
sequences[-1].end = min(sequences[-1].end, end_index)
logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers")

@ -14,7 +14,8 @@ logger = get_logger(__file__)
@pytest.mark.forked
def test_sequence_manager_shutdown():
@pytest.mark.parametrize("mode", ["fastest", "random"])
def test_sequence_manager_basics(mode: str):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
sequential = RemoteSequential(config, dht)
@ -28,7 +29,7 @@ def test_sequence_manager_shutdown():
sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt, start=True),
)
sequence = sequential.sequence_manager.make_sequence()
sequence = sequential.sequence_manager.make_sequence(mode=mode)
assert all(sequence[i].peer_id != sequence[i + 1].peer_id for i in range(len(sequence) - 1))
assert sequential.sequence_manager.is_alive()

Loading…
Cancel
Save