Store (start_block, end_block) in each DHT record for reliability (#510)

This PR fixes gaps in the DHT server info caused by unavailable DHT keys. Now, one DHT key is enough to get info about all blocks hosted by a server - so we'll see info until all keys are unavailable.

Also, this PR refactors `petals.client.routing` and `petals.server.block_selection` modules to use the common `compute_spans()` function (defined in `petals.utils.dht`) and `RemoteSpanInfo` class (defined in `petals.data_structures`).
pull/516/head^2
Alexander Borzunov 7 months ago committed by GitHub
parent 158621677b
commit 5ce4f1a159
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,17 +1,15 @@
import dataclasses
import time
from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
from typing import Iterable, List, Optional, Tuple
from hivemind import get_logger
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
from petals.utils.dht import compute_spans
logger = get_logger(__name__)
T = TypeVar("T")
@dataclasses.dataclass
class RemoteSequenceInfo:
"""
@ -30,7 +28,7 @@ class RemoteSequenceInfo:
last_updated_time: Optional[float]
@classmethod
def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
def make_empty(cls, block_uids: Iterable[ModuleUID]) -> "RemoteSequenceInfo":
block_uids = tuple(block_uids)
empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
empty_spans = tuple([] for _ in range(len(block_uids)))
@ -39,7 +37,7 @@ class RemoteSequenceInfo:
def __getitem__(self, ix: slice):
assert isinstance(ix, slice)
block_uids, block_infos = self.block_uids[ix], self.block_infos[ix]
spans_by_priority, spans_containing_block = self.compute_spans(block_infos)
spans_by_priority, spans_containing_block = self._sort_spans(block_infos)
return RemoteSequenceInfo(
block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time
)
@ -47,60 +45,23 @@ class RemoteSequenceInfo:
def __len__(self):
return len(self.block_uids)
def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]):
def update_(self, new_block_infos: List[RemoteModuleInfo]):
assert len(new_block_infos) == len(self.block_uids)
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
if info is None:
logger.debug(f"Found no block info for block {uid}")
continue
if not isinstance(info, RemoteModuleInfo):
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
continue
if not info.servers:
logger.debug(f"Found no active peers for block {uid}")
continue
if info.uid != uid:
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
continue
assert uid == info.uid, f"The DHT entry for {uid} actually points to {info.uid}"
self.block_infos[block_index].servers = info.servers
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
self.spans_by_priority, self.spans_containing_block = self._sort_spans(self.block_infos)
self.last_updated_time = time.perf_counter()
@staticmethod
def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
closed_spans = []
active_spans = {}
for block_index, info in enumerate(block_infos):
if info is not None:
for peer_id, server_info in info.servers.items():
if server_info.state != ServerState.ONLINE:
continue
if peer_id not in active_spans:
active_spans[peer_id] = RemoteSpanInfo(
peer_id=peer_id,
start=block_index,
end=block_index + 1,
server_info=server_info,
)
else: # peer_id in active_spans
active_spans[peer_id].end = block_index + 1
for peer_id in list(active_spans.keys()):
if (
info is None
or peer_id not in info.servers
or info.servers[peer_id].state != ServerState.ONLINE
or block_index == len(block_infos) - 1
):
closed_spans.append(active_spans.pop(peer_id))
assert not active_spans, f"spans: {active_spans}"
closed_spans.sort(key=lambda span: span.length, reverse=True)
def _sort_spans(block_infos: List[RemoteModuleInfo]):
spans_by_priority = list(compute_spans(block_infos, min_state=ServerState.ONLINE).values())
spans_by_priority.sort(key=lambda span: span.length, reverse=True)
spans_containing_block = tuple(list() for _ in range(len(block_infos)))
for span in closed_spans:
spans_containing_block = tuple([] for _ in range(len(block_infos)))
for span in spans_by_priority:
for block_index in range(span.start, span.end):
spans_containing_block[block_index].append(span)
return closed_spans, spans_containing_block
return spans_by_priority, spans_containing_block

@ -117,7 +117,6 @@ class RemoteSequenceManager:
if state.sequence_info.last_updated_time is not None:
assert block_uids == state.sequence_info.block_uids
self._thread.ready.set() # no need to await the first dht fetch
self._need_latest_infos = True
@staticmethod
def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
@ -346,9 +345,6 @@ class RemoteSequenceManager:
)
for block_info in new_block_infos:
if not block_info:
continue
# Apply allow and block lists
block_info.servers = {
peer_id: server_info

@ -11,18 +11,15 @@ UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer
CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
class ServerState(Enum):
OFFLINE = 0
JOINING = 1
ONLINE = 2
RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
def parse_uid(uid: ModuleUID) -> Tuple[str, int]:
assert CHAIN_DELIMITER not in uid, "parse_uid() does not support chained UIDs"
dht_prefix, index = uid.split(UID_DELIMITER)
return dht_prefix, int(index)
@pydantic.dataclasses.dataclass
class ModelInfo:
num_blocks: int
num_blocks: pydantic.conint(ge=1, strict=True)
repository: Optional[str] = None
def to_dict(self) -> dict:
@ -33,11 +30,23 @@ class ModelInfo:
return cls(**source)
class ServerState(Enum):
OFFLINE = 0
JOINING = 1
ONLINE = 2
RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
@pydantic.dataclasses.dataclass
class ServerInfo:
state: ServerState
throughput: RPS
start_block: Optional[pydantic.conint(ge=0, strict=True)] = None
end_block: Optional[pydantic.conint(ge=0, strict=True)] = None
public_name: Optional[str] = None
version: Optional[str] = None
@ -83,9 +92,17 @@ class RemoteSpanInfo:
server_info: ServerInfo
@property
def length(self):
def length(self) -> int:
return self.end - self.start
@property
def state(self) -> ServerState:
return self.server_info.state
@property
def throughput(self) -> float:
return self.server_info.throughput
RPCInfo = Dict[str, Any]

@ -1,54 +1,23 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List
import numpy as np
from hivemind import PeerID, get_logger
from petals.data_structures import RemoteModuleInfo, ServerState
__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
from petals.data_structures import RemoteModuleInfo, RemoteSpanInfo, ServerState
from petals.utils.dht import compute_spans
logger = get_logger(__name__)
@dataclass
class Span:
start: int
end: int
throughput: float
state: ServerState
@property
def length(self):
return self.end - self.start
def move_to(self, new_start: int) -> None:
self.start, self.end = new_start, new_start + self.length
def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
spans = {}
throughputs = np.zeros(len(module_infos))
for block, module in enumerate(module_infos):
if module is None:
continue
# We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
# If the order were not defined, we would get slightly different values due to floating point errors,
# which may cause excess block replacements.
for peer_id, server in sorted(module.servers.items()):
if server.state == ServerState.OFFLINE:
continue
def compute_throughputs(spans: Dict[PeerID, RemoteSpanInfo], *, total_blocks: int) -> np.ndarray:
# We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
# If the order were not defined, we would get slightly different values due to floating point errors,
# which may cause excess block replacements.
if peer_id in spans:
spans[peer_id].start = min(spans[peer_id].start, block)
spans[peer_id].end = max(spans[peer_id].start, block + 1)
else:
spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state)
throughputs[block] += server.throughput
return spans, throughputs
throughputs = np.zeros(total_blocks)
for span in sorted(spans.values(), key=lambda span: span.peer_id):
throughputs[span.start : span.end] += span.throughput
return throughputs
def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
@ -56,19 +25,26 @@ def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
return min(options)[-1]
def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
_, throughputs = compute_spans(module_infos)
def choose_best_blocks(num_blocks: int, module_infos: List[RemoteModuleInfo]) -> List[int]:
spans = compute_spans(module_infos, min_state=ServerState.JOINING)
throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
start = _choose_best_start(throughputs, num_blocks)
return list(range(start, start + num_blocks))
def _move_span(span: RemoteSpanInfo, new_start: int):
span.start, span.end = new_start, new_start + span.length
def should_choose_other_blocks(
local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float
local_peer_id: PeerID, module_infos: List[RemoteModuleInfo], balance_quality: float
) -> bool:
if balance_quality > 1.0:
return True # Forces rebalancing on each check (may be used for debugging purposes)
spans, throughputs = compute_spans(module_infos)
spans = compute_spans(module_infos, min_state=ServerState.JOINING)
throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
initial_throughput = throughputs.min()
eps = 1e-3
@ -88,7 +64,7 @@ def should_choose_other_blocks(
return False # This server is on its best place already
throughputs[local_span.start : local_span.end] += local_span.throughput * eps
local_span.move_to(new_start)
_move_span(local_span, new_start)
throughputs[local_span.start : local_span.end] += local_span.throughput
moved = True
@ -105,7 +81,7 @@ def should_choose_other_blocks(
throughputs[span.start : span.end] += span.throughput * eps
if span.start != new_start:
span.move_to(new_start)
_move_span(span, new_start)
moved = True
throughputs[span.start : span.end] += span.throughput

@ -23,7 +23,7 @@ from transformers import PretrainedConfig
import petals
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState, parse_uid
from petals.server import block_selection
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
from petals.server.block_utils import get_block_size, resolve_block_dtype
@ -220,11 +220,10 @@ class Server:
num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
if block_indices is not None:
try:
first_block_index, last_block_index = block_indices.split(":")
first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
start_block, end_block = [int(index.strip()) for index in block_indices.split(":")]
except Exception as e:
raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)")
block_indices = range(first_block_index, last_block_index)
block_indices = range(start_block, end_block)
num_blocks = len(block_indices)
self.strict_block_indices, self.num_blocks = block_indices, num_blocks
@ -703,11 +702,16 @@ class ModuleAnnouncerThread(threading.Thread):
self.expiration = expiration
self.trigger = threading.Event()
self.dht_prefix = parse_uid(module_uids[0])[0]
block_indices = [parse_uid(uid)[1] for uid in module_uids]
self.server_info.start_block = min(block_indices)
self.server_info.end_block = max(block_indices) + 1
self.max_pinged = max_pinged
self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
start_block, end_block = min(block_indices), max(block_indices) + 1
self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
self.next_uids = [
f"{self.dht_prefix}{UID_DELIMITER}{i}"
for i in range(self.server_info.start_block + 1, self.server_info.end_block + 1)
]
self.ping_aggregator = PingAggregator(self.dht)
def run(self) -> None:
@ -755,12 +759,11 @@ class ModuleAnnouncerThread(threading.Thread):
def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True)
middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers}
middle_servers = {peer_id for info in module_infos[:-1] for peer_id in info.servers}
pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
pinged_servers.discard(self.dht.peer_id)
if module_infos[-1] is not None:
# Sample servers hosting the block after the last one (most likely continuations) separately
pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
# Sample servers hosting the block after the last one (most likely continuations) separately
pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
self.ping_aggregator.ping(list(pinged_servers))

@ -11,7 +11,16 @@ from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
from petals.data_structures import (
CHAIN_DELIMITER,
UID_DELIMITER,
ModuleUID,
RemoteModuleInfo,
RemoteSpanInfo,
ServerInfo,
ServerState,
parse_uid,
)
logger = get_logger(__name__)
@ -70,7 +79,7 @@ def get_remote_module_infos(
*,
latest: bool = False,
return_future: bool = False,
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
) -> Union[List[RemoteModuleInfo], MPFuture]:
return dht.run_coroutine(
partial(
_get_remote_module_infos,
@ -90,7 +99,7 @@ async def _get_remote_module_infos(
active_adapter: Optional[str],
expiration_time: Optional[DHTExpiration],
latest: bool,
) -> List[Optional[RemoteModuleInfo]]:
) -> List[RemoteModuleInfo]:
if latest:
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
expiration_time = math.inf
@ -99,14 +108,14 @@ async def _get_remote_module_infos(
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
for i, uid in enumerate(uids):
metadata = found[uid]
modules = [RemoteModuleInfo(uid=uid, servers={}) for uid in uids]
for module_info in modules:
metadata = found[module_info.uid]
if metadata is None or not isinstance(metadata.value, dict):
if metadata is not None:
logger.warning(f"Incorrect metadata for {uid}: {metadata}")
logger.warning(f"Incorrect metadata for {module_info.uid}: {metadata}")
continue
servers = {}
for peer_id, server_info in metadata.value.items():
try:
peer_id = PeerID.from_base58(peer_id)
@ -116,9 +125,29 @@ async def _get_remote_module_infos(
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
servers[peer_id] = server_info
module_info.servers[peer_id] = server_info
except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
if servers:
modules[i] = RemoteModuleInfo(uid, servers)
logger.warning(f"Incorrect peer entry for uid={module_info.uid}, peer_id={peer_id}: {e}")
return modules
def compute_spans(module_infos: List[RemoteModuleInfo], *, min_state: ServerState) -> Dict[PeerID, RemoteSpanInfo]:
block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0
num_blocks = len(module_infos)
spans = {}
for block_idx, module_info in enumerate(module_infos):
for peer_id, server_info in sorted(module_info.servers.items()):
if server_info.state.value < min_state.value:
continue
if peer_id not in spans or spans[peer_id].state.value < server_info.state.value:
spans[peer_id] = RemoteSpanInfo(
peer_id=peer_id, start=block_idx, end=block_idx + 1, server_info=server_info
)
if server_info.start_block is not None and server_info.end_block is not None:
spans[peer_id].start = max(server_info.start_block - block_offset, 0)
spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks)
elif spans[peer_id].state == server_info.state:
spans[peer_id].end = max(spans[peer_id].end, block_idx + 1)
return spans

Loading…
Cancel
Save