From 62d9ed5ce7b08f621d7947679087fd54e8df723b Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 18 Jul 2023 08:46:36 +0400 Subject: [PATCH] Implement shortest-path routing for inference (#362) This PR: 1. **Adds shortest path routing for inference.** We build a graph with client-server and server-server latencies and compute costs, as well as empirically measured overheads. For client-server latencies, we ping possible first and last servers in a sequence in `SequenceManager.update()`. We penalize servers who may not have enough cache for our request. This uses info added to DHT in #355, #356, #358. 2. **Makes a server ping neighboring servers in addition to next ones.** This is to get an opportunity to change the server even before we use all its blocks (e.g., because a neighboring server is faster). This feature is not enabled though, since it increases graph size for N servers to O(N^2) - but we may enable it if needed. 3. **Fixes a `SequenceManager` bug with the first `update()`.** Previously, this update was likely to produce incorrect information and cause to `MissingBlocksErrors` until the next update happens. --- setup.cfg | 1 + src/petals/__init__.py | 2 +- src/petals/cli/run_server.py | 2 +- src/petals/client/inference_session.py | 4 +- src/petals/client/routing/sequence_manager.py | 196 ++++++++++++++++-- src/petals/server/server.py | 31 +-- src/petals/utils/ping.py | 29 +-- src/petals/utils/random.py | 12 ++ 8 files changed, 222 insertions(+), 55 deletions(-) create mode 100644 src/petals/utils/random.py diff --git a/setup.cfg b/setup.cfg index 7fc930e..10f56b5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,6 +48,7 @@ install_requires = sentencepiece>=0.1.99 peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735 safetensors>=0.3.1 + Dijkstar>=2.6.0 [options.extras_require] dev = diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 3e67633..d02dbeb 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -11,7 +11,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.2.0.dev2" +__version__ = "1.2.0.dev3" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index f2a0168..ce69974 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -84,7 +84,7 @@ def main(): parser.add_argument('--attn_cache_tokens', type=int, default=8192, help='The number of past attention key/value pairs that will be stored between inference steps. ' 'Default: 8192 (4 simultaneous sessions of up to 2048 tokens).') - parser.add_argument('--alloc_timeout', type=float, default=60, + parser.add_argument('--alloc_timeout', type=float, default=5, help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed ' 'before rejecting the request') parser.add_argument('--revision', type=str, default=None, diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 8c2dfc9..0e5d6b4 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -340,7 +340,9 @@ class InferenceSession: f"from block {block_idx} to {update_end} will be regenerated" ) - updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency") + updated_spans = self._sequence_manager.make_sequence( + block_idx, update_end, mode="min_latency", cache_tokens_needed=self._max_length + ) # make_sequence() could return a longer sequence updated_spans[-1].end = min(updated_spans[-1].end, update_end) updated_sessions = self._enter_server_sessions(updated_spans) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 19b475b..5b1ab3f 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -10,6 +10,7 @@ import time from typing import Any, Collection, Dict, List, Optional, Sequence, Union from weakref import WeakMethod +import dijkstar import numpy as np from hivemind import DHT, P2P, MSGPackSerializer, PeerID from hivemind.dht.node import Blacklist @@ -23,6 +24,8 @@ from petals.client.routing.spending_policy import NoSpendingPolicy from petals.constants import PUBLIC_INITIAL_PEERS from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState from petals.server.handler import TransformerConnectionHandler +from petals.utils.ping import PingAggregator +from petals.utils.random import sample_up_to logger = get_logger(__name__) @@ -33,6 +36,7 @@ class SequenceManagerConfig: dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name) daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers + show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True] allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers use_server_to_server: bool = True # Use direct server-to-server communication @@ -43,7 +47,10 @@ class SequenceManagerConfig: min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1) max_backoff: float = 60 # limit maximal sleep time between retries to this value ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds - active_adapter: Optional[str] = None + active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo) + + max_pinged: int = 5 # max servers to ping from each sequence side, per update + ping_timeout: float = 2 # max time to wait for pings, per update @dataclasses.dataclass @@ -79,7 +86,6 @@ class RemoteSequenceManager: *, dht: Optional[DHT] = None, state: Optional[SequenceManagerState] = None, - active_adapter: Optional[str] = None, ): assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`" assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..." @@ -94,7 +100,7 @@ class RemoteSequenceManager: dht = DHT( initial_peers=config.initial_peers, client_mode=True, - num_workers=config.num_hidden_layers, + num_workers=32, startup_timeout=config.daemon_startup_timeout, start=True, ) @@ -109,25 +115,25 @@ class RemoteSequenceManager: self._thread_start_lock = threading.Lock() self.policy = NoSpendingPolicy() + self.ping_aggregator = PingAggregator(dht) + if state.banned_peers is None: state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0) if state.sequence_info is None: state.sequence_info = RemoteSequenceInfo.make_empty(block_uids) - if state.sequence_info.last_updated_time is None: - # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records - # in the first _update() instead of the latest ones. This makes the first .update() faster. - petals.dht_utils.get_remote_module_infos( - self.dht, self.block_uids, active_adapter=active_adapter, latest=True, return_future=True - ) - self._need_latest_infos = False - else: + if state.sequence_info.last_updated_time is not None: assert block_uids == state.sequence_info.block_uids self._thread.ready.set() # no need to await the first dht fetch self._need_latest_infos = True def make_sequence( - self, start_index: int = 0, end_index: Optional[int] = None, *, mode: str + self, + start_index: int = 0, + end_index: Optional[int] = None, + *, + mode: str, + cache_tokens_needed: Optional[int] = None, ) -> List[RemoteSpanInfo]: """ Form a sequence of remote servers that collectively serve all consecutive layers @@ -143,6 +149,150 @@ class RemoteSequenceManager: self.update(wait=True) # this will await an existing update or trigger a new one (if not updating) end_index = end_index if end_index is not None else len(self) + + if mode == "min_latency": + span_sequence = self._make_sequence_with_min_latency( + start_index, end_index, cache_tokens_needed=cache_tokens_needed + ) + elif mode == "max_throughput": + span_sequence = self._make_sequence_with_max_throughput(start_index, end_index) + else: + raise RuntimeError(f"Unexpected mode {mode}") + + if self.config.show_route is True or (mode == "min_latency" and self.config.show_route == "inference"): + route_repr = " => ".join( + [f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence] + ) + logger.info(f"Route found: {route_repr}") + return span_sequence + + def _make_sequence_with_min_latency( + self, start_index: int, end_index: int, *, cache_tokens_needed: Optional[int] + ) -> List[RemoteSpanInfo]: + if start_index == end_index: + return [] + + with self.lock_changes: + missing_blocks = [ + block_idx + for block_idx in range(start_index, end_index) + if not self.state.sequence_info.spans_containing_block[block_idx] + ] + if missing_blocks: + raise MissingBlocksError(missing_blocks) + server_infos = { + span.peer_id: span.server_info + for block_idx in range(start_index, end_index) + for span in self.state.sequence_info.spans_containing_block[block_idx] + } + + graph = self._build_inference_graph(start_index, end_index, cache_tokens_needed=cache_tokens_needed) + + path = dijkstar.find_path(graph, "start", "end") + logger.debug(f"Path info: {path}") + if start_index == 0 and end_index == len(self): + logger.debug(f"Expected speed: {1 / path.total_cost:.1f} steps/sec") + + span_sequence = [] + for peer_id, block_idx in path.nodes[1:-1]: + if not span_sequence or span_sequence[-1].peer_id != peer_id: + span_sequence.append(RemoteSpanInfo(peer_id, block_idx, block_idx, server_infos[peer_id])) + else: + span_sequence[-1].end = block_idx + + # Remove empty spans that can appear if we don't force to go to the end of each server and network delay + # don't follow triangle inequality (delay(A, B) + delay(B, C) < delay(A, C)) due to measurement errors + span_sequence = [span for span in span_sequence if span.length > 0] + + return span_sequence + + def _build_inference_graph( + self, + start_index: int, + end_index: int, + *, + cache_tokens_needed: Optional[int], + overhead_coeff: float = 1.82, # Backend overhead (empirically measured) + overhead_delay: float = 0.018, # Serialization overhead (empirically measured) + default_inference_rps: float = 300, # If inference RPS unknown + alloc_delay: float = 10, # If not enough cache left, we penalize the edge + ) -> dijkstar.Graph: + missing_blocks = [ + block_idx + for block_idx in range(start_index, end_index) + if not self.state.sequence_info.spans_containing_block[block_idx] + ] + if missing_blocks: + raise MissingBlocksError(missing_blocks) + + client_server_rtts = self.ping_aggregator.to_dict() + + graph = dijkstar.Graph() + + # Clent -> server network delays + for span in self.state.sequence_info.spans_containing_block[start_index]: + delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id)) + delay += overhead_delay + if not self._has_cache_for(span, cache_tokens_needed): + delay += alloc_delay + graph.add_edge("start", (span.peer_id, start_index), delay) + + # Server -> client network delays + for span in self.state.sequence_info.spans_containing_block[end_index - 1]: + delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id)) + graph.add_edge((span.peer_id, end_index), "end", delay) + + # Server -> server network delays + for block_idx in range(start_index + 1, end_index): + for cur_span in self.state.sequence_info.spans_containing_block[block_idx - 1]: + if cur_span.end != block_idx: + # If we choose a server, we force to go to the end of it before switching to a new one + # to avoid O(N^2) graphs for N servers + continue + + for next_span in self.state.sequence_info.spans_containing_block[block_idx]: + rtt = None + if cur_span.server_info.next_pings is not None: + rtt = cur_span.server_info.next_pings.get(next_span.peer_id.to_base58()) + delay = self._rtt_to_delay(rtt) + delay += overhead_delay + if not self._has_cache_for(next_span, cache_tokens_needed): + delay += alloc_delay + graph.add_edge((cur_span.peer_id, block_idx), (next_span.peer_id, block_idx), delay) + + # Compute delays + for span in self.state.sequence_info.spans_by_priority: + for block_idx in range(max(span.start, start_index), min(span.end, end_index)): + inference_rps = span.server_info.inference_rps + if inference_rps is None: + inference_rps = default_inference_rps + graph.add_edge((span.peer_id, block_idx), (span.peer_id, block_idx + 1), overhead_coeff / inference_rps) + + return graph + + @staticmethod + def _rtt_to_delay( + rtt: float, + *, + default_delay: float = 0.15, # If network delay unknown + max_delay: float = 5, # If unreachable, we don't want to discard the edge completely + ) -> float: + if rtt is None: + return default_delay + return min(rtt / 2, max_delay) + + @staticmethod + def _has_cache_for(span: RemoteSpanInfo, cache_tokens_needed: Optional[int] = None) -> bool: + if cache_tokens_needed is None or span.server_info.cache_tokens_left is None: + return True + + # Here, `span` contains all blocks hosted by a server - but we won't necessarily run all of them through + # this particular server in our path. It is difficult to estimate how many blocks we'll use at this stage, + # so we assume that we'll use all of them (the worst case for the cache size) and get a pessimistic estimate. + # This is okay since false positives are more costly than false negatives here. + return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left + + def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]: span_sequence = [] current_index = start_index while current_index < end_index: @@ -150,20 +300,12 @@ class RemoteSequenceManager: if not candidate_spans: raise MissingBlocksError(current_index) - if mode == "max_throughput": - span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64) - elif mode == "min_latency": - span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64) - else: - raise RuntimeError(f"Unexpected mode {mode}") + span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64) chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum()) assert chosen_span.start <= current_index < chosen_span.end span_sequence.append(dataclasses.replace(chosen_span, start=current_index)) current_index = chosen_span.end - - route_repr = " => ".join([f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence]) - logger.debug(f"Route found: {route_repr}") return span_sequence def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager: @@ -182,10 +324,10 @@ class RemoteSequenceManager: def _update(self): """Perform an immediate and synchronous refresh, may take time""" + new_block_infos = petals.dht_utils.get_remote_module_infos( - self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=self._need_latest_infos + self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True ) - self._need_latest_infos = True # All future _update() should use latest infos for block_info in new_block_infos: if not block_info: @@ -217,6 +359,14 @@ class RemoteSequenceManager: with self.lock_changes: self.state.sequence_info.update_(new_block_infos) + + first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]] + last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]] + + pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged)) + pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged)) + self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout) + self.ready.set() def on_request_failure(self, peer_id: Optional[PeerID]): diff --git a/src/petals/server/server.py b/src/petals/server/server.py index aea57c7..0793dff 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -32,6 +32,7 @@ from petals.server.throughput import get_dtype_name, get_server_throughput from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, check_device_balance, convert_block from petals.utils.ping import PingAggregator +from petals.utils.random import sample_up_to from petals.utils.version import get_compatible_model_repo logger = get_logger(__name__) @@ -61,7 +62,7 @@ class Server: cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, attn_cache_tokens: int = 8192, - alloc_timeout: float = 60, + alloc_timeout: float = 5, device: Optional[Union[str, torch.device]] = None, compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, @@ -637,7 +638,6 @@ class ModuleAnnouncerThread(threading.Thread): update_period: float, expiration: float, max_pinged: int = 5, - max_reported: int = 10, **kwargs, ): super().__init__(**kwargs) @@ -650,10 +650,11 @@ class ModuleAnnouncerThread(threading.Thread): self.expiration = expiration self.trigger = threading.Event() - self.max_pinged, self.max_reported = max_pinged, max_reported - last_uid = max(module_uids, key=lambda uid: int(uid.split(UID_DELIMITER)[-1])) - dht_prefix, block_index = last_uid.split(UID_DELIMITER) - self.next_uid = f"{dht_prefix}{UID_DELIMITER}{int(block_index) + 1}" + self.max_pinged = max_pinged + dht_prefix = module_uids[0].split(UID_DELIMITER)[0] + block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids] + start_block, end_block = min(block_indices), max(block_indices) + 1 + self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)] self.ping_aggregator = PingAggregator(self.dht) def run(self) -> None: @@ -664,7 +665,7 @@ class ModuleAnnouncerThread(threading.Thread): if self.server_info.state != ServerState.OFFLINE: self._ping_next_servers() self.server_info.next_pings = { - peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.fastest(self.max_reported).items() + peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.to_dict().items() } else: self.server_info.next_pings = None # No need to ping if we're disconnecting @@ -691,14 +692,14 @@ class ModuleAnnouncerThread(threading.Thread): self.join() def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]: - [module_info] = get_remote_module_infos(self.dht, [self.next_uid], latest=True) - if module_info is None: - return - - next_servers = list(module_info.servers) - if len(next_servers) > self.max_pinged: - next_servers = random.sample(next_servers, self.max_pinged) - self.ping_aggregator.ping(next_servers) + module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True) + middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers} + pinged_servers = set(sample_up_to(middle_servers, self.max_pinged)) + pinged_servers.discard(self.dht.peer_id) + if module_infos[-1] is not None: + # Sample servers hosting the block after the last one (most likely continuations) separately + pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged)) + self.ping_aggregator.ping(list(pinged_servers)) class RuntimeWithDeduplicatedPools(Runtime): diff --git a/src/petals/utils/ping.py b/src/petals/utils/ping.py index d5fd129..4245bf4 100644 --- a/src/petals/utils/ping.py +++ b/src/petals/utils/ping.py @@ -1,5 +1,6 @@ import asyncio import math +import threading import time from functools import partial from typing import Dict, Sequence @@ -34,27 +35,27 @@ async def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kwargs) -> class PingAggregator: - def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 3600): + def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 300): self.dht = dht self.ema_alpha = ema_alpha self.expiration = expiration self.ping_emas = hivemind.TimedStorage() + self.lock = threading.Lock() - def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs): + def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs) -> None: current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs)) logger.debug(f"Current RTTs: {current_rtts}") - expiration = hivemind.get_dht_time() + self.expiration - for peer_id, rtt in current_rtts.items(): - prev_rtt = self.ping_emas.get(peer_id) - if prev_rtt is not None and prev_rtt.value != math.inf: - rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value # Exponential smoothing - self.ping_emas.store(peer_id, rtt, expiration) + with self.lock: + expiration = hivemind.get_dht_time() + self.expiration + for peer_id, rtt in current_rtts.items(): + prev_rtt = self.ping_emas.get(peer_id) + if prev_rtt is not None and prev_rtt.value != math.inf: + rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value # Exponential smoothing + self.ping_emas.store(peer_id, rtt, expiration) - def fastest(self, n_peers: int) -> Dict[hivemind.PeerID, float]: - with self.ping_emas.freeze(): + def to_dict(self) -> Dict[hivemind.PeerID, float]: + with self.lock, self.ping_emas.freeze(): smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()} - logger.debug(f"Smothed RTTs: {smoothed_rtts}") - - fastest_rtts = sorted(smoothed_rtts.items(), key=lambda item: item[1])[:n_peers] - return dict(fastest_rtts) + logger.debug(f"Smothed RTTs: {smoothed_rtts}") + return smoothed_rtts diff --git a/src/petals/utils/random.py b/src/petals/utils/random.py new file mode 100644 index 0000000..15635ff --- /dev/null +++ b/src/petals/utils/random.py @@ -0,0 +1,12 @@ +import random +from typing import Collection, TypeVar + +T = TypeVar("T") + + +def sample_up_to(population: Collection[T], k: int) -> T: + if not isinstance(population, list): + population = list(population) + if len(population) > k: + population = random.sample(population, k) + return population