You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/client/routing/sequence_info.py

107 lines
4.5 KiB
Python

import dataclasses
import time
from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
from hivemind import get_logger
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
logger = get_logger(__name__)
T = TypeVar("T")
@dataclasses.dataclass
class RemoteSequenceInfo:
"""
A dataclass that stores general information about which servers hold any given layer;
- updated by RemoteSequenceManager in a background thread
- accessed by routing strategies in .on_update
:note: this class should *not* be modified by RoutingStrategy.on_update to avoid interference between strategies;
Any metadata specific to one routing strategy, it should be stored inside that strategy. Any information that
is used by most routing strategies should be moved from said strategies to this class.
"""
block_uids: Tuple[ModuleUID, ...]
block_infos: Tuple[RemoteModuleInfo, ...] # note: the contents of RemoteModuleInfo can and will be updated
spans_by_priority: List[RemoteSpanInfo]
spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
Refactor RemoteSequenceManager (#309) This PR: 1. **Extracts `SequenceManagerConfig` and `SequenceManagerState` subclasses.** The config is provided by caller and never changed from inside `RemoteSequenceManager`. The state is a part of the `RemoteSequenceManager`'s state shared between the main manager and its slices. We fix some slicing bugs along the way. 2. **Removes `dht_prefix` and `p2p` arguments, makes `dht` argument optional.** `dht_prefix` can always be overridden using `config.dht_prefix`. `p2p` actually needed only under the hood of `RemoteSequenceManager`, so it can extract it by itself without exposing this low-level class to callers. If strictly necessary, a caller can provide `p2p` as a part of `SequenceManagerState`. `dht` is also needed only by `RemoteSequenceManager`, so we can make it optional in the parent classes and create it automatically when it's not provided. 3. **Simplifies retry logic.** Previously, we could have "nested" retry loops: one in `._update()`, another in inference/forward/backward steps. The loop in `._update()` could introduce issues to concurrent inference/forward/backward calls, since it blocks the entire class if its delay period becomes too high. Now this logic is simplified: `._update()` performs only one attempt to fetch the DHT info, any retries are triggered by the inference/forward/backward steps. 4. **Removes deprecated `RemoteTransformerBlock`.** `RemoteTransformerBlock` was deprecated a long time ago, before Petals 1.0.0. Its removal is long due. 5. **Removes `dht_utils.get_remote_module()`, `dht_utils.get_remote_sequence()`.** This functions duplicate the functionality of the `RemoteSequential` constructor. 6. (minor) **Removes `RemoteSequential.is_subsequence` flag.** This flag worked incorrectly and was never used. I am removing it for the sake of simplicity.
1 year ago
last_updated_time: Optional[float]
@classmethod
def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
block_uids = tuple(block_uids)
empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
empty_spans = tuple([] for _ in range(len(block_uids)))
Refactor RemoteSequenceManager (#309) This PR: 1. **Extracts `SequenceManagerConfig` and `SequenceManagerState` subclasses.** The config is provided by caller and never changed from inside `RemoteSequenceManager`. The state is a part of the `RemoteSequenceManager`'s state shared between the main manager and its slices. We fix some slicing bugs along the way. 2. **Removes `dht_prefix` and `p2p` arguments, makes `dht` argument optional.** `dht_prefix` can always be overridden using `config.dht_prefix`. `p2p` actually needed only under the hood of `RemoteSequenceManager`, so it can extract it by itself without exposing this low-level class to callers. If strictly necessary, a caller can provide `p2p` as a part of `SequenceManagerState`. `dht` is also needed only by `RemoteSequenceManager`, so we can make it optional in the parent classes and create it automatically when it's not provided. 3. **Simplifies retry logic.** Previously, we could have "nested" retry loops: one in `._update()`, another in inference/forward/backward steps. The loop in `._update()` could introduce issues to concurrent inference/forward/backward calls, since it blocks the entire class if its delay period becomes too high. Now this logic is simplified: `._update()` performs only one attempt to fetch the DHT info, any retries are triggered by the inference/forward/backward steps. 4. **Removes deprecated `RemoteTransformerBlock`.** `RemoteTransformerBlock` was deprecated a long time ago, before Petals 1.0.0. Its removal is long due. 5. **Removes `dht_utils.get_remote_module()`, `dht_utils.get_remote_sequence()`.** This functions duplicate the functionality of the `RemoteSequential` constructor. 6. (minor) **Removes `RemoteSequential.is_subsequence` flag.** This flag worked incorrectly and was never used. I am removing it for the sake of simplicity.
1 year ago
return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=None)
def __getitem__(self, ix: slice):
assert isinstance(ix, slice)
block_uids, block_infos = self.block_uids[ix], self.block_infos[ix]
spans_by_priority, spans_containing_block = self.compute_spans(block_infos)
return RemoteSequenceInfo(
block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time
)
def __len__(self):
return len(self.block_uids)
def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]):
assert len(new_block_infos) == len(self.block_uids)
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
if info is None:
logger.debug(f"Found no block info for block {uid}")
continue
if not isinstance(info, RemoteModuleInfo):
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
continue
if not info.servers:
logger.debug(f"Found no active peers for block {uid}")
continue
if info.uid != uid:
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
continue
self.block_infos[block_index].servers = info.servers
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
self.last_updated_time = time.perf_counter()
@staticmethod
def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
closed_spans = []
active_spans = {}
for block_index, info in enumerate(block_infos):
if info is not None:
for peer_id, server_info in info.servers.items():
if server_info.state != ServerState.ONLINE:
continue
if peer_id not in active_spans:
active_spans[peer_id] = RemoteSpanInfo(
peer_id=peer_id,
start=block_index,
end=block_index + 1,
server_info=server_info,
)
else: # peer_id in active_spans
active_spans[peer_id].end = block_index + 1
for peer_id in list(active_spans.keys()):
if (
info is None
or peer_id not in info.servers
or info.servers[peer_id].state != ServerState.ONLINE
or block_index == len(block_infos) - 1
):
closed_spans.append(active_spans.pop(peer_id))
assert not active_spans, f"spans: {active_spans}"
closed_spans.sort(key=lambda span: span.length, reverse=True)
spans_containing_block = tuple(list() for _ in range(len(block_infos)))
for span in closed_spans:
for block_index in range(span.start, span.end):
spans_containing_block[block_index].append(span)
return closed_spans, spans_containing_block