Implement shortest-path routing for inference (#362)

This PR:

1. **Adds shortest path routing for inference.** We build a graph with client-server and server-server latencies and compute costs, as well as empirically measured overheads. For client-server latencies, we ping possible first and last servers in a sequence in `SequenceManager.update()`. We penalize servers who may not have enough cache for our request. This uses info added to DHT in #355, #356, #358.

2. **Makes a server ping neighboring servers in addition to next ones.** This is to get an opportunity to change the server even before we use all its blocks (e.g., because a neighboring server is faster). This feature is not enabled though, since it increases graph size for N servers to O(N^2) - but we may enable it if needed.

3. **Fixes a `SequenceManager` bug with the first `update()`.** Previously, this update was likely to produce incorrect information and cause to `MissingBlocksErrors` until the next update happens.
pull/365/head
Alexander Borzunov 10 months ago committed by GitHub
parent fd30f7ce10
commit 62d9ed5ce7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -48,6 +48,7 @@ install_requires =
sentencepiece>=0.1.99
peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735
safetensors>=0.3.1
Dijkstar>=2.6.0
[options.extras_require]
dev =

@ -11,7 +11,7 @@ from petals.models import *
from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "1.2.0.dev2"
__version__ = "1.2.0.dev3"
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

@ -84,7 +84,7 @@ def main():
parser.add_argument('--attn_cache_tokens', type=int, default=8192,
help='The number of past attention key/value pairs that will be stored between inference steps. '
'Default: 8192 (4 simultaneous sessions of up to 2048 tokens).')
parser.add_argument('--alloc_timeout', type=float, default=60,
parser.add_argument('--alloc_timeout', type=float, default=5,
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
'before rejecting the request')
parser.add_argument('--revision', type=str, default=None,

@ -340,7 +340,9 @@ class InferenceSession:
f"from block {block_idx} to {update_end} will be regenerated"
)
updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency")
updated_spans = self._sequence_manager.make_sequence(
block_idx, update_end, mode="min_latency", cache_tokens_needed=self._max_length
)
# make_sequence() could return a longer sequence
updated_spans[-1].end = min(updated_spans[-1].end, update_end)
updated_sessions = self._enter_server_sessions(updated_spans)

@ -10,6 +10,7 @@ import time
from typing import Any, Collection, Dict, List, Optional, Sequence, Union
from weakref import WeakMethod
import dijkstar
import numpy as np
from hivemind import DHT, P2P, MSGPackSerializer, PeerID
from hivemind.dht.node import Blacklist
@ -23,6 +24,8 @@ from petals.client.routing.spending_policy import NoSpendingPolicy
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
from petals.server.handler import TransformerConnectionHandler
from petals.utils.ping import PingAggregator
from petals.utils.random import sample_up_to
logger = get_logger(__name__)
@ -33,6 +36,7 @@ class SequenceManagerConfig:
dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name)
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers
use_server_to_server: bool = True # Use direct server-to-server communication
@ -43,7 +47,10 @@ class SequenceManagerConfig:
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
max_backoff: float = 60 # limit maximal sleep time between retries to this value
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
active_adapter: Optional[str] = None
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
max_pinged: int = 5 # max servers to ping from each sequence side, per update
ping_timeout: float = 2 # max time to wait for pings, per update
@dataclasses.dataclass
@ -79,7 +86,6 @@ class RemoteSequenceManager:
*,
dht: Optional[DHT] = None,
state: Optional[SequenceManagerState] = None,
active_adapter: Optional[str] = None,
):
assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
@ -94,7 +100,7 @@ class RemoteSequenceManager:
dht = DHT(
initial_peers=config.initial_peers,
client_mode=True,
num_workers=config.num_hidden_layers,
num_workers=32,
startup_timeout=config.daemon_startup_timeout,
start=True,
)
@ -109,25 +115,25 @@ class RemoteSequenceManager:
self._thread_start_lock = threading.Lock()
self.policy = NoSpendingPolicy()
self.ping_aggregator = PingAggregator(dht)
if state.banned_peers is None:
state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0)
if state.sequence_info is None:
state.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
if state.sequence_info.last_updated_time is None:
# Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
# in the first _update() instead of the latest ones. This makes the first .update() faster.
petals.dht_utils.get_remote_module_infos(
self.dht, self.block_uids, active_adapter=active_adapter, latest=True, return_future=True
)
self._need_latest_infos = False
else:
if state.sequence_info.last_updated_time is not None:
assert block_uids == state.sequence_info.block_uids
self._thread.ready.set() # no need to await the first dht fetch
self._need_latest_infos = True
def make_sequence(
self, start_index: int = 0, end_index: Optional[int] = None, *, mode: str
self,
start_index: int = 0,
end_index: Optional[int] = None,
*,
mode: str,
cache_tokens_needed: Optional[int] = None,
) -> List[RemoteSpanInfo]:
"""
Form a sequence of remote servers that collectively serve all consecutive layers
@ -143,6 +149,150 @@ class RemoteSequenceManager:
self.update(wait=True) # this will await an existing update or trigger a new one (if not updating)
end_index = end_index if end_index is not None else len(self)
if mode == "min_latency":
span_sequence = self._make_sequence_with_min_latency(
start_index, end_index, cache_tokens_needed=cache_tokens_needed
)
elif mode == "max_throughput":
span_sequence = self._make_sequence_with_max_throughput(start_index, end_index)
else:
raise RuntimeError(f"Unexpected mode {mode}")
if self.config.show_route is True or (mode == "min_latency" and self.config.show_route == "inference"):
route_repr = " => ".join(
[f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence]
)
logger.info(f"Route found: {route_repr}")
return span_sequence
def _make_sequence_with_min_latency(
self, start_index: int, end_index: int, *, cache_tokens_needed: Optional[int]
) -> List[RemoteSpanInfo]:
if start_index == end_index:
return []
with self.lock_changes:
missing_blocks = [
block_idx
for block_idx in range(start_index, end_index)
if not self.state.sequence_info.spans_containing_block[block_idx]
]
if missing_blocks:
raise MissingBlocksError(missing_blocks)
server_infos = {
span.peer_id: span.server_info
for block_idx in range(start_index, end_index)
for span in self.state.sequence_info.spans_containing_block[block_idx]
}
graph = self._build_inference_graph(start_index, end_index, cache_tokens_needed=cache_tokens_needed)
path = dijkstar.find_path(graph, "start", "end")
logger.debug(f"Path info: {path}")
if start_index == 0 and end_index == len(self):
logger.debug(f"Expected speed: {1 / path.total_cost:.1f} steps/sec")
span_sequence = []
for peer_id, block_idx in path.nodes[1:-1]:
if not span_sequence or span_sequence[-1].peer_id != peer_id:
span_sequence.append(RemoteSpanInfo(peer_id, block_idx, block_idx, server_infos[peer_id]))
else:
span_sequence[-1].end = block_idx
# Remove empty spans that can appear if we don't force to go to the end of each server and network delay
# don't follow triangle inequality (delay(A, B) + delay(B, C) < delay(A, C)) due to measurement errors
span_sequence = [span for span in span_sequence if span.length > 0]
return span_sequence
def _build_inference_graph(
self,
start_index: int,
end_index: int,
*,
cache_tokens_needed: Optional[int],
overhead_coeff: float = 1.82, # Backend overhead (empirically measured)
overhead_delay: float = 0.018, # Serialization overhead (empirically measured)
default_inference_rps: float = 300, # If inference RPS unknown
alloc_delay: float = 10, # If not enough cache left, we penalize the edge
) -> dijkstar.Graph:
missing_blocks = [
block_idx
for block_idx in range(start_index, end_index)
if not self.state.sequence_info.spans_containing_block[block_idx]
]
if missing_blocks:
raise MissingBlocksError(missing_blocks)
client_server_rtts = self.ping_aggregator.to_dict()
graph = dijkstar.Graph()
# Clent -> server network delays
for span in self.state.sequence_info.spans_containing_block[start_index]:
delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id))
delay += overhead_delay
if not self._has_cache_for(span, cache_tokens_needed):
delay += alloc_delay
graph.add_edge("start", (span.peer_id, start_index), delay)
# Server -> client network delays
for span in self.state.sequence_info.spans_containing_block[end_index - 1]:
delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id))
graph.add_edge((span.peer_id, end_index), "end", delay)
# Server -> server network delays
for block_idx in range(start_index + 1, end_index):
for cur_span in self.state.sequence_info.spans_containing_block[block_idx - 1]:
if cur_span.end != block_idx:
# If we choose a server, we force to go to the end of it before switching to a new one
# to avoid O(N^2) graphs for N servers
continue
for next_span in self.state.sequence_info.spans_containing_block[block_idx]:
rtt = None
if cur_span.server_info.next_pings is not None:
rtt = cur_span.server_info.next_pings.get(next_span.peer_id.to_base58())
delay = self._rtt_to_delay(rtt)
delay += overhead_delay
if not self._has_cache_for(next_span, cache_tokens_needed):
delay += alloc_delay
graph.add_edge((cur_span.peer_id, block_idx), (next_span.peer_id, block_idx), delay)
# Compute delays
for span in self.state.sequence_info.spans_by_priority:
for block_idx in range(max(span.start, start_index), min(span.end, end_index)):
inference_rps = span.server_info.inference_rps
if inference_rps is None:
inference_rps = default_inference_rps
graph.add_edge((span.peer_id, block_idx), (span.peer_id, block_idx + 1), overhead_coeff / inference_rps)
return graph
@staticmethod
def _rtt_to_delay(
rtt: float,
*,
default_delay: float = 0.15, # If network delay unknown
max_delay: float = 5, # If unreachable, we don't want to discard the edge completely
) -> float:
if rtt is None:
return default_delay
return min(rtt / 2, max_delay)
@staticmethod
def _has_cache_for(span: RemoteSpanInfo, cache_tokens_needed: Optional[int] = None) -> bool:
if cache_tokens_needed is None or span.server_info.cache_tokens_left is None:
return True
# Here, `span` contains all blocks hosted by a server - but we won't necessarily run all of them through
# this particular server in our path. It is difficult to estimate how many blocks we'll use at this stage,
# so we assume that we'll use all of them (the worst case for the cache size) and get a pessimistic estimate.
# This is okay since false positives are more costly than false negatives here.
return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
span_sequence = []
current_index = start_index
while current_index < end_index:
@ -150,20 +300,12 @@ class RemoteSequenceManager:
if not candidate_spans:
raise MissingBlocksError(current_index)
if mode == "max_throughput":
span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
elif mode == "min_latency":
span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64)
else:
raise RuntimeError(f"Unexpected mode {mode}")
span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
assert chosen_span.start <= current_index < chosen_span.end
span_sequence.append(dataclasses.replace(chosen_span, start=current_index))
current_index = chosen_span.end
route_repr = " => ".join([f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence])
logger.debug(f"Route found: {route_repr}")
return span_sequence
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
@ -182,10 +324,10 @@ class RemoteSequenceManager:
def _update(self):
"""Perform an immediate and synchronous refresh, may take time"""
new_block_infos = petals.dht_utils.get_remote_module_infos(
self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=self._need_latest_infos
self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True
)
self._need_latest_infos = True # All future _update() should use latest infos
for block_info in new_block_infos:
if not block_info:
@ -217,6 +359,14 @@ class RemoteSequenceManager:
with self.lock_changes:
self.state.sequence_info.update_(new_block_infos)
first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]]
last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]]
pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged))
pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged))
self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)
self.ready.set()
def on_request_failure(self, peer_id: Optional[PeerID]):

@ -32,6 +32,7 @@ from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
from petals.utils.ping import PingAggregator
from petals.utils.random import sample_up_to
from petals.utils.version import get_compatible_model_repo
logger = get_logger(__name__)
@ -61,7 +62,7 @@ class Server:
cache_dir: Optional[str] = None,
max_disk_space: Optional[int] = None,
attn_cache_tokens: int = 8192,
alloc_timeout: float = 60,
alloc_timeout: float = 5,
device: Optional[Union[str, torch.device]] = None,
compression=CompressionType.NONE,
stats_report_interval: Optional[int] = None,
@ -637,7 +638,6 @@ class ModuleAnnouncerThread(threading.Thread):
update_period: float,
expiration: float,
max_pinged: int = 5,
max_reported: int = 10,
**kwargs,
):
super().__init__(**kwargs)
@ -650,10 +650,11 @@ class ModuleAnnouncerThread(threading.Thread):
self.expiration = expiration
self.trigger = threading.Event()
self.max_pinged, self.max_reported = max_pinged, max_reported
last_uid = max(module_uids, key=lambda uid: int(uid.split(UID_DELIMITER)[-1]))
dht_prefix, block_index = last_uid.split(UID_DELIMITER)
self.next_uid = f"{dht_prefix}{UID_DELIMITER}{int(block_index) + 1}"
self.max_pinged = max_pinged
dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
start_block, end_block = min(block_indices), max(block_indices) + 1
self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
self.ping_aggregator = PingAggregator(self.dht)
def run(self) -> None:
@ -664,7 +665,7 @@ class ModuleAnnouncerThread(threading.Thread):
if self.server_info.state != ServerState.OFFLINE:
self._ping_next_servers()
self.server_info.next_pings = {
peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.fastest(self.max_reported).items()
peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.to_dict().items()
}
else:
self.server_info.next_pings = None # No need to ping if we're disconnecting
@ -691,14 +692,14 @@ class ModuleAnnouncerThread(threading.Thread):
self.join()
def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
[module_info] = get_remote_module_infos(self.dht, [self.next_uid], latest=True)
if module_info is None:
return
next_servers = list(module_info.servers)
if len(next_servers) > self.max_pinged:
next_servers = random.sample(next_servers, self.max_pinged)
self.ping_aggregator.ping(next_servers)
module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True)
middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers}
pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
pinged_servers.discard(self.dht.peer_id)
if module_infos[-1] is not None:
# Sample servers hosting the block after the last one (most likely continuations) separately
pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
self.ping_aggregator.ping(list(pinged_servers))
class RuntimeWithDeduplicatedPools(Runtime):

@ -1,5 +1,6 @@
import asyncio
import math
import threading
import time
from functools import partial
from typing import Dict, Sequence
@ -34,27 +35,27 @@ async def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kwargs) ->
class PingAggregator:
def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 3600):
def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 300):
self.dht = dht
self.ema_alpha = ema_alpha
self.expiration = expiration
self.ping_emas = hivemind.TimedStorage()
self.lock = threading.Lock()
def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs):
def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs) -> None:
current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs))
logger.debug(f"Current RTTs: {current_rtts}")
expiration = hivemind.get_dht_time() + self.expiration
for peer_id, rtt in current_rtts.items():
prev_rtt = self.ping_emas.get(peer_id)
if prev_rtt is not None and prev_rtt.value != math.inf:
rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value # Exponential smoothing
self.ping_emas.store(peer_id, rtt, expiration)
with self.lock:
expiration = hivemind.get_dht_time() + self.expiration
for peer_id, rtt in current_rtts.items():
prev_rtt = self.ping_emas.get(peer_id)
if prev_rtt is not None and prev_rtt.value != math.inf:
rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value # Exponential smoothing
self.ping_emas.store(peer_id, rtt, expiration)
def fastest(self, n_peers: int) -> Dict[hivemind.PeerID, float]:
with self.ping_emas.freeze():
def to_dict(self) -> Dict[hivemind.PeerID, float]:
with self.lock, self.ping_emas.freeze():
smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()}
logger.debug(f"Smothed RTTs: {smoothed_rtts}")
fastest_rtts = sorted(smoothed_rtts.items(), key=lambda item: item[1])[:n_peers]
return dict(fastest_rtts)
logger.debug(f"Smothed RTTs: {smoothed_rtts}")
return smoothed_rtts

@ -0,0 +1,12 @@
import random
from typing import Collection, TypeVar
T = TypeVar("T")
def sample_up_to(population: Collection[T], k: int) -> T:
if not isinstance(population, list):
population = list(population)
if len(population) > k:
population = random.sample(population, k)
return population
Loading…
Cancel
Save