Refactor compute_spans() in petals.client.routing.sequence_info and

petals.server.block_selection
pull/510/head
Aleksandr Borzunov 9 months ago
parent dcd4641b82
commit 4a302d0bd0

@ -1,17 +1,15 @@
import dataclasses
import time
from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
from typing import Iterable, List, Optional, Tuple
from hivemind import get_logger
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState, parse_uid
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
from petals.utils.dht import compute_spans
logger = get_logger(__name__)
T = TypeVar("T")
@dataclasses.dataclass
class RemoteSequenceInfo:
"""
@ -30,7 +28,7 @@ class RemoteSequenceInfo:
last_updated_time: Optional[float]
@classmethod
def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
def make_empty(cls, block_uids: Iterable[ModuleUID]) -> "RemoteSequenceInfo":
block_uids = tuple(block_uids)
empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
empty_spans = tuple([] for _ in range(len(block_uids)))
@ -39,7 +37,7 @@ class RemoteSequenceInfo:
def __getitem__(self, ix: slice):
assert isinstance(ix, slice)
block_uids, block_infos = self.block_uids[ix], self.block_infos[ix]
spans_by_priority, spans_containing_block = self.compute_spans(block_infos)
spans_by_priority, spans_containing_block = self._sort_spans(block_infos)
return RemoteSequenceInfo(
block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time
)
@ -53,48 +51,17 @@ class RemoteSequenceInfo:
assert uid == info.uid, f"The DHT entry for {uid} actually points to {info.uid}"
self.block_infos[block_index].servers = info.servers
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
self.spans_by_priority, self.spans_containing_block = self._sort_spans(self.block_infos)
self.last_updated_time = time.perf_counter()
@staticmethod
def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
block_offset = parse_uid(block_infos[0].uid)[1] if block_infos else 0
num_blocks = len(block_infos)
closed_spans = []
active_spans = {}
for block_index, info in enumerate(block_infos):
for peer_id, server_info in info.servers.items():
if server_info.state != ServerState.ONLINE:
continue
if peer_id not in active_spans:
active_spans[peer_id] = RemoteSpanInfo(
peer_id=peer_id,
start=block_index,
end=block_index + 1,
server_info=server_info,
)
if server_info.start_block is not None and server_info.end_block is not None:
active_spans[peer_id].start = max(server_info.start_block - block_offset, 0)
active_spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks)
else: # peer_id in active_spans
active_spans[peer_id].end = max(active_spans[peer_id].end, block_index + 1)
for peer_id in list(active_spans.keys()):
if (
info is None
or peer_id not in info.servers
or info.servers[peer_id].state != ServerState.ONLINE
or block_index == len(block_infos) - 1
):
closed_spans.append(active_spans.pop(peer_id))
assert not active_spans, f"spans: {active_spans}"
closed_spans.sort(key=lambda span: span.length, reverse=True)
def _sort_spans(block_infos: List[RemoteModuleInfo]):
spans_by_priority = list(compute_spans(block_infos, min_state=ServerState.ONLINE).values())
spans_by_priority.sort(key=lambda span: span.length, reverse=True)
spans_containing_block = tuple(list() for _ in range(len(block_infos)))
for span in closed_spans:
spans_containing_block = tuple([] for _ in range(len(block_infos)))
for span in spans_by_priority:
for block_index in range(span.start, span.end):
spans_containing_block[block_index].append(span)
return closed_spans, spans_containing_block
return spans_by_priority, spans_containing_block

@ -117,7 +117,6 @@ class RemoteSequenceManager:
if state.sequence_info.last_updated_time is not None:
assert block_uids == state.sequence_info.block_uids
self._thread.ready.set() # no need to await the first dht fetch
self._need_latest_infos = True
@staticmethod
def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:

@ -92,9 +92,17 @@ class RemoteSpanInfo:
server_info: ServerInfo
@property
def length(self):
def length(self) -> int:
return self.end - self.start
@property
def state(self) -> ServerState:
return self.server_info.state
@property
def throughput(self) -> float:
return self.server_info.throughput
RPCInfo = Dict[str, Any]

@ -1,57 +1,23 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List
import numpy as np
from hivemind import PeerID, get_logger
from petals.data_structures import RemoteModuleInfo, ServerState, parse_uid
__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
from petals.data_structures import RemoteModuleInfo, RemoteSpanInfo, ServerState
from petals.utils.dht import compute_spans
logger = get_logger(__name__)
@dataclass
class Span:
start: int
end: int
throughput: float
state: ServerState
@property
def length(self):
return self.end - self.start
def move_to(self, new_start: int) -> None:
self.start, self.end = new_start, new_start + self.length
def compute_spans(module_infos: List[RemoteModuleInfo]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0
num_blocks = len(module_infos)
spans = {}
throughputs = np.zeros(num_blocks)
for block, module in enumerate(module_infos):
# We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
# If the order were not defined, we would get slightly different values due to floating point errors,
# which may cause excess block replacements.
for peer_id, server in sorted(module.servers.items()):
if server.state == ServerState.OFFLINE:
continue
def compute_throughputs(spans: Dict[PeerID, RemoteSpanInfo], *, total_blocks: int) -> np.ndarray:
# We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
# If the order were not defined, we would get slightly different values due to floating point errors,
# which may cause excess block replacements.
if peer_id not in spans or spans[peer_id].state.value < server.state.value:
spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state)
if server.start_block is not None and server.end_block is not None:
spans[peer_id].start = max(server.start_block - block_offset, 0)
spans[peer_id].end = min(server.end_block - block_offset, num_blocks)
elif spans[peer_id].state == server.state:
spans[peer_id].start = min(spans[peer_id].start, block)
spans[peer_id].end = max(spans[peer_id].end, block + 1)
throughputs[block] += server.throughput
return spans, throughputs
throughputs = np.zeros(total_blocks)
for span in sorted(spans.values(), key=lambda span: span.peer_id):
throughputs[span.start : span.end] += span.throughput
return throughputs
def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
@ -60,18 +26,25 @@ def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
def choose_best_blocks(num_blocks: int, module_infos: List[RemoteModuleInfo]) -> List[int]:
_, throughputs = compute_spans(module_infos)
spans = compute_spans(module_infos, min_state=ServerState.JOINING)
throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
start = _choose_best_start(throughputs, num_blocks)
return list(range(start, start + num_blocks))
def _move_span(span: RemoteSpanInfo, new_start: int):
span.start, span.end = new_start, new_start + span.length
def should_choose_other_blocks(
local_peer_id: PeerID, module_infos: List[RemoteModuleInfo], balance_quality: float
) -> bool:
if balance_quality > 1.0:
return True # Forces rebalancing on each check (may be used for debugging purposes)
spans, throughputs = compute_spans(module_infos)
spans = compute_spans(module_infos, min_state=ServerState.JOINING)
throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
initial_throughput = throughputs.min()
eps = 1e-3
@ -91,7 +64,7 @@ def should_choose_other_blocks(
return False # This server is on its best place already
throughputs[local_span.start : local_span.end] += local_span.throughput * eps
local_span.move_to(new_start)
_move_span(local_span, new_start)
throughputs[local_span.start : local_span.end] += local_span.throughput
moved = True
@ -108,7 +81,7 @@ def should_choose_other_blocks(
throughputs[span.start : span.end] += span.throughput * eps
if span.start != new_start:
span.move_to(new_start)
_move_span(span, new_start)
moved = True
throughputs[span.start : span.end] += span.throughput

@ -11,7 +11,16 @@ from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
from petals.data_structures import (
CHAIN_DELIMITER,
UID_DELIMITER,
ModuleUID,
RemoteModuleInfo,
RemoteSpanInfo,
ServerInfo,
ServerState,
parse_uid,
)
logger = get_logger(__name__)
@ -120,3 +129,26 @@ async def _get_remote_module_infos(
except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={module_info.uid}, peer_id={peer_id}: {e}")
return modules
def compute_spans(module_infos: List[RemoteModuleInfo], *, min_state: ServerState) -> Dict[PeerID, RemoteSpanInfo]:
block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0
num_blocks = len(module_infos)
spans = {}
for block_idx, module_info in enumerate(module_infos):
for peer_id, server_info in sorted(module_info.servers.items()):
if server_info.state.value < min_state.value:
continue
if peer_id not in spans or spans[peer_id].state.value < server_info.state.value:
spans[peer_id] = RemoteSpanInfo(
peer_id=peer_id, start=block_idx, end=block_idx + 1, server_info=server_info
)
if server_info.start_block is not None and server_info.end_block is not None:
spans[peer_id].start = max(server_info.start_block - block_offset, 0)
spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks)
elif spans[peer_id].state == server_info.state:
spans[peer_id].start = min(spans[peer_id].start, block_idx)
spans[peer_id].end = max(spans[peer_id].end, block_idx + 1)
return spans

Loading…
Cancel
Save