diff --git a/src/petals/client/routing/sequence_info.py b/src/petals/client/routing/sequence_info.py index bce6712..2c9137b 100644 --- a/src/petals/client/routing/sequence_info.py +++ b/src/petals/client/routing/sequence_info.py @@ -1,17 +1,15 @@ import dataclasses import time -from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar +from typing import Iterable, List, Optional, Tuple from hivemind import get_logger from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState +from petals.utils.dht import compute_spans logger = get_logger(__name__) -T = TypeVar("T") - - @dataclasses.dataclass class RemoteSequenceInfo: """ @@ -30,7 +28,7 @@ class RemoteSequenceInfo: last_updated_time: Optional[float] @classmethod - def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T: + def make_empty(cls, block_uids: Iterable[ModuleUID]) -> "RemoteSequenceInfo": block_uids = tuple(block_uids) empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids) empty_spans = tuple([] for _ in range(len(block_uids))) @@ -39,7 +37,7 @@ class RemoteSequenceInfo: def __getitem__(self, ix: slice): assert isinstance(ix, slice) block_uids, block_infos = self.block_uids[ix], self.block_infos[ix] - spans_by_priority, spans_containing_block = self.compute_spans(block_infos) + spans_by_priority, spans_containing_block = self._sort_spans(block_infos) return RemoteSequenceInfo( block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time ) @@ -47,60 +45,23 @@ class RemoteSequenceInfo: def __len__(self): return len(self.block_uids) - def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]): + def update_(self, new_block_infos: List[RemoteModuleInfo]): assert len(new_block_infos) == len(self.block_uids) for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)): - if info is None: - logger.debug(f"Found no block info for block {uid}") - continue - if not isinstance(info, RemoteModuleInfo): - logger.warning(f"Unexpected dht entry type for {uid}: {info}") - continue - if not info.servers: - logger.debug(f"Found no active peers for block {uid}") - continue - if info.uid != uid: - logger.warning(f"The DHT entry for {uid} actually points to {info.uid}") - continue + assert uid == info.uid, f"The DHT entry for {uid} actually points to {info.uid}" self.block_infos[block_index].servers = info.servers - self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos) + self.spans_by_priority, self.spans_containing_block = self._sort_spans(self.block_infos) self.last_updated_time = time.perf_counter() @staticmethod - def compute_spans(block_infos: Sequence[RemoteModuleInfo]): - closed_spans = [] - active_spans = {} - for block_index, info in enumerate(block_infos): - if info is not None: - for peer_id, server_info in info.servers.items(): - if server_info.state != ServerState.ONLINE: - continue - if peer_id not in active_spans: - active_spans[peer_id] = RemoteSpanInfo( - peer_id=peer_id, - start=block_index, - end=block_index + 1, - server_info=server_info, - ) - else: # peer_id in active_spans - active_spans[peer_id].end = block_index + 1 - - for peer_id in list(active_spans.keys()): - if ( - info is None - or peer_id not in info.servers - or info.servers[peer_id].state != ServerState.ONLINE - or block_index == len(block_infos) - 1 - ): - closed_spans.append(active_spans.pop(peer_id)) - assert not active_spans, f"spans: {active_spans}" - - closed_spans.sort(key=lambda span: span.length, reverse=True) + def _sort_spans(block_infos: List[RemoteModuleInfo]): + spans_by_priority = list(compute_spans(block_infos, min_state=ServerState.ONLINE).values()) + spans_by_priority.sort(key=lambda span: span.length, reverse=True) - spans_containing_block = tuple(list() for _ in range(len(block_infos))) - for span in closed_spans: + spans_containing_block = tuple([] for _ in range(len(block_infos))) + for span in spans_by_priority: for block_index in range(span.start, span.end): spans_containing_block[block_index].append(span) - return closed_spans, spans_containing_block + return spans_by_priority, spans_containing_block diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 3e239b4..ed5224c 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -117,7 +117,6 @@ class RemoteSequenceManager: if state.sequence_info.last_updated_time is not None: assert block_uids == state.sequence_info.block_uids self._thread.ready.set() # no need to await the first dht fetch - self._need_latest_infos = True @staticmethod def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]: @@ -346,9 +345,6 @@ class RemoteSequenceManager: ) for block_info in new_block_infos: - if not block_info: - continue - # Apply allow and block lists block_info.servers = { peer_id: server_info diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index 7c86f14..9cbbf76 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -11,18 +11,15 @@ UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4" -class ServerState(Enum): - OFFLINE = 0 - JOINING = 1 - ONLINE = 2 - - -RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True) +def parse_uid(uid: ModuleUID) -> Tuple[str, int]: + assert CHAIN_DELIMITER not in uid, "parse_uid() does not support chained UIDs" + dht_prefix, index = uid.split(UID_DELIMITER) + return dht_prefix, int(index) @pydantic.dataclasses.dataclass class ModelInfo: - num_blocks: int + num_blocks: pydantic.conint(ge=1, strict=True) repository: Optional[str] = None def to_dict(self) -> dict: @@ -33,11 +30,23 @@ class ModelInfo: return cls(**source) +class ServerState(Enum): + OFFLINE = 0 + JOINING = 1 + ONLINE = 2 + + +RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True) + + @pydantic.dataclasses.dataclass class ServerInfo: state: ServerState throughput: RPS + start_block: Optional[pydantic.conint(ge=0, strict=True)] = None + end_block: Optional[pydantic.conint(ge=0, strict=True)] = None + public_name: Optional[str] = None version: Optional[str] = None @@ -83,9 +92,17 @@ class RemoteSpanInfo: server_info: ServerInfo @property - def length(self): + def length(self) -> int: return self.end - self.start + @property + def state(self) -> ServerState: + return self.server_info.state + + @property + def throughput(self) -> float: + return self.server_info.throughput + RPCInfo = Dict[str, Any] diff --git a/src/petals/server/block_selection.py b/src/petals/server/block_selection.py index cc050d4..441c0cd 100644 --- a/src/petals/server/block_selection.py +++ b/src/petals/server/block_selection.py @@ -1,54 +1,23 @@ -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Dict, List import numpy as np from hivemind import PeerID, get_logger -from petals.data_structures import RemoteModuleInfo, ServerState - -__all__ = ["choose_best_blocks", "should_choose_other_blocks"] +from petals.data_structures import RemoteModuleInfo, RemoteSpanInfo, ServerState +from petals.utils.dht import compute_spans logger = get_logger(__name__) -@dataclass -class Span: - start: int - end: int - throughput: float - state: ServerState - - @property - def length(self): - return self.end - self.start - - def move_to(self, new_start: int) -> None: - self.start, self.end = new_start, new_start + self.length - - -def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]: - spans = {} - throughputs = np.zeros(len(module_infos)) - for block, module in enumerate(module_infos): - if module is None: - continue - - # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers. - # If the order were not defined, we would get slightly different values due to floating point errors, - # which may cause excess block replacements. - for peer_id, server in sorted(module.servers.items()): - if server.state == ServerState.OFFLINE: - continue +def compute_throughputs(spans: Dict[PeerID, RemoteSpanInfo], *, total_blocks: int) -> np.ndarray: + # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers. + # If the order were not defined, we would get slightly different values due to floating point errors, + # which may cause excess block replacements. - if peer_id in spans: - spans[peer_id].start = min(spans[peer_id].start, block) - spans[peer_id].end = max(spans[peer_id].start, block + 1) - else: - spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state) - - throughputs[block] += server.throughput - - return spans, throughputs + throughputs = np.zeros(total_blocks) + for span in sorted(spans.values(), key=lambda span: span.peer_id): + throughputs[span.start : span.end] += span.throughput + return throughputs def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int: @@ -56,19 +25,26 @@ def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int: return min(options)[-1] -def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]: - _, throughputs = compute_spans(module_infos) +def choose_best_blocks(num_blocks: int, module_infos: List[RemoteModuleInfo]) -> List[int]: + spans = compute_spans(module_infos, min_state=ServerState.JOINING) + throughputs = compute_throughputs(spans, total_blocks=len(module_infos)) + start = _choose_best_start(throughputs, num_blocks) return list(range(start, start + num_blocks)) +def _move_span(span: RemoteSpanInfo, new_start: int): + span.start, span.end = new_start, new_start + span.length + + def should_choose_other_blocks( - local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float + local_peer_id: PeerID, module_infos: List[RemoteModuleInfo], balance_quality: float ) -> bool: if balance_quality > 1.0: return True # Forces rebalancing on each check (may be used for debugging purposes) - spans, throughputs = compute_spans(module_infos) + spans = compute_spans(module_infos, min_state=ServerState.JOINING) + throughputs = compute_throughputs(spans, total_blocks=len(module_infos)) initial_throughput = throughputs.min() eps = 1e-3 @@ -88,7 +64,7 @@ def should_choose_other_blocks( return False # This server is on its best place already throughputs[local_span.start : local_span.end] += local_span.throughput * eps - local_span.move_to(new_start) + _move_span(local_span, new_start) throughputs[local_span.start : local_span.end] += local_span.throughput moved = True @@ -105,7 +81,7 @@ def should_choose_other_blocks( throughputs[span.start : span.end] += span.throughput * eps if span.start != new_start: - span.move_to(new_start) + _move_span(span, new_start) moved = True throughputs[span.start : span.end] += span.throughput diff --git a/src/petals/server/server.py b/src/petals/server/server.py index fd9f766..82388aa 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -23,7 +23,7 @@ from transformers import PretrainedConfig import petals from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState, parse_uid from petals.server import block_selection from petals.server.backend import TransformerBackend, merge_inference_pools_inplace from petals.server.block_utils import get_block_size, resolve_block_dtype @@ -220,11 +220,10 @@ class Server: num_blocks = min(num_blocks, self.block_config.num_hidden_layers) if block_indices is not None: try: - first_block_index, last_block_index = block_indices.split(":") - first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index))) + start_block, end_block = [int(index.strip()) for index in block_indices.split(":")] except Exception as e: raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)") - block_indices = range(first_block_index, last_block_index) + block_indices = range(start_block, end_block) num_blocks = len(block_indices) self.strict_block_indices, self.num_blocks = block_indices, num_blocks @@ -703,11 +702,16 @@ class ModuleAnnouncerThread(threading.Thread): self.expiration = expiration self.trigger = threading.Event() + self.dht_prefix = parse_uid(module_uids[0])[0] + block_indices = [parse_uid(uid)[1] for uid in module_uids] + self.server_info.start_block = min(block_indices) + self.server_info.end_block = max(block_indices) + 1 + self.max_pinged = max_pinged - self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0] - block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids] - start_block, end_block = min(block_indices), max(block_indices) + 1 - self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)] + self.next_uids = [ + f"{self.dht_prefix}{UID_DELIMITER}{i}" + for i in range(self.server_info.start_block + 1, self.server_info.end_block + 1) + ] self.ping_aggregator = PingAggregator(self.dht) def run(self) -> None: @@ -755,12 +759,11 @@ class ModuleAnnouncerThread(threading.Thread): def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]: module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True) - middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers} + middle_servers = {peer_id for info in module_infos[:-1] for peer_id in info.servers} pinged_servers = set(sample_up_to(middle_servers, self.max_pinged)) pinged_servers.discard(self.dht.peer_id) - if module_infos[-1] is not None: - # Sample servers hosting the block after the last one (most likely continuations) separately - pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged)) + # Sample servers hosting the block after the last one (most likely continuations) separately + pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged)) self.ping_aggregator.ping(list(pinged_servers)) diff --git a/src/petals/utils/dht.py b/src/petals/utils/dht.py index 0710f60..4faf74a 100644 --- a/src/petals/utils/dht.py +++ b/src/petals/utils/dht.py @@ -11,7 +11,16 @@ from hivemind.dht import DHT, DHTNode, DHTValue from hivemind.p2p import PeerID from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo +from petals.data_structures import ( + CHAIN_DELIMITER, + UID_DELIMITER, + ModuleUID, + RemoteModuleInfo, + RemoteSpanInfo, + ServerInfo, + ServerState, + parse_uid, +) logger = get_logger(__name__) @@ -70,7 +79,7 @@ def get_remote_module_infos( *, latest: bool = False, return_future: bool = False, -) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]: +) -> Union[List[RemoteModuleInfo], MPFuture]: return dht.run_coroutine( partial( _get_remote_module_infos, @@ -90,7 +99,7 @@ async def _get_remote_module_infos( active_adapter: Optional[str], expiration_time: Optional[DHTExpiration], latest: bool, -) -> List[Optional[RemoteModuleInfo]]: +) -> List[RemoteModuleInfo]: if latest: assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both" expiration_time = math.inf @@ -99,14 +108,14 @@ async def _get_remote_module_infos( num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers) - modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids) - for i, uid in enumerate(uids): - metadata = found[uid] + modules = [RemoteModuleInfo(uid=uid, servers={}) for uid in uids] + for module_info in modules: + metadata = found[module_info.uid] if metadata is None or not isinstance(metadata.value, dict): if metadata is not None: - logger.warning(f"Incorrect metadata for {uid}: {metadata}") + logger.warning(f"Incorrect metadata for {module_info.uid}: {metadata}") continue - servers = {} + for peer_id, server_info in metadata.value.items(): try: peer_id = PeerID.from_base58(peer_id) @@ -116,9 +125,29 @@ async def _get_remote_module_infos( logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}") continue - servers[peer_id] = server_info + module_info.servers[peer_id] = server_info except (TypeError, ValueError) as e: - logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}") - if servers: - modules[i] = RemoteModuleInfo(uid, servers) + logger.warning(f"Incorrect peer entry for uid={module_info.uid}, peer_id={peer_id}: {e}") return modules + + +def compute_spans(module_infos: List[RemoteModuleInfo], *, min_state: ServerState) -> Dict[PeerID, RemoteSpanInfo]: + block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0 + num_blocks = len(module_infos) + + spans = {} + for block_idx, module_info in enumerate(module_infos): + for peer_id, server_info in sorted(module_info.servers.items()): + if server_info.state.value < min_state.value: + continue + + if peer_id not in spans or spans[peer_id].state.value < server_info.state.value: + spans[peer_id] = RemoteSpanInfo( + peer_id=peer_id, start=block_idx, end=block_idx + 1, server_info=server_info + ) + if server_info.start_block is not None and server_info.end_block is not None: + spans[peer_id].start = max(server_info.start_block - block_offset, 0) + spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks) + elif spans[peer_id].state == server_info.state: + spans[peer_id].end = max(spans[peer_id].end, block_idx + 1) + return spans