Optimize RemoteSequenceManager (#106)

- [x] made RemoteSequenceManager into a background thread that pre-fetches information instead of running just in time
- [x] moved routing-related stuff to petals.client.routing
- [x] extract remote peer routing information to RemoteSequenceInfo
- [x] made sure that the code survives continued use (e.g. one hour)
- [x] updated every spot where update_ is called manually
- [x] modified get_sequence to check that the thread is alive, warn if not
- [x] removed max_retries, switched rpc_info to exponential backoff
- [x] fixed a bg that causes RemoteSeq* to lose user-defined hyperparameters (e.g. timeout) upon subsequencing (sequential[3:5])
- [x] moved client-side points strategy to client.routing
- [x] ensured that RemoteSequenceManager thread created in get_remote_module properly shuts down when the module is destroyed
- [x] resolved minor affected todos
- [x] modified tests to no longer use PYTHONPATH
- [x] worked around protocol error in rpc_info


Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Artem Chumachenko <artek.chumak@gmail.com>
pull/110/head
justheuristic 1 year ago committed by GitHub
parent 7d859a947b
commit a2066a4096
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -104,7 +104,7 @@ jobs:
kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived init
PYTHONPATH=. pytest tests --durations=0 --durations-min=1.0 -v
pytest tests --durations=0 --durations-min=1.0 -v
kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived tests

@ -1,5 +1,5 @@
from petals.client.inference_session import InferenceSession
from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
from petals.client.sequence_manager import RemoteSequenceManager
from petals.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase
from petals.client.routing.sequence_manager import RemoteSequenceManager
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase

@ -20,7 +20,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import StubBase
from hivemind.proto import runtime_pb2
from petals.client.sequence_manager import RemoteSequenceManager
from petals.client.routing.sequence_manager import RemoteSequenceManager
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
@ -44,14 +44,14 @@ class _ServerInferenceSession:
*,
timeout: float,
max_length: int,
points: int = 0,
**metadata,
):
self.uid, self.rpc_info = uid, rpc_info
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self.timeout = timeout
self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, **metadata))
self.stepped = False
self.closed = False
@ -162,7 +162,7 @@ class InferenceSession:
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
"""
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int, **metadata):
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int):
self._sequence_manager = sequence_manager
self._p2p = p2p
self._closed = False
@ -171,7 +171,6 @@ class InferenceSession:
self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
self._position = 0
self._max_length = max_length
self._metadata = metadata
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
server_sessions = []
@ -179,6 +178,7 @@ class InferenceSession:
for span in chosen_spans:
stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
session = RemoteExpertWorker.run_coroutine(
_ServerInferenceSession.create(
stub,
@ -186,7 +186,7 @@ class InferenceSession:
rpc_info=self._sequence_manager.rpc_info,
timeout=self._sequence_manager.request_timeout,
max_length=self._max_length,
**self._metadata,
**metadata,
)
)
server_sessions.append(session)
@ -237,7 +237,7 @@ class InferenceSession:
logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
try:
if attempt_no >= 1:
self._sequence_manager.update_()
self._sequence_manager.update(wait=True)
if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
# If there is a failed server session, this code closes it
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])

@ -36,7 +36,7 @@ class DistributedBloomConfig(BloomConfig):
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
request_timeout: int = 20 # a number of seconds for waiting result from each node
request_timeout: int = 30 # a number of seconds for waiting result from each node
original_register_parameter = nn.Module.register_parameter

@ -9,7 +9,7 @@ from torch import nn
import petals.client
from petals.client.inference_session import InferenceSession
from petals.client.sequence_manager import RemoteSequenceManager
from petals.client.routing.sequence_manager import RemoteSequenceManager
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
from petals.data_structures import UID_DELIMITER
from petals.utils.misc import DUMMY
@ -30,7 +30,7 @@ class RemoteSequential(nn.Module):
dht_prefix: Optional[str] = None,
p2p: Optional[P2P] = None,
sequence_manager: Optional[RemoteSequenceManager] = None,
request_timeout: int = 20,
**kwargs,
):
super().__init__()
self.config = config
@ -39,16 +39,18 @@ class RemoteSequential(nn.Module):
self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager)
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)]
block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks))
if sequence_manager is None:
logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p, request_timeout=request_timeout)
self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p, start=True, **kwargs)
self.is_subsequence = False
else:
logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")
if kwargs:
logger.warning(f"Parameters {kwargs} are ignored because sequence_manager is explicitly provided")
self.sequence_manager = sequence_manager
assert isinstance(sequence_manager.block_uids, list)
self.is_subsequence = self.sequence_manager.block_uids != block_uids
assert isinstance(sequence_manager.sequence_info.block_uids, tuple)
self.is_subsequence = self.sequence_manager.sequence_info.block_uids != block_uids
def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
@ -81,7 +83,6 @@ class RemoteSequential(nn.Module):
return len(self.sequence_manager)
def inference_session(self, **kwargs) -> InferenceSession:
self.sequence_manager.update_()
return InferenceSession(self.sequence_manager, self.p2p, **kwargs)
def extra_repr(self) -> str:

@ -0,0 +1 @@
"""Client-side functions responsible for choosing the best server, """

@ -0,0 +1,102 @@
import dataclasses
import time
from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
from hivemind import get_logger, use_hivemind_log_handler
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
T = TypeVar("T")
@dataclasses.dataclass
class RemoteSequenceInfo:
"""
A dataclass that stores general information about which servers hold any given layer;
- updated by RemoteSequenceManager in a background thread
- accessed by routing strategies in .on_update
:note: this class should *not* be modified by RoutingStrategy.on_update to avoid interference between strategies;
Any metadata specific to one routing strategy, it should be stored inside that strategy. Any information that
is used by most routing strategies should be moved from said strategies to this class.
"""
block_uids: Tuple[ModuleUID, ...]
block_infos: Tuple[RemoteModuleInfo, ...] # note: the contents of RemoteModuleInfo can and will be updated
spans_by_priority: List[RemoteSpanInfo]
spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
last_updated_time: float
@classmethod
def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
block_uids = tuple(block_uids)
empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
empty_spans = tuple([] for _ in range(len(block_uids)))
return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=-float("inf"))
def __getitem__(self, ix: slice):
assert isinstance(ix, slice)
block_uids, block_infos = self.block_uids[ix], self.block_infos[ix]
spans_by_priority, spans_containing_block = self.compute_spans(block_infos)
return RemoteSequenceInfo(
block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time
)
def __len__(self):
return len(self.block_uids)
def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]):
assert len(new_block_infos) == len(self.block_uids)
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
if info is None:
logger.debug(f"Found no block info for block {uid}")
continue
if not isinstance(info, RemoteModuleInfo):
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
continue
if not info.servers:
logger.debug(f"Found no active peers for block {uid}")
continue
if info.uid != uid:
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
continue
self.block_infos[block_index].servers = info.servers
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
self.last_updated_time = time.perf_counter()
@staticmethod
def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
closed_spans = []
active_spans = {}
for block_index, info in enumerate(block_infos):
if info is not None:
for peer_id, server in info.servers.items():
if server.state != ServerState.ONLINE:
continue
if peer_id not in active_spans:
active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
else: # peer_id in active_spans
active_spans[peer_id].end = block_index + 1
for peer_id in list(active_spans.keys()):
if (
info is None
or peer_id not in info.servers
or info.servers[peer_id] != ServerState.ONLINE
or block_index == len(block_infos) - 1
):
closed_spans.append(active_spans.pop(peer_id))
assert not active_spans, f"spans: {active_spans}"
closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
spans_containing_block = tuple(list() for _ in range(len(block_infos)))
for span in closed_spans:
for block_index in range(span.start, span.end):
spans_containing_block[block_index].append(span)
return closed_spans, spans_containing_block

@ -0,0 +1,265 @@
from __future__ import annotations
import itertools
import logging
import random
import threading
import time
from typing import Any, Dict, List, Optional, Sequence, Union
from weakref import WeakMethod
from hivemind import DHT, P2P, MSGPackSerializer
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
import petals.dht_utils
from petals.client.routing.sequence_info import RemoteSequenceInfo
from petals.client.routing.spending_policy import NoSpendingPolicy
from petals.data_structures import ModuleUID, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class RemoteSequenceManager:
"""
Sequence manager is a thread that keeps track of remote servers that hold the specified sequence of blocks.
TL;DR it tells you, which peers you should ask to get a specific layer. It is used in RemoteSequential.
When created, RemoteSequenceManager looks up which servers serve necessary layers by reading from DHT.
Using this information, sequence manager can form sequences of servers that collectively have the full sequence.
To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr).
:param dht: a running hivemind.DHT instance, connected to peers that serve the corresponding blocks
:param block_uids: a sequence of DHT keys (strings) corresponding to remote layers
:param p2p: an optional P2P replica (if not specified, create one via dht.replicate_p2p())
:param update_period: by default, refresh DHT information once in this many seconds
:param request_timeout: float, in seconds, default timeout for RPC forwad/backward/inference requests
:param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1)
:param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht
:param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time
:param start: start the background thread (see the note below). If false, you will need to start it manually.
:note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
running redundant sequence managers for the same set of layers.
"""
def __init__(
self,
dht: DHT,
block_uids: Sequence[ModuleUID],
p2p: P2P,
update_period: float = 30,
request_timeout: float = 30,
min_backoff: float = 1,
sequence_info: Optional[RemoteSequenceInfo] = None,
rpc_info: Optional[dict] = None,
*, # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
start: bool,
):
assert len(block_uids) > 0, "Sequences must contain at least one block"
self.dht, self.p2p = dht, p2p
self.request_timeout, self.min_backoff = request_timeout, min_backoff
self.lock_changes = threading.Lock()
self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update))
self.policy = NoSpendingPolicy()
self._rpc_info = rpc_info
if sequence_info is None:
self.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
self.update(wait=False)
else:
self.sequence_info = sequence_info
assert block_uids == sequence_info.block_uids
self._thread.ready.set() # no need to await the first dht fetch
if start:
self.run_in_background()
def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
"""
Starts the updater thread in a background. if await_ready, this method will wait until sequence manager
is ready to process incoming requests or for :timeout: seconds max.
"""
self._thread.start()
if await_ready:
self._thread.ready.wait(timeout)
def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
"""
Form a sequence of remote servers that collectively serve all consecutive layers
:param start_index: optional index of the first module in a sequence, default = the first of block_uids
:param end_index: optional index of the last module (non-inclusive), default = after last of block uids
"""
if not self.is_alive():
logger.error("Using a sequence manager that is not running: it has either crashed or never started")
if not self.ready.is_set():
logger.warning("Remote SequenceManager is still searching for routes, waiting for it to become ready")
self.ready.wait()
end_index = end_index if end_index is not None else len(self)
span_sequence = []
current_index = start_index
while current_index < end_index:
candidate_spans = self.sequence_info.spans_containing_block[current_index]
chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
assert chosen_span.start <= current_index < chosen_span.end
span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))
current_index = chosen_span.end
return span_sequence
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
"""Get a RemoteSequenceManager for a sub-sequence of blocks"""
assert isinstance(ix, (int, slice))
if not isinstance(ix, slice):
ix = slice(int(ix), int(ix) + 1, 1)
return type(self)(
self.dht,
self.block_uids[ix],
self.p2p,
update_period=self._thread.update_period,
request_timeout=self.request_timeout,
min_backoff=self.min_backoff,
sequence_info=self.sequence_info[ix],
rpc_info=self._rpc_info,
start=True,
)
def update(self, *, wait: bool):
"""Run an asynchronous update in background as soon as possible"""
self.ready.clear() # TODO this should be a separate event
self._thread.trigger.set()
if wait:
self.ready.wait()
def _update(self):
"""Perform an immediate and synchronous refresh, may take time"""
for attempt_no in itertools.count():
try:
new_block_infos = petals.dht_utils.get_remote_module_infos(
self.dht, self.block_uids, expiration_time=float("inf")
)
with self.lock_changes:
self.sequence_info.update_(new_block_infos)
missing_blocks = [i for i in range(len(self)) if not self.sequence_info.spans_containing_block[i]]
if missing_blocks:
raise MissingBlocksError(f"no servers holding blocks {missing_blocks}")
self.ready.set() # if there is an active server for every block, we may begin running
break
except Exception as e:
delay = self.get_retry_delay(attempt_no)
logger.warning(f"Could not find route through the model: {repr(e)} (retry in {delay:.0f} sec)")
traceback_level = logging.DEBUG if str(e) else logging.WARNING
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
time.sleep(delay)
def __len__(self):
return len(self.block_uids)
@property
def is_alive(self):
return self._thread.is_alive
@property
def ready(self) -> threading.Event:
return self._thread.ready
@property
def block_uids(self):
return self.sequence_info.block_uids
@property
def rpc_info(self):
"""Return the rpc_info queried from one of the servers that hold the first block"""
if self._rpc_info is None:
for attempt_no in itertools.count():
try:
self._update()
peer_id, _ = random.choice(list(self.sequence_info.block_infos[0].servers.items()))
stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
outputs = RemoteExpertWorker.run_coroutine(
stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
)
self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
break
except Exception as e:
delay = self.get_retry_delay(attempt_no)
logger.warning(
f"Caught exception when gathering information from peer {peer_id} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
traceback_level = logging.DEBUG if str(e) else logging.WARNING
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
time.sleep(delay)
return self._rpc_info
def get_retry_delay(self, attempt_no: int) -> float:
if attempt_no == 0:
return 0
return self.min_backoff * 2 ** (attempt_no - 1)
def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
"""
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
:param args: request-specific inputs, typicall block uids and input tensors
:param kwargs: additional request context, such as remote peer ID
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
"""
return dict(points=self.policy.get_points(protocol, *args, **kwargs))
def shutdown(self):
self._thread.shutdown()
class _SequenceManagerUpdateThread(threading.Thread):
def __init__(self, update_period: float, ref_update_manager: WeakMethod):
super().__init__(daemon=True)
self.ref_update_manager = ref_update_manager
self.ready = threading.Event()
self.trigger = threading.Event()
self.last_update_time = -float("inf")
self.update_period = update_period
self.should_shutdown = False
def run(self) -> None:
while not self.should_shutdown:
self.trigger.wait(max(0.0, min(self.update_period, time.perf_counter() - self.last_update_time)))
if self.should_shutdown:
logger.debug(f"{self.__class__.__name__} is shutting down")
break
update_manager = self.ref_update_manager()
if update_manager is None:
logger.debug(f"{self.__class__.__name__} exited because the sequence manager no longer exists")
break
try:
update_manager()
self.trigger.clear()
except Exception as e:
logger.exception(e)
finally:
del update_manager
logger.debug(f"{self.__class__.__name__} thread exited")
def shutdown(self, timeout: Optional[float] = None):
self.should_shutdown = True
self.trigger.set()
self.join(timeout)
def __del__(self):
if self.is_alive():
self.shutdown()
class MissingBlocksError(Exception):
def __repr__(self):
return self.args[0]

@ -1,179 +0,0 @@
from __future__ import annotations
import random
import threading
from typing import List, Optional, Sequence, Tuple, Union
from hivemind import DHT, P2P, DHTExpiration, MSGPackSerializer
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
import petals.dht_utils
from petals.client.spending_policy import NoSpendingPolicy
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
from petals.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class RemoteSequenceManager:
"""
Keeps and updates the meta-information about which peers host which blocks.
In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc.
"""
def __init__(
self,
dht: DHT,
block_uids: Sequence[ModuleUID],
p2p: P2P,
max_retries: int = 3,
request_timeout: float = 20,
min_backoff: float = 1,
):
assert len(block_uids) > 0, "Sequences must contain at least one block"
self.dht, self.p2p = dht, p2p
self.block_uids: List[ModuleUID] = list(block_uids)
self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
self.spans_by_priority: List[RemoteSpanInfo] = [] # sorted from best to worst
self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
self.last_update_time: DHTExpiration = -float("inf")
self.max_retries = max_retries
self.request_timeout, self.min_backoff = request_timeout, min_backoff
self._rpc_info = None
self.lock_changes = threading.Lock()
self.policy = NoSpendingPolicy()
self.update_()
for uid, info in zip(self.block_uids, self.block_infos):
assert info is not None, f"Found no remote peers for block {uid}"
assert self.spans_by_priority and self.spans_containing_block
def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
"""
Form a sequence of remote servers that collectively serve all consecutive layers
:param start_index: optional index of the first module in a sequence, default = the first of block_uids
:param end_index: optional index of the last module (non-inclusive), default = after last of block uids
"""
end_index = end_index if end_index is not None else len(self.block_uids)
span_sequence = []
current_index = start_index
while current_index < end_index:
candidate_spans = self.spans_containing_block[current_index]
chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
assert chosen_span.start <= current_index < chosen_span.end
span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))
current_index = chosen_span.end
return span_sequence
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
"""Get a RemoteSequenceManager for a sub-sequence of blocks"""
assert isinstance(ix, (int, slice))
if not isinstance(ix, slice):
ix = slice(int(ix), int(ix) + 1, 1)
with self.lock_changes:
subseq = RemoteSequenceManager(self.dht, self.block_uids[ix], self.p2p)
subseq.block_infos = self.block_infos[ix]
subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
subseq.last_update_time = self.last_update_time
return subseq
def update_(self):
with self.lock_changes:
self.update_block_infos_()
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
def update_block_infos_(self):
new_block_infos = petals.dht_utils.get_remote_module_infos(
self.dht, self.block_uids, expiration_time=float("inf")
)
assert len(new_block_infos) == len(self.block_uids)
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
if info is None:
logger.warning(f"Found no block info for block {uid}")
continue
if not isinstance(info, RemoteModuleInfo):
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
if not info.servers:
logger.warning(f"Found no active peers for block {uid}")
if info.uid != uid:
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
self.block_infos[block_index] = info
@staticmethod
def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
closed_spans = []
active_spans = {}
for block_index, info in enumerate(block_infos):
if info is not None:
for peer_id, server in info.servers.items():
if server.state != ServerState.ONLINE:
continue
if peer_id not in active_spans:
active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
else: # peer_id in active_spans
active_spans[peer_id].end = block_index + 1
for peer_id in list(active_spans.keys()):
if (
info is None
or peer_id not in info.servers
or info.servers[peer_id].state != ServerState.ONLINE
or block_index == len(block_infos) - 1
):
closed_spans.append(active_spans.pop(peer_id))
assert not active_spans, f"spans: {active_spans}"
closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
spans_containing_block = tuple(list() for _ in range(len(block_infos)))
for span in closed_spans:
for block_index in range(span.start, span.end):
spans_containing_block[block_index].append(span)
return closed_spans, spans_containing_block
def __len__(self):
return len(self.block_uids)
@property
def rpc_info(self):
"""Return the rpc_info queried from one of the servers that hold the first block"""
if self._rpc_info is None:
retries = 0
for i in range(self.max_retries):
try:
self.update_()
peer_id = random.choice(list(self.block_infos[0].servers.keys()))
stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
outputs = RemoteExpertWorker.run_coroutine(
stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
)
self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
break
except Exception as e:
retries += 1
if retries >= self.max_retries:
raise e
else:
logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True)
return self._rpc_info
def get_retry_delay(self, attempt_no: int) -> float:
if attempt_no == 0:
return 0
return self.min_backoff * 2 ** (attempt_no - 1)
def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[bytes]:
"""
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
:param args: request-specific inputs, typicall block uids and input tensors
:param kwargs: additional request context, such as remote peer ID
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
"""
return MSGPackSerializer.dumps(dict(points=self.policy.get_points(protocol, *args, **kwargs)))

@ -8,11 +8,12 @@ from collections import deque
from typing import List, Optional, Sequence, Tuple
import torch
from hivemind import MSGPackSerializer
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.utils.logging import get_logger
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
from petals.client.sequence_manager import RemoteSequenceManager
from petals.client.routing.sequence_manager import RemoteSequenceManager
from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
@ -58,7 +59,7 @@ async def sequential_forward(
logger.debug(f"Forward: block {block_idx}, attempt {attempt_no}")
try:
if attempt_no >= 1:
sequence_manager.update_()
sequence_manager._update()
if not sequences or attempt_no >= 1:
sequences = deque(sequence_manager.make_sequence(block_idx, end_index))
# make_sequence() could return a longer sequence
@ -78,7 +79,7 @@ async def sequential_forward(
sequence_manager.rpc_info,
*inputs_and_prompts,
timeout=sequence_manager.request_timeout,
metadata=metadata,
metadata=MSGPackSerializer.dumps(metadata),
)
assert isinstance(outputs, torch.Tensor)
@ -136,7 +137,7 @@ async def sequential_backward(
logger.debug(f"Backward: block {span.end - 1}, attempt {attempt_no}")
try:
if attempt_no >= 1:
sequence_manager.update_()
sequence_manager.update(wait=True)
_, backup_inputs, backup_sequences = await sequential_forward(
inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
)
@ -162,7 +163,7 @@ async def sequential_backward(
grad_outputs,
prompts[span.start : span.end],
timeout=sequence_manager.request_timeout,
metadata=metadata,
metadata=MSGPackSerializer.dumps(metadata),
)
grad_outputs = [grad_outputs]
grad_prompts_reversed.extend(span_grad_prompts)

@ -94,7 +94,7 @@ async def _get_remote_sequence(
) -> petals.client.RemoteSequential:
uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
p2p = await dht.replicate_p2p()
manager = petals.client.RemoteSequenceManager(dht, uids, p2p)
manager = petals.client.RemoteSequenceManager(dht, uids, p2p, start=True)
return petals.client.RemoteSequential(config, dht, dht_prefix, p2p, manager)
@ -125,7 +125,7 @@ async def _get_remote_module(
single_uid = isinstance(uid_or_uids, ModuleUID)
uids = [uid_or_uids] if single_uid else uid_or_uids
p2p = await dht.replicate_p2p()
managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p, start=True) for uid in uids)
modules = [
petals.client.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m)
for m in managers
@ -134,14 +134,13 @@ async def _get_remote_module(
def get_remote_module_infos(
dht: DHT,
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
expiration_time: Optional[DHTExpiration] = None,
dht: DHT, uid_or_uids: Union[ModuleUID, Sequence[ModuleUID]], expiration_time: Optional[DHTExpiration] = None
) -> List[Optional[RemoteModuleInfo]]:
single_uid = isinstance(uid_or_uids, ModuleUID)
uids = [uid_or_uids] if single_uid else uid_or_uids
infos = dht.run_coroutine(
partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future=False
partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time),
return_future=False,
)
return infos[0] if single_uid else infos

@ -25,7 +25,7 @@ class Span:
self.start, self.end = new_start, new_start + self.length
def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
spans = {}
throughputs = np.zeros(len(module_infos))
for block, module in enumerate(module_infos):
@ -56,7 +56,7 @@ def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
_, throughputs = _compute_spans(module_infos)
_, throughputs = compute_spans(module_infos)
start = _choose_best_start(throughputs, num_blocks)
return list(range(start, start + num_blocks))
@ -67,7 +67,7 @@ def should_choose_other_blocks(
if balance_quality > 1.0:
return True # Forces rebalancing on each check (may be used for debugging purposes)
spans, throughputs = _compute_spans(module_infos)
spans, throughputs = compute_spans(module_infos)
initial_throughput = throughputs.min()
eps = 1e-3

@ -1,6 +1,6 @@
import pytest
import torch
from hivemind import DHT, BatchTensorDescriptor, MSGPackSerializer, get_logger, use_hivemind_log_handler
from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler
from hivemind.proto import runtime_pb2
from test_utils import *
@ -48,7 +48,7 @@ def test_remote_sequential():
# test RemoteSequential with lossy compression
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
lossy_sequential = RemoteSequential(
config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p)
config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p, start=True)
)
test_inputs.grad = None
@ -58,7 +58,8 @@ def test_remote_sequential():
assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
assert abs(approx_outputs - full_outputs).mean() < 0.01
assert abs(test_inputs.grad - full_grad).mean() < 0.3
absmax = abs(full_grad).max()
assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.01
class DummyCustomSequenceManager(RemoteSequenceManager):
@ -73,13 +74,12 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
return rpc_info
def get_request_metadata(self, protocol: str, *args, **kwargs):
metadata = super().get_request_metadata(protocol, *args, **kwargs)
if protocol == "rpc_forward":
return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.FLOAT16,)))
metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,)
elif protocol == "rpc_backward":
return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.BLOCKWISE_8BIT,)))
else:
assert protocol == "rpc_inference"
return super().get_request_metadata(protocol, *args, **kwargs)
metadata["output_compression"] = (runtime_pb2.CompressionType.BLOCKWISE_8BIT,)
return metadata
@pytest.mark.forked

@ -0,0 +1,54 @@
import threading
import time
import pytest
import torch
from hivemind import DHT, get_logger, use_hivemind_log_handler
from test_utils import *
from petals.client import RemoteSequenceManager, RemoteSequential
from petals.client.remote_model import DistributedBloomConfig
from petals.data_structures import UID_DELIMITER
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@pytest.mark.forked
def test_sequence_manager_shutdown():
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
sequential = RemoteSequential(config, dht)
shutdown_evt = threading.Event()
# test RemoteSequential with lossy compression
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
sequential = RemoteSequential(
config,
dht,
sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt, start=True),
)
assert sequential.sequence_manager.is_alive()
assert sequential.sequence_manager._thread.ready.is_set()
assert not shutdown_evt.is_set()
sequential(torch.randn(1, 2, config.hidden_size))
sequential.sequence_manager.shutdown()
del sequential
time.sleep(1)
assert shutdown_evt.is_set()
class TestSequenceManager(RemoteSequenceManager):
"""A sequence manager that signals if it was shut down"""
def __init__(self, *args, _was_shut_down: threading.Event, **kwargs):
super().__init__(*args, **kwargs)
self._was_shut_down = _was_shut_down
def shutdown(self):
super().shutdown()
assert not self.is_alive()
self._was_shut_down.set()
Loading…
Cancel
Save