From a2066a4096b5ce776ebdcda57c3dd2a994356047 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 1 Dec 2022 10:25:55 +0300 Subject: [PATCH] Optimize RemoteSequenceManager (#106) - [x] made RemoteSequenceManager into a background thread that pre-fetches information instead of running just in time - [x] moved routing-related stuff to petals.client.routing - [x] extract remote peer routing information to RemoteSequenceInfo - [x] made sure that the code survives continued use (e.g. one hour) - [x] updated every spot where update_ is called manually - [x] modified get_sequence to check that the thread is alive, warn if not - [x] removed max_retries, switched rpc_info to exponential backoff - [x] fixed a bg that causes RemoteSeq* to lose user-defined hyperparameters (e.g. timeout) upon subsequencing (sequential[3:5]) - [x] moved client-side points strategy to client.routing - [x] ensured that RemoteSequenceManager thread created in get_remote_module properly shuts down when the module is destroyed - [x] resolved minor affected todos - [x] modified tests to no longer use PYTHONPATH - [x] worked around protocol error in rpc_info Co-authored-by: Aleksandr Borzunov Co-authored-by: Artem Chumachenko --- .github/workflows/run-tests.yaml | 2 +- src/petals/client/__init__.py | 4 +- src/petals/client/inference_session.py | 14 +- src/petals/client/remote_model.py | 2 +- src/petals/client/remote_sequential.py | 15 +- src/petals/client/routing/__init__.py | 1 + src/petals/client/routing/sequence_info.py | 102 +++++++ src/petals/client/routing/sequence_manager.py | 265 ++++++++++++++++++ .../client/{ => routing}/spending_policy.py | 0 src/petals/client/sequence_manager.py | 179 ------------ src/petals/client/sequential_autograd.py | 11 +- src/petals/dht_utils.py | 11 +- src/petals/server/block_selection.py | 6 +- tests/test_remote_sequential.py | 16 +- tests/test_sequence_manager.py | 54 ++++ 15 files changed, 463 insertions(+), 219 deletions(-) create mode 100644 src/petals/client/routing/__init__.py create mode 100644 src/petals/client/routing/sequence_info.py create mode 100644 src/petals/client/routing/sequence_manager.py rename src/petals/client/{ => routing}/spending_policy.py (100%) delete mode 100644 src/petals/client/sequence_manager.py create mode 100644 tests/test_sequence_manager.py diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 4bf131f..acab39c 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -104,7 +104,7 @@ jobs: kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived init - PYTHONPATH=. pytest tests --durations=0 --durations-min=1.0 -v + pytest tests --durations=0 --durations-min=1.0 -v kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived tests diff --git a/src/petals/client/__init__.py b/src/petals/client/__init__.py index d1e16ae..93fc8a6 100644 --- a/src/petals/client/__init__.py +++ b/src/petals/client/__init__.py @@ -1,5 +1,5 @@ from petals.client.inference_session import InferenceSession from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock -from petals.client.sequence_manager import RemoteSequenceManager -from petals.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase +from petals.client.routing.sequence_manager import RemoteSequenceManager +from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 6490d2a..100f704 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -20,7 +20,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.p2p import StubBase from hivemind.proto import runtime_pb2 -from petals.client.sequence_manager import RemoteSequenceManager +from petals.client.routing.sequence_manager import RemoteSequenceManager from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo from petals.server.handler import TransformerConnectionHandler from petals.utils.misc import DUMMY, is_dummy @@ -44,14 +44,14 @@ class _ServerInferenceSession: *, timeout: float, max_length: int, - points: int = 0, + **metadata, ): self.uid, self.rpc_info = uid, rpc_info self.num_blocks = uid.count(CHAIN_DELIMITER) + 1 self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter self.timeout = timeout - self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points)) + self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, **metadata)) self.stepped = False self.closed = False @@ -162,7 +162,7 @@ class InferenceSession: An interface to a multi-step *inference* session for a sequence of remote transformer blocks """ - def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int, **metadata): + def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int): self._sequence_manager = sequence_manager self._p2p = p2p self._closed = False @@ -171,7 +171,6 @@ class InferenceSession: self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers self._position = 0 self._max_length = max_length - self._metadata = metadata def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]: server_sessions = [] @@ -179,6 +178,7 @@ class InferenceSession: for span in chosen_spans: stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id) span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end]) + metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id) session = RemoteExpertWorker.run_coroutine( _ServerInferenceSession.create( stub, @@ -186,7 +186,7 @@ class InferenceSession: rpc_info=self._sequence_manager.rpc_info, timeout=self._sequence_manager.request_timeout, max_length=self._max_length, - **self._metadata, + **metadata, ) ) server_sessions.append(session) @@ -237,7 +237,7 @@ class InferenceSession: logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}") try: if attempt_no >= 1: - self._sequence_manager.update_() + self._sequence_manager.update(wait=True) if not self._chosen_spans or not self._server_sessions or attempt_no >= 1: # If there is a failed server session, this code closes it self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1]) diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index e95ff81..312b1b4 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -36,7 +36,7 @@ class DistributedBloomConfig(BloomConfig): chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU pre_seq_len: int = 0 # a number of tokens for prompt tuning. tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters'] - request_timeout: int = 20 # a number of seconds for waiting result from each node + request_timeout: int = 30 # a number of seconds for waiting result from each node original_register_parameter = nn.Module.register_parameter diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index 933edcb..cb53979 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -9,7 +9,7 @@ from torch import nn import petals.client from petals.client.inference_session import InferenceSession -from petals.client.sequence_manager import RemoteSequenceManager +from petals.client.routing.sequence_manager import RemoteSequenceManager from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction from petals.data_structures import UID_DELIMITER from petals.utils.misc import DUMMY @@ -30,7 +30,7 @@ class RemoteSequential(nn.Module): dht_prefix: Optional[str] = None, p2p: Optional[P2P] = None, sequence_manager: Optional[RemoteSequenceManager] = None, - request_timeout: int = 20, + **kwargs, ): super().__init__() self.config = config @@ -39,16 +39,18 @@ class RemoteSequential(nn.Module): self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager) - block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)] + block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)) if sequence_manager is None: logger.debug(f"Creating new sequence manager for block uids: {block_uids}") - self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p, request_timeout=request_timeout) + self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p, start=True, **kwargs) self.is_subsequence = False else: logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules") + if kwargs: + logger.warning(f"Parameters {kwargs} are ignored because sequence_manager is explicitly provided") self.sequence_manager = sequence_manager - assert isinstance(sequence_manager.block_uids, list) - self.is_subsequence = self.sequence_manager.block_uids != block_uids + assert isinstance(sequence_manager.sequence_info.block_uids, tuple) + self.is_subsequence = self.sequence_manager.sequence_info.block_uids != block_uids def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY): outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager) @@ -81,7 +83,6 @@ class RemoteSequential(nn.Module): return len(self.sequence_manager) def inference_session(self, **kwargs) -> InferenceSession: - self.sequence_manager.update_() return InferenceSession(self.sequence_manager, self.p2p, **kwargs) def extra_repr(self) -> str: diff --git a/src/petals/client/routing/__init__.py b/src/petals/client/routing/__init__.py new file mode 100644 index 0000000..62885d8 --- /dev/null +++ b/src/petals/client/routing/__init__.py @@ -0,0 +1 @@ +"""Client-side functions responsible for choosing the best server, """ diff --git a/src/petals/client/routing/sequence_info.py b/src/petals/client/routing/sequence_info.py new file mode 100644 index 0000000..a756003 --- /dev/null +++ b/src/petals/client/routing/sequence_info.py @@ -0,0 +1,102 @@ +import dataclasses +import time +from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar + +from hivemind import get_logger, use_hivemind_log_handler + +from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState + +use_hivemind_log_handler("in_root_logger") +logger = get_logger(__file__) + + +T = TypeVar("T") + + +@dataclasses.dataclass +class RemoteSequenceInfo: + """ + A dataclass that stores general information about which servers hold any given layer; + - updated by RemoteSequenceManager in a background thread + - accessed by routing strategies in .on_update + :note: this class should *not* be modified by RoutingStrategy.on_update to avoid interference between strategies; + Any metadata specific to one routing strategy, it should be stored inside that strategy. Any information that + is used by most routing strategies should be moved from said strategies to this class. + """ + + block_uids: Tuple[ModuleUID, ...] + block_infos: Tuple[RemoteModuleInfo, ...] # note: the contents of RemoteModuleInfo can and will be updated + spans_by_priority: List[RemoteSpanInfo] + spans_containing_block: Tuple[List[RemoteSpanInfo], ...] + last_updated_time: float + + @classmethod + def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T: + block_uids = tuple(block_uids) + empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids) + empty_spans = tuple([] for _ in range(len(block_uids))) + return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=-float("inf")) + + def __getitem__(self, ix: slice): + assert isinstance(ix, slice) + block_uids, block_infos = self.block_uids[ix], self.block_infos[ix] + spans_by_priority, spans_containing_block = self.compute_spans(block_infos) + return RemoteSequenceInfo( + block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time + ) + + def __len__(self): + return len(self.block_uids) + + def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]): + assert len(new_block_infos) == len(self.block_uids) + for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)): + if info is None: + logger.debug(f"Found no block info for block {uid}") + continue + if not isinstance(info, RemoteModuleInfo): + logger.warning(f"Unexpected dht entry type for {uid}: {info}") + continue + if not info.servers: + logger.debug(f"Found no active peers for block {uid}") + continue + if info.uid != uid: + logger.warning(f"The DHT entry for {uid} actually points to {info.uid}") + continue + self.block_infos[block_index].servers = info.servers + + self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos) + self.last_updated_time = time.perf_counter() + + @staticmethod + def compute_spans(block_infos: Sequence[RemoteModuleInfo]): + closed_spans = [] + active_spans = {} + for block_index, info in enumerate(block_infos): + if info is not None: + for peer_id, server in info.servers.items(): + if server.state != ServerState.ONLINE: + continue + if peer_id not in active_spans: + active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id) + else: # peer_id in active_spans + active_spans[peer_id].end = block_index + 1 + + for peer_id in list(active_spans.keys()): + if ( + info is None + or peer_id not in info.servers + or info.servers[peer_id] != ServerState.ONLINE + or block_index == len(block_infos) - 1 + ): + closed_spans.append(active_spans.pop(peer_id)) + assert not active_spans, f"spans: {active_spans}" + + closed_spans.sort(key=lambda span: span.end - span.start, reverse=True) + + spans_containing_block = tuple(list() for _ in range(len(block_infos))) + for span in closed_spans: + for block_index in range(span.start, span.end): + spans_containing_block[block_index].append(span) + + return closed_spans, spans_containing_block diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py new file mode 100644 index 0000000..0c684f9 --- /dev/null +++ b/src/petals/client/routing/sequence_manager.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +import itertools +import logging +import random +import threading +import time +from typing import Any, Dict, List, Optional, Sequence, Union +from weakref import WeakMethod + +from hivemind import DHT, P2P, MSGPackSerializer +from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker +from hivemind.proto import runtime_pb2 +from hivemind.utils.logging import get_logger, use_hivemind_log_handler + +import petals.dht_utils +from petals.client.routing.sequence_info import RemoteSequenceInfo +from petals.client.routing.spending_policy import NoSpendingPolicy +from petals.data_structures import ModuleUID, RemoteSpanInfo +from petals.server.handler import TransformerConnectionHandler + +use_hivemind_log_handler("in_root_logger") +logger = get_logger(__file__) + + +class RemoteSequenceManager: + """ + Sequence manager is a thread that keeps track of remote servers that hold the specified sequence of blocks. + TL;DR it tells you, which peers you should ask to get a specific layer. It is used in RemoteSequential. + When created, RemoteSequenceManager looks up which servers serve necessary layers by reading from DHT. + Using this information, sequence manager can form sequences of servers that collectively have the full sequence. + To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr). + + :param dht: a running hivemind.DHT instance, connected to peers that serve the corresponding blocks + :param block_uids: a sequence of DHT keys (strings) corresponding to remote layers + :param p2p: an optional P2P replica (if not specified, create one via dht.replicate_p2p()) + :param update_period: by default, refresh DHT information once in this many seconds + :param request_timeout: float, in seconds, default timeout for RPC forwad/backward/inference requests + :param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1) + :param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht + :param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time + :param start: start the background thread (see the note below). If false, you will need to start it manually. + :note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid + running redundant sequence managers for the same set of layers. + + """ + + def __init__( + self, + dht: DHT, + block_uids: Sequence[ModuleUID], + p2p: P2P, + update_period: float = 30, + request_timeout: float = 30, + min_backoff: float = 1, + sequence_info: Optional[RemoteSequenceInfo] = None, + rpc_info: Optional[dict] = None, + *, # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below) + start: bool, + ): + assert len(block_uids) > 0, "Sequences must contain at least one block" + self.dht, self.p2p = dht, p2p + self.request_timeout, self.min_backoff = request_timeout, min_backoff + self.lock_changes = threading.Lock() + self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update)) + self.policy = NoSpendingPolicy() + self._rpc_info = rpc_info + + if sequence_info is None: + self.sequence_info = RemoteSequenceInfo.make_empty(block_uids) + self.update(wait=False) + else: + self.sequence_info = sequence_info + assert block_uids == sequence_info.block_uids + self._thread.ready.set() # no need to await the first dht fetch + + if start: + self.run_in_background() + + def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None: + """ + Starts the updater thread in a background. if await_ready, this method will wait until sequence manager + is ready to process incoming requests or for :timeout: seconds max. + """ + self._thread.start() + if await_ready: + self._thread.ready.wait(timeout) + + def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]: + """ + Form a sequence of remote servers that collectively serve all consecutive layers + + :param start_index: optional index of the first module in a sequence, default = the first of block_uids + :param end_index: optional index of the last module (non-inclusive), default = after last of block uids + """ + if not self.is_alive(): + logger.error("Using a sequence manager that is not running: it has either crashed or never started") + if not self.ready.is_set(): + logger.warning("Remote SequenceManager is still searching for routes, waiting for it to become ready") + self.ready.wait() + + end_index = end_index if end_index is not None else len(self) + span_sequence = [] + current_index = start_index + while current_index < end_index: + candidate_spans = self.sequence_info.spans_containing_block[current_index] + chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing + + assert chosen_span.start <= current_index < chosen_span.end + span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id)) + current_index = chosen_span.end + + return span_sequence + + def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager: + """Get a RemoteSequenceManager for a sub-sequence of blocks""" + assert isinstance(ix, (int, slice)) + if not isinstance(ix, slice): + ix = slice(int(ix), int(ix) + 1, 1) + return type(self)( + self.dht, + self.block_uids[ix], + self.p2p, + update_period=self._thread.update_period, + request_timeout=self.request_timeout, + min_backoff=self.min_backoff, + sequence_info=self.sequence_info[ix], + rpc_info=self._rpc_info, + start=True, + ) + + def update(self, *, wait: bool): + """Run an asynchronous update in background as soon as possible""" + self.ready.clear() # TODO this should be a separate event + self._thread.trigger.set() + if wait: + self.ready.wait() + + def _update(self): + """Perform an immediate and synchronous refresh, may take time""" + for attempt_no in itertools.count(): + try: + new_block_infos = petals.dht_utils.get_remote_module_infos( + self.dht, self.block_uids, expiration_time=float("inf") + ) + with self.lock_changes: + self.sequence_info.update_(new_block_infos) + missing_blocks = [i for i in range(len(self)) if not self.sequence_info.spans_containing_block[i]] + if missing_blocks: + raise MissingBlocksError(f"no servers holding blocks {missing_blocks}") + self.ready.set() # if there is an active server for every block, we may begin running + break + + except Exception as e: + delay = self.get_retry_delay(attempt_no) + logger.warning(f"Could not find route through the model: {repr(e)} (retry in {delay:.0f} sec)") + traceback_level = logging.DEBUG if str(e) else logging.WARNING + logger.log(traceback_level, "See detailed traceback below:", exc_info=True) + time.sleep(delay) + + def __len__(self): + return len(self.block_uids) + + @property + def is_alive(self): + return self._thread.is_alive + + @property + def ready(self) -> threading.Event: + return self._thread.ready + + @property + def block_uids(self): + return self.sequence_info.block_uids + + @property + def rpc_info(self): + """Return the rpc_info queried from one of the servers that hold the first block""" + if self._rpc_info is None: + for attempt_no in itertools.count(): + try: + self._update() + peer_id, _ = random.choice(list(self.sequence_info.block_infos[0].servers.items())) + stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id) + outputs = RemoteExpertWorker.run_coroutine( + stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0])) + ) + self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info) + break + except Exception as e: + delay = self.get_retry_delay(attempt_no) + logger.warning( + f"Caught exception when gathering information from peer {peer_id} " + f"(retry in {delay:.0f} sec): {repr(e)}" + ) + traceback_level = logging.DEBUG if str(e) else logging.WARNING + logger.log(traceback_level, "See detailed traceback below:", exc_info=True) + time.sleep(delay) + + return self._rpc_info + + def get_retry_delay(self, attempt_no: int) -> float: + if attempt_no == 0: + return 0 + return self.min_backoff * 2 ** (attempt_no - 1) + + def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]: + """ + :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference" + :param args: request-specific inputs, typicall block uids and input tensors + :param kwargs: additional request context, such as remote peer ID + :returns: msgpack-serialized metadata dict that will be passed alongside a given request + """ + return dict(points=self.policy.get_points(protocol, *args, **kwargs)) + + def shutdown(self): + self._thread.shutdown() + + +class _SequenceManagerUpdateThread(threading.Thread): + def __init__(self, update_period: float, ref_update_manager: WeakMethod): + super().__init__(daemon=True) + self.ref_update_manager = ref_update_manager + self.ready = threading.Event() + self.trigger = threading.Event() + self.last_update_time = -float("inf") + self.update_period = update_period + self.should_shutdown = False + + def run(self) -> None: + while not self.should_shutdown: + self.trigger.wait(max(0.0, min(self.update_period, time.perf_counter() - self.last_update_time))) + + if self.should_shutdown: + logger.debug(f"{self.__class__.__name__} is shutting down") + break + + update_manager = self.ref_update_manager() + if update_manager is None: + logger.debug(f"{self.__class__.__name__} exited because the sequence manager no longer exists") + break + + try: + update_manager() + self.trigger.clear() + except Exception as e: + logger.exception(e) + finally: + del update_manager + + logger.debug(f"{self.__class__.__name__} thread exited") + + def shutdown(self, timeout: Optional[float] = None): + self.should_shutdown = True + self.trigger.set() + self.join(timeout) + + def __del__(self): + if self.is_alive(): + self.shutdown() + + +class MissingBlocksError(Exception): + def __repr__(self): + return self.args[0] diff --git a/src/petals/client/spending_policy.py b/src/petals/client/routing/spending_policy.py similarity index 100% rename from src/petals/client/spending_policy.py rename to src/petals/client/routing/spending_policy.py diff --git a/src/petals/client/sequence_manager.py b/src/petals/client/sequence_manager.py deleted file mode 100644 index 89ade87..0000000 --- a/src/petals/client/sequence_manager.py +++ /dev/null @@ -1,179 +0,0 @@ -from __future__ import annotations - -import random -import threading -from typing import List, Optional, Sequence, Tuple, Union - -from hivemind import DHT, P2P, DHTExpiration, MSGPackSerializer -from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker -from hivemind.proto import runtime_pb2 -from hivemind.utils.logging import get_logger, use_hivemind_log_handler - -import petals.dht_utils -from petals.client.spending_policy import NoSpendingPolicy -from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState -from petals.server.handler import TransformerConnectionHandler - -use_hivemind_log_handler("in_root_logger") -logger = get_logger(__file__) - - -class RemoteSequenceManager: - """ - Keeps and updates the meta-information about which peers host which blocks. - In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc. - """ - - def __init__( - self, - dht: DHT, - block_uids: Sequence[ModuleUID], - p2p: P2P, - max_retries: int = 3, - request_timeout: float = 20, - min_backoff: float = 1, - ): - assert len(block_uids) > 0, "Sequences must contain at least one block" - self.dht, self.p2p = dht, p2p - self.block_uids: List[ModuleUID] = list(block_uids) - self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids) - self.spans_by_priority: List[RemoteSpanInfo] = [] # sorted from best to worst - self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids))) - self.last_update_time: DHTExpiration = -float("inf") - self.max_retries = max_retries - self.request_timeout, self.min_backoff = request_timeout, min_backoff - self._rpc_info = None - self.lock_changes = threading.Lock() - self.policy = NoSpendingPolicy() - self.update_() - - for uid, info in zip(self.block_uids, self.block_infos): - assert info is not None, f"Found no remote peers for block {uid}" - assert self.spans_by_priority and self.spans_containing_block - - def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]: - """ - Form a sequence of remote servers that collectively serve all consecutive layers - - :param start_index: optional index of the first module in a sequence, default = the first of block_uids - :param end_index: optional index of the last module (non-inclusive), default = after last of block uids - """ - end_index = end_index if end_index is not None else len(self.block_uids) - span_sequence = [] - current_index = start_index - while current_index < end_index: - candidate_spans = self.spans_containing_block[current_index] - chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing - - assert chosen_span.start <= current_index < chosen_span.end - span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id)) - current_index = chosen_span.end - - return span_sequence - - def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager: - """Get a RemoteSequenceManager for a sub-sequence of blocks""" - assert isinstance(ix, (int, slice)) - if not isinstance(ix, slice): - ix = slice(int(ix), int(ix) + 1, 1) - with self.lock_changes: - subseq = RemoteSequenceManager(self.dht, self.block_uids[ix], self.p2p) - subseq.block_infos = self.block_infos[ix] - subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos) - subseq.last_update_time = self.last_update_time - return subseq - - def update_(self): - with self.lock_changes: - self.update_block_infos_() - self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos) - - def update_block_infos_(self): - new_block_infos = petals.dht_utils.get_remote_module_infos( - self.dht, self.block_uids, expiration_time=float("inf") - ) - assert len(new_block_infos) == len(self.block_uids) - for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)): - if info is None: - logger.warning(f"Found no block info for block {uid}") - continue - if not isinstance(info, RemoteModuleInfo): - logger.warning(f"Unexpected dht entry type for {uid}: {info}") - if not info.servers: - logger.warning(f"Found no active peers for block {uid}") - if info.uid != uid: - logger.warning(f"The DHT entry for {uid} actually points to {info.uid}") - self.block_infos[block_index] = info - - @staticmethod - def compute_spans(block_infos: Sequence[RemoteModuleInfo]): - closed_spans = [] - active_spans = {} - for block_index, info in enumerate(block_infos): - if info is not None: - for peer_id, server in info.servers.items(): - if server.state != ServerState.ONLINE: - continue - if peer_id not in active_spans: - active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id) - else: # peer_id in active_spans - active_spans[peer_id].end = block_index + 1 - - for peer_id in list(active_spans.keys()): - if ( - info is None - or peer_id not in info.servers - or info.servers[peer_id].state != ServerState.ONLINE - or block_index == len(block_infos) - 1 - ): - closed_spans.append(active_spans.pop(peer_id)) - assert not active_spans, f"spans: {active_spans}" - - closed_spans.sort(key=lambda span: span.end - span.start, reverse=True) - - spans_containing_block = tuple(list() for _ in range(len(block_infos))) - for span in closed_spans: - for block_index in range(span.start, span.end): - spans_containing_block[block_index].append(span) - - return closed_spans, spans_containing_block - - def __len__(self): - return len(self.block_uids) - - @property - def rpc_info(self): - """Return the rpc_info queried from one of the servers that hold the first block""" - if self._rpc_info is None: - retries = 0 - for i in range(self.max_retries): - try: - self.update_() - peer_id = random.choice(list(self.block_infos[0].servers.keys())) - stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id) - outputs = RemoteExpertWorker.run_coroutine( - stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0])) - ) - self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info) - break - except Exception as e: - retries += 1 - if retries >= self.max_retries: - raise e - else: - logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True) - return self._rpc_info - - def get_retry_delay(self, attempt_no: int) -> float: - if attempt_no == 0: - return 0 - return self.min_backoff * 2 ** (attempt_no - 1) - - def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[bytes]: - """ - :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference" - :param args: request-specific inputs, typicall block uids and input tensors - :param kwargs: additional request context, such as remote peer ID - :returns: msgpack-serialized metadata dict that will be passed alongside a given request - """ - return MSGPackSerializer.dumps(dict(points=self.policy.get_points(protocol, *args, **kwargs))) diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 76b5c24..21ac4bc 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -8,11 +8,12 @@ from collections import deque from typing import List, Optional, Sequence, Tuple import torch +from hivemind import MSGPackSerializer from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.utils.logging import get_logger from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward -from petals.client.sequence_manager import RemoteSequenceManager +from petals.client.routing.sequence_manager import RemoteSequenceManager from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo from petals.server.handler import TransformerConnectionHandler from petals.utils.misc import DUMMY, is_dummy @@ -58,7 +59,7 @@ async def sequential_forward( logger.debug(f"Forward: block {block_idx}, attempt {attempt_no}") try: if attempt_no >= 1: - sequence_manager.update_() + sequence_manager._update() if not sequences or attempt_no >= 1: sequences = deque(sequence_manager.make_sequence(block_idx, end_index)) # make_sequence() could return a longer sequence @@ -78,7 +79,7 @@ async def sequential_forward( sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.request_timeout, - metadata=metadata, + metadata=MSGPackSerializer.dumps(metadata), ) assert isinstance(outputs, torch.Tensor) @@ -136,7 +137,7 @@ async def sequential_backward( logger.debug(f"Backward: block {span.end - 1}, attempt {attempt_no}") try: if attempt_no >= 1: - sequence_manager.update_() + sequence_manager.update(wait=True) _, backup_inputs, backup_sequences = await sequential_forward( inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end ) @@ -162,7 +163,7 @@ async def sequential_backward( grad_outputs, prompts[span.start : span.end], timeout=sequence_manager.request_timeout, - metadata=metadata, + metadata=MSGPackSerializer.dumps(metadata), ) grad_outputs = [grad_outputs] grad_prompts_reversed.extend(span_grad_prompts) diff --git a/src/petals/dht_utils.py b/src/petals/dht_utils.py index c8fd7af..71a6cd8 100644 --- a/src/petals/dht_utils.py +++ b/src/petals/dht_utils.py @@ -94,7 +94,7 @@ async def _get_remote_sequence( ) -> petals.client.RemoteSequential: uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)] p2p = await dht.replicate_p2p() - manager = petals.client.RemoteSequenceManager(dht, uids, p2p) + manager = petals.client.RemoteSequenceManager(dht, uids, p2p, start=True) return petals.client.RemoteSequential(config, dht, dht_prefix, p2p, manager) @@ -125,7 +125,7 @@ async def _get_remote_module( single_uid = isinstance(uid_or_uids, ModuleUID) uids = [uid_or_uids] if single_uid else uid_or_uids p2p = await dht.replicate_p2p() - managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p) for uid in uids) + managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p, start=True) for uid in uids) modules = [ petals.client.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers @@ -134,14 +134,13 @@ async def _get_remote_module( def get_remote_module_infos( - dht: DHT, - uid_or_uids: Union[ModuleUID, List[ModuleUID]], - expiration_time: Optional[DHTExpiration] = None, + dht: DHT, uid_or_uids: Union[ModuleUID, Sequence[ModuleUID]], expiration_time: Optional[DHTExpiration] = None ) -> List[Optional[RemoteModuleInfo]]: single_uid = isinstance(uid_or_uids, ModuleUID) uids = [uid_or_uids] if single_uid else uid_or_uids infos = dht.run_coroutine( - partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future=False + partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), + return_future=False, ) return infos[0] if single_uid else infos diff --git a/src/petals/server/block_selection.py b/src/petals/server/block_selection.py index 98842b6..203b781 100644 --- a/src/petals/server/block_selection.py +++ b/src/petals/server/block_selection.py @@ -25,7 +25,7 @@ class Span: self.start, self.end = new_start, new_start + self.length -def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]: +def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]: spans = {} throughputs = np.zeros(len(module_infos)) for block, module in enumerate(module_infos): @@ -56,7 +56,7 @@ def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int: def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]: - _, throughputs = _compute_spans(module_infos) + _, throughputs = compute_spans(module_infos) start = _choose_best_start(throughputs, num_blocks) return list(range(start, start + num_blocks)) @@ -67,7 +67,7 @@ def should_choose_other_blocks( if balance_quality > 1.0: return True # Forces rebalancing on each check (may be used for debugging purposes) - spans, throughputs = _compute_spans(module_infos) + spans, throughputs = compute_spans(module_infos) initial_throughput = throughputs.min() eps = 1e-3 diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 4237188..6b23a00 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -1,6 +1,6 @@ import pytest import torch -from hivemind import DHT, BatchTensorDescriptor, MSGPackSerializer, get_logger, use_hivemind_log_handler +from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler from hivemind.proto import runtime_pb2 from test_utils import * @@ -48,7 +48,7 @@ def test_remote_sequential(): # test RemoteSequential with lossy compression block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] lossy_sequential = RemoteSequential( - config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p) + config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p, start=True) ) test_inputs.grad = None @@ -58,7 +58,8 @@ def test_remote_sequential(): assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used" assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used" assert abs(approx_outputs - full_outputs).mean() < 0.01 - assert abs(test_inputs.grad - full_grad).mean() < 0.3 + absmax = abs(full_grad).max() + assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.01 class DummyCustomSequenceManager(RemoteSequenceManager): @@ -73,13 +74,12 @@ class DummyCustomSequenceManager(RemoteSequenceManager): return rpc_info def get_request_metadata(self, protocol: str, *args, **kwargs): + metadata = super().get_request_metadata(protocol, *args, **kwargs) if protocol == "rpc_forward": - return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.FLOAT16,))) + metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,) elif protocol == "rpc_backward": - return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.BLOCKWISE_8BIT,))) - else: - assert protocol == "rpc_inference" - return super().get_request_metadata(protocol, *args, **kwargs) + metadata["output_compression"] = (runtime_pb2.CompressionType.BLOCKWISE_8BIT,) + return metadata @pytest.mark.forked diff --git a/tests/test_sequence_manager.py b/tests/test_sequence_manager.py new file mode 100644 index 0000000..e9c7239 --- /dev/null +++ b/tests/test_sequence_manager.py @@ -0,0 +1,54 @@ +import threading +import time + +import pytest +import torch +from hivemind import DHT, get_logger, use_hivemind_log_handler +from test_utils import * + +from petals.client import RemoteSequenceManager, RemoteSequential +from petals.client.remote_model import DistributedBloomConfig +from petals.data_structures import UID_DELIMITER + +use_hivemind_log_handler("in_root_logger") +logger = get_logger(__file__) + + +@pytest.mark.forked +def test_sequence_manager_shutdown(): + config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) + sequential = RemoteSequential(config, dht) + shutdown_evt = threading.Event() + + # test RemoteSequential with lossy compression + block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] + sequential = RemoteSequential( + config, + dht, + sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt, start=True), + ) + + assert sequential.sequence_manager.is_alive() + assert sequential.sequence_manager._thread.ready.is_set() + assert not shutdown_evt.is_set() + sequential(torch.randn(1, 2, config.hidden_size)) + + sequential.sequence_manager.shutdown() + del sequential + time.sleep(1) + + assert shutdown_evt.is_set() + + +class TestSequenceManager(RemoteSequenceManager): + """A sequence manager that signals if it was shut down""" + + def __init__(self, *args, _was_shut_down: threading.Event, **kwargs): + super().__init__(*args, **kwargs) + self._was_shut_down = _was_shut_down + + def shutdown(self): + super().shutdown() + assert not self.is_alive() + self._was_shut_down.set()