Optimize RemoteSequenceManager (#106)
- [x] made RemoteSequenceManager into a background thread that pre-fetches information instead of running just in time - [x] moved routing-related stuff to petals.client.routing - [x] extract remote peer routing information to RemoteSequenceInfo - [x] made sure that the code survives continued use (e.g. one hour) - [x] updated every spot where update_ is called manually - [x] modified get_sequence to check that the thread is alive, warn if not - [x] removed max_retries, switched rpc_info to exponential backoff - [x] fixed a bg that causes RemoteSeq* to lose user-defined hyperparameters (e.g. timeout) upon subsequencing (sequential[3:5]) - [x] moved client-side points strategy to client.routing - [x] ensured that RemoteSequenceManager thread created in get_remote_module properly shuts down when the module is destroyed - [x] resolved minor affected todos - [x] modified tests to no longer use PYTHONPATH - [x] worked around protocol error in rpc_info Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com> Co-authored-by: Artem Chumachenko <artek.chumak@gmail.com>pull/110/head
parent
7d859a947b
commit
a2066a4096
@ -1,5 +1,5 @@
|
||||
from petals.client.inference_session import InferenceSession
|
||||
from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
|
||||
from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
|
||||
from petals.client.sequence_manager import RemoteSequenceManager
|
||||
from petals.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
||||
from petals.client.routing.sequence_manager import RemoteSequenceManager
|
||||
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
||||
|
@ -0,0 +1 @@
|
||||
"""Client-side functions responsible for choosing the best server, """
|
@ -0,0 +1,102 @@
|
||||
import dataclasses
|
||||
import time
|
||||
from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
|
||||
|
||||
from hivemind import get_logger, use_hivemind_log_handler
|
||||
|
||||
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RemoteSequenceInfo:
|
||||
"""
|
||||
A dataclass that stores general information about which servers hold any given layer;
|
||||
- updated by RemoteSequenceManager in a background thread
|
||||
- accessed by routing strategies in .on_update
|
||||
:note: this class should *not* be modified by RoutingStrategy.on_update to avoid interference between strategies;
|
||||
Any metadata specific to one routing strategy, it should be stored inside that strategy. Any information that
|
||||
is used by most routing strategies should be moved from said strategies to this class.
|
||||
"""
|
||||
|
||||
block_uids: Tuple[ModuleUID, ...]
|
||||
block_infos: Tuple[RemoteModuleInfo, ...] # note: the contents of RemoteModuleInfo can and will be updated
|
||||
spans_by_priority: List[RemoteSpanInfo]
|
||||
spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
|
||||
last_updated_time: float
|
||||
|
||||
@classmethod
|
||||
def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
|
||||
block_uids = tuple(block_uids)
|
||||
empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
|
||||
empty_spans = tuple([] for _ in range(len(block_uids)))
|
||||
return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=-float("inf"))
|
||||
|
||||
def __getitem__(self, ix: slice):
|
||||
assert isinstance(ix, slice)
|
||||
block_uids, block_infos = self.block_uids[ix], self.block_infos[ix]
|
||||
spans_by_priority, spans_containing_block = self.compute_spans(block_infos)
|
||||
return RemoteSequenceInfo(
|
||||
block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.block_uids)
|
||||
|
||||
def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]):
|
||||
assert len(new_block_infos) == len(self.block_uids)
|
||||
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
|
||||
if info is None:
|
||||
logger.debug(f"Found no block info for block {uid}")
|
||||
continue
|
||||
if not isinstance(info, RemoteModuleInfo):
|
||||
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
|
||||
continue
|
||||
if not info.servers:
|
||||
logger.debug(f"Found no active peers for block {uid}")
|
||||
continue
|
||||
if info.uid != uid:
|
||||
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
|
||||
continue
|
||||
self.block_infos[block_index].servers = info.servers
|
||||
|
||||
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
|
||||
self.last_updated_time = time.perf_counter()
|
||||
|
||||
@staticmethod
|
||||
def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
|
||||
closed_spans = []
|
||||
active_spans = {}
|
||||
for block_index, info in enumerate(block_infos):
|
||||
if info is not None:
|
||||
for peer_id, server in info.servers.items():
|
||||
if server.state != ServerState.ONLINE:
|
||||
continue
|
||||
if peer_id not in active_spans:
|
||||
active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
|
||||
else: # peer_id in active_spans
|
||||
active_spans[peer_id].end = block_index + 1
|
||||
|
||||
for peer_id in list(active_spans.keys()):
|
||||
if (
|
||||
info is None
|
||||
or peer_id not in info.servers
|
||||
or info.servers[peer_id] != ServerState.ONLINE
|
||||
or block_index == len(block_infos) - 1
|
||||
):
|
||||
closed_spans.append(active_spans.pop(peer_id))
|
||||
assert not active_spans, f"spans: {active_spans}"
|
||||
|
||||
closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
|
||||
|
||||
spans_containing_block = tuple(list() for _ in range(len(block_infos)))
|
||||
for span in closed_spans:
|
||||
for block_index in range(span.start, span.end):
|
||||
spans_containing_block[block_index].append(span)
|
||||
|
||||
return closed_spans, spans_containing_block
|
@ -0,0 +1,265 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from weakref import WeakMethod
|
||||
|
||||
from hivemind import DHT, P2P, MSGPackSerializer
|
||||
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||
from hivemind.proto import runtime_pb2
|
||||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
|
||||
import petals.dht_utils
|
||||
from petals.client.routing.sequence_info import RemoteSequenceInfo
|
||||
from petals.client.routing.spending_policy import NoSpendingPolicy
|
||||
from petals.data_structures import ModuleUID, RemoteSpanInfo
|
||||
from petals.server.handler import TransformerConnectionHandler
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class RemoteSequenceManager:
|
||||
"""
|
||||
Sequence manager is a thread that keeps track of remote servers that hold the specified sequence of blocks.
|
||||
TL;DR it tells you, which peers you should ask to get a specific layer. It is used in RemoteSequential.
|
||||
When created, RemoteSequenceManager looks up which servers serve necessary layers by reading from DHT.
|
||||
Using this information, sequence manager can form sequences of servers that collectively have the full sequence.
|
||||
To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr).
|
||||
|
||||
:param dht: a running hivemind.DHT instance, connected to peers that serve the corresponding blocks
|
||||
:param block_uids: a sequence of DHT keys (strings) corresponding to remote layers
|
||||
:param p2p: an optional P2P replica (if not specified, create one via dht.replicate_p2p())
|
||||
:param update_period: by default, refresh DHT information once in this many seconds
|
||||
:param request_timeout: float, in seconds, default timeout for RPC forwad/backward/inference requests
|
||||
:param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1)
|
||||
:param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht
|
||||
:param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time
|
||||
:param start: start the background thread (see the note below). If false, you will need to start it manually.
|
||||
:note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
|
||||
running redundant sequence managers for the same set of layers.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dht: DHT,
|
||||
block_uids: Sequence[ModuleUID],
|
||||
p2p: P2P,
|
||||
update_period: float = 30,
|
||||
request_timeout: float = 30,
|
||||
min_backoff: float = 1,
|
||||
sequence_info: Optional[RemoteSequenceInfo] = None,
|
||||
rpc_info: Optional[dict] = None,
|
||||
*, # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
|
||||
start: bool,
|
||||
):
|
||||
assert len(block_uids) > 0, "Sequences must contain at least one block"
|
||||
self.dht, self.p2p = dht, p2p
|
||||
self.request_timeout, self.min_backoff = request_timeout, min_backoff
|
||||
self.lock_changes = threading.Lock()
|
||||
self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update))
|
||||
self.policy = NoSpendingPolicy()
|
||||
self._rpc_info = rpc_info
|
||||
|
||||
if sequence_info is None:
|
||||
self.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
|
||||
self.update(wait=False)
|
||||
else:
|
||||
self.sequence_info = sequence_info
|
||||
assert block_uids == sequence_info.block_uids
|
||||
self._thread.ready.set() # no need to await the first dht fetch
|
||||
|
||||
if start:
|
||||
self.run_in_background()
|
||||
|
||||
def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
|
||||
"""
|
||||
Starts the updater thread in a background. if await_ready, this method will wait until sequence manager
|
||||
is ready to process incoming requests or for :timeout: seconds max.
|
||||
"""
|
||||
self._thread.start()
|
||||
if await_ready:
|
||||
self._thread.ready.wait(timeout)
|
||||
|
||||
def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
|
||||
"""
|
||||
Form a sequence of remote servers that collectively serve all consecutive layers
|
||||
|
||||
:param start_index: optional index of the first module in a sequence, default = the first of block_uids
|
||||
:param end_index: optional index of the last module (non-inclusive), default = after last of block uids
|
||||
"""
|
||||
if not self.is_alive():
|
||||
logger.error("Using a sequence manager that is not running: it has either crashed or never started")
|
||||
if not self.ready.is_set():
|
||||
logger.warning("Remote SequenceManager is still searching for routes, waiting for it to become ready")
|
||||
self.ready.wait()
|
||||
|
||||
end_index = end_index if end_index is not None else len(self)
|
||||
span_sequence = []
|
||||
current_index = start_index
|
||||
while current_index < end_index:
|
||||
candidate_spans = self.sequence_info.spans_containing_block[current_index]
|
||||
chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
|
||||
|
||||
assert chosen_span.start <= current_index < chosen_span.end
|
||||
span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))
|
||||
current_index = chosen_span.end
|
||||
|
||||
return span_sequence
|
||||
|
||||
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
|
||||
"""Get a RemoteSequenceManager for a sub-sequence of blocks"""
|
||||
assert isinstance(ix, (int, slice))
|
||||
if not isinstance(ix, slice):
|
||||
ix = slice(int(ix), int(ix) + 1, 1)
|
||||
return type(self)(
|
||||
self.dht,
|
||||
self.block_uids[ix],
|
||||
self.p2p,
|
||||
update_period=self._thread.update_period,
|
||||
request_timeout=self.request_timeout,
|
||||
min_backoff=self.min_backoff,
|
||||
sequence_info=self.sequence_info[ix],
|
||||
rpc_info=self._rpc_info,
|
||||
start=True,
|
||||
)
|
||||
|
||||
def update(self, *, wait: bool):
|
||||
"""Run an asynchronous update in background as soon as possible"""
|
||||
self.ready.clear() # TODO this should be a separate event
|
||||
self._thread.trigger.set()
|
||||
if wait:
|
||||
self.ready.wait()
|
||||
|
||||
def _update(self):
|
||||
"""Perform an immediate and synchronous refresh, may take time"""
|
||||
for attempt_no in itertools.count():
|
||||
try:
|
||||
new_block_infos = petals.dht_utils.get_remote_module_infos(
|
||||
self.dht, self.block_uids, expiration_time=float("inf")
|
||||
)
|
||||
with self.lock_changes:
|
||||
self.sequence_info.update_(new_block_infos)
|
||||
missing_blocks = [i for i in range(len(self)) if not self.sequence_info.spans_containing_block[i]]
|
||||
if missing_blocks:
|
||||
raise MissingBlocksError(f"no servers holding blocks {missing_blocks}")
|
||||
self.ready.set() # if there is an active server for every block, we may begin running
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
delay = self.get_retry_delay(attempt_no)
|
||||
logger.warning(f"Could not find route through the model: {repr(e)} (retry in {delay:.0f} sec)")
|
||||
traceback_level = logging.DEBUG if str(e) else logging.WARNING
|
||||
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
|
||||
time.sleep(delay)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.block_uids)
|
||||
|
||||
@property
|
||||
def is_alive(self):
|
||||
return self._thread.is_alive
|
||||
|
||||
@property
|
||||
def ready(self) -> threading.Event:
|
||||
return self._thread.ready
|
||||
|
||||
@property
|
||||
def block_uids(self):
|
||||
return self.sequence_info.block_uids
|
||||
|
||||
@property
|
||||
def rpc_info(self):
|
||||
"""Return the rpc_info queried from one of the servers that hold the first block"""
|
||||
if self._rpc_info is None:
|
||||
for attempt_no in itertools.count():
|
||||
try:
|
||||
self._update()
|
||||
peer_id, _ = random.choice(list(self.sequence_info.block_infos[0].servers.items()))
|
||||
stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
|
||||
outputs = RemoteExpertWorker.run_coroutine(
|
||||
stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
|
||||
)
|
||||
self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
|
||||
break
|
||||
except Exception as e:
|
||||
delay = self.get_retry_delay(attempt_no)
|
||||
logger.warning(
|
||||
f"Caught exception when gathering information from peer {peer_id} "
|
||||
f"(retry in {delay:.0f} sec): {repr(e)}"
|
||||
)
|
||||
traceback_level = logging.DEBUG if str(e) else logging.WARNING
|
||||
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
|
||||
time.sleep(delay)
|
||||
|
||||
return self._rpc_info
|
||||
|
||||
def get_retry_delay(self, attempt_no: int) -> float:
|
||||
if attempt_no == 0:
|
||||
return 0
|
||||
return self.min_backoff * 2 ** (attempt_no - 1)
|
||||
|
||||
def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
|
||||
:param args: request-specific inputs, typicall block uids and input tensors
|
||||
:param kwargs: additional request context, such as remote peer ID
|
||||
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
|
||||
"""
|
||||
return dict(points=self.policy.get_points(protocol, *args, **kwargs))
|
||||
|
||||
def shutdown(self):
|
||||
self._thread.shutdown()
|
||||
|
||||
|
||||
class _SequenceManagerUpdateThread(threading.Thread):
|
||||
def __init__(self, update_period: float, ref_update_manager: WeakMethod):
|
||||
super().__init__(daemon=True)
|
||||
self.ref_update_manager = ref_update_manager
|
||||
self.ready = threading.Event()
|
||||
self.trigger = threading.Event()
|
||||
self.last_update_time = -float("inf")
|
||||
self.update_period = update_period
|
||||
self.should_shutdown = False
|
||||
|
||||
def run(self) -> None:
|
||||
while not self.should_shutdown:
|
||||
self.trigger.wait(max(0.0, min(self.update_period, time.perf_counter() - self.last_update_time)))
|
||||
|
||||
if self.should_shutdown:
|
||||
logger.debug(f"{self.__class__.__name__} is shutting down")
|
||||
break
|
||||
|
||||
update_manager = self.ref_update_manager()
|
||||
if update_manager is None:
|
||||
logger.debug(f"{self.__class__.__name__} exited because the sequence manager no longer exists")
|
||||
break
|
||||
|
||||
try:
|
||||
update_manager()
|
||||
self.trigger.clear()
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
finally:
|
||||
del update_manager
|
||||
|
||||
logger.debug(f"{self.__class__.__name__} thread exited")
|
||||
|
||||
def shutdown(self, timeout: Optional[float] = None):
|
||||
self.should_shutdown = True
|
||||
self.trigger.set()
|
||||
self.join(timeout)
|
||||
|
||||
def __del__(self):
|
||||
if self.is_alive():
|
||||
self.shutdown()
|
||||
|
||||
|
||||
class MissingBlocksError(Exception):
|
||||
def __repr__(self):
|
||||
return self.args[0]
|
@ -1,179 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import threading
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from hivemind import DHT, P2P, DHTExpiration, MSGPackSerializer
|
||||
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||
from hivemind.proto import runtime_pb2
|
||||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
|
||||
import petals.dht_utils
|
||||
from petals.client.spending_policy import NoSpendingPolicy
|
||||
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
|
||||
from petals.server.handler import TransformerConnectionHandler
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class RemoteSequenceManager:
|
||||
"""
|
||||
Keeps and updates the meta-information about which peers host which blocks.
|
||||
In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dht: DHT,
|
||||
block_uids: Sequence[ModuleUID],
|
||||
p2p: P2P,
|
||||
max_retries: int = 3,
|
||||
request_timeout: float = 20,
|
||||
min_backoff: float = 1,
|
||||
):
|
||||
assert len(block_uids) > 0, "Sequences must contain at least one block"
|
||||
self.dht, self.p2p = dht, p2p
|
||||
self.block_uids: List[ModuleUID] = list(block_uids)
|
||||
self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
|
||||
self.spans_by_priority: List[RemoteSpanInfo] = [] # sorted from best to worst
|
||||
self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
|
||||
self.last_update_time: DHTExpiration = -float("inf")
|
||||
self.max_retries = max_retries
|
||||
self.request_timeout, self.min_backoff = request_timeout, min_backoff
|
||||
self._rpc_info = None
|
||||
self.lock_changes = threading.Lock()
|
||||
self.policy = NoSpendingPolicy()
|
||||
self.update_()
|
||||
|
||||
for uid, info in zip(self.block_uids, self.block_infos):
|
||||
assert info is not None, f"Found no remote peers for block {uid}"
|
||||
assert self.spans_by_priority and self.spans_containing_block
|
||||
|
||||
def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
|
||||
"""
|
||||
Form a sequence of remote servers that collectively serve all consecutive layers
|
||||
|
||||
:param start_index: optional index of the first module in a sequence, default = the first of block_uids
|
||||
:param end_index: optional index of the last module (non-inclusive), default = after last of block uids
|
||||
"""
|
||||
end_index = end_index if end_index is not None else len(self.block_uids)
|
||||
span_sequence = []
|
||||
current_index = start_index
|
||||
while current_index < end_index:
|
||||
candidate_spans = self.spans_containing_block[current_index]
|
||||
chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
|
||||
|
||||
assert chosen_span.start <= current_index < chosen_span.end
|
||||
span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))
|
||||
current_index = chosen_span.end
|
||||
|
||||
return span_sequence
|
||||
|
||||
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
|
||||
"""Get a RemoteSequenceManager for a sub-sequence of blocks"""
|
||||
assert isinstance(ix, (int, slice))
|
||||
if not isinstance(ix, slice):
|
||||
ix = slice(int(ix), int(ix) + 1, 1)
|
||||
with self.lock_changes:
|
||||
subseq = RemoteSequenceManager(self.dht, self.block_uids[ix], self.p2p)
|
||||
subseq.block_infos = self.block_infos[ix]
|
||||
subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
|
||||
subseq.last_update_time = self.last_update_time
|
||||
return subseq
|
||||
|
||||
def update_(self):
|
||||
with self.lock_changes:
|
||||
self.update_block_infos_()
|
||||
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
|
||||
|
||||
def update_block_infos_(self):
|
||||
new_block_infos = petals.dht_utils.get_remote_module_infos(
|
||||
self.dht, self.block_uids, expiration_time=float("inf")
|
||||
)
|
||||
assert len(new_block_infos) == len(self.block_uids)
|
||||
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
|
||||
if info is None:
|
||||
logger.warning(f"Found no block info for block {uid}")
|
||||
continue
|
||||
if not isinstance(info, RemoteModuleInfo):
|
||||
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
|
||||
if not info.servers:
|
||||
logger.warning(f"Found no active peers for block {uid}")
|
||||
if info.uid != uid:
|
||||
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
|
||||
self.block_infos[block_index] = info
|
||||
|
||||
@staticmethod
|
||||
def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
|
||||
closed_spans = []
|
||||
active_spans = {}
|
||||
for block_index, info in enumerate(block_infos):
|
||||
if info is not None:
|
||||
for peer_id, server in info.servers.items():
|
||||
if server.state != ServerState.ONLINE:
|
||||
continue
|
||||
if peer_id not in active_spans:
|
||||
active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
|
||||
else: # peer_id in active_spans
|
||||
active_spans[peer_id].end = block_index + 1
|
||||
|
||||
for peer_id in list(active_spans.keys()):
|
||||
if (
|
||||
info is None
|
||||
or peer_id not in info.servers
|
||||
or info.servers[peer_id].state != ServerState.ONLINE
|
||||
or block_index == len(block_infos) - 1
|
||||
):
|
||||
closed_spans.append(active_spans.pop(peer_id))
|
||||
assert not active_spans, f"spans: {active_spans}"
|
||||
|
||||
closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
|
||||
|
||||
spans_containing_block = tuple(list() for _ in range(len(block_infos)))
|
||||
for span in closed_spans:
|
||||
for block_index in range(span.start, span.end):
|
||||
spans_containing_block[block_index].append(span)
|
||||
|
||||
return closed_spans, spans_containing_block
|
||||
|
||||
def __len__(self):
|
||||
return len(self.block_uids)
|
||||
|
||||
@property
|
||||
def rpc_info(self):
|
||||
"""Return the rpc_info queried from one of the servers that hold the first block"""
|
||||
if self._rpc_info is None:
|
||||
retries = 0
|
||||
for i in range(self.max_retries):
|
||||
try:
|
||||
self.update_()
|
||||
peer_id = random.choice(list(self.block_infos[0].servers.keys()))
|
||||
stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
|
||||
outputs = RemoteExpertWorker.run_coroutine(
|
||||
stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
|
||||
)
|
||||
self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
|
||||
break
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
if retries >= self.max_retries:
|
||||
raise e
|
||||
else:
|
||||
logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True)
|
||||
return self._rpc_info
|
||||
|
||||
def get_retry_delay(self, attempt_no: int) -> float:
|
||||
if attempt_no == 0:
|
||||
return 0
|
||||
return self.min_backoff * 2 ** (attempt_no - 1)
|
||||
|
||||
def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[bytes]:
|
||||
"""
|
||||
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
|
||||
:param args: request-specific inputs, typicall block uids and input tensors
|
||||
:param kwargs: additional request context, such as remote peer ID
|
||||
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
|
||||
"""
|
||||
return MSGPackSerializer.dumps(dict(points=self.policy.get_points(protocol, *args, **kwargs)))
|
@ -0,0 +1,54 @@
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from hivemind import DHT, get_logger, use_hivemind_log_handler
|
||||
from test_utils import *
|
||||
|
||||
from petals.client import RemoteSequenceManager, RemoteSequential
|
||||
from petals.client.remote_model import DistributedBloomConfig
|
||||
from petals.data_structures import UID_DELIMITER
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_sequence_manager_shutdown():
|
||||
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
||||
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
|
||||
sequential = RemoteSequential(config, dht)
|
||||
shutdown_evt = threading.Event()
|
||||
|
||||
# test RemoteSequential with lossy compression
|
||||
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
|
||||
sequential = RemoteSequential(
|
||||
config,
|
||||
dht,
|
||||
sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt, start=True),
|
||||
)
|
||||
|
||||
assert sequential.sequence_manager.is_alive()
|
||||
assert sequential.sequence_manager._thread.ready.is_set()
|
||||
assert not shutdown_evt.is_set()
|
||||
sequential(torch.randn(1, 2, config.hidden_size))
|
||||
|
||||
sequential.sequence_manager.shutdown()
|
||||
del sequential
|
||||
time.sleep(1)
|
||||
|
||||
assert shutdown_evt.is_set()
|
||||
|
||||
|
||||
class TestSequenceManager(RemoteSequenceManager):
|
||||
"""A sequence manager that signals if it was shut down"""
|
||||
|
||||
def __init__(self, *args, _was_shut_down: threading.Event, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._was_shut_down = _was_shut_down
|
||||
|
||||
def shutdown(self):
|
||||
super().shutdown()
|
||||
assert not self.is_alive()
|
||||
self._was_shut_down.set()
|
Loading…
Reference in New Issue