From 158013a6715a97686493bddf1904ba6455a2a6b6 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 11 Jul 2023 17:29:34 +0400 Subject: [PATCH] Implement direct server-to-server communication (#331) Implement #226. --- src/petals/__init__.py | 2 +- src/petals/cli/run_server.py | 3 +- src/petals/client/inference_session.py | 228 ++++++++++-------- src/petals/client/routing/sequence_manager.py | 1 + src/petals/server/handler.py | 200 +++++++++++++-- src/petals/server/server.py | 32 ++- 6 files changed, 334 insertions(+), 132 deletions(-) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 26aa3ab..f007d11 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -9,7 +9,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.2.0.dev0" +__version__ = "1.2.0.dev1" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 3c28709..1d3c438 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -27,8 +27,7 @@ def main(): parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve") parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve") - parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default," - "use the same name as in the converted model.") + parser.add_argument('--dht_prefix', type=str, default=None, help="Announce all blocks with this DHT prefix") parser.add_argument('--port', type=int, required=False, help='Port this server listens to. ' diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 168dd40..8c2dfc9 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -3,7 +3,8 @@ from __future__ import annotations import asyncio import itertools import time -from typing import AsyncIterator, List, Optional +import uuid +from typing import AsyncIterator, List, Optional, Tuple import torch from hivemind import ( @@ -15,10 +16,10 @@ from hivemind import ( serialize_torch_tensor, ) from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker -from hivemind.p2p import StubBase +from hivemind.p2p import P2P from hivemind.proto import runtime_pb2 -from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback +from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig, maybe_log_traceback from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo from petals.server.handler import TransformerConnectionHandler from petals.utils.misc import DUMMY, is_dummy @@ -35,35 +36,48 @@ class _ServerInferenceSession: def __init__( self, + config: SequenceManagerConfig, + span: RemoteSpanInfo, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator, *, - timeout: float, max_length: int, **metadata, ): - self.uid, self.rpc_info = uid, rpc_info + self.config = config + self.span, self.uid, self.rpc_info = span, uid, rpc_info self.num_blocks = uid.count(CHAIN_DELIMITER) + 1 self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter - self.timeout = timeout - self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, **metadata)) + self.session_id = str(uuid.uuid4()) + self.session_metadata = dict(max_length=max_length, **metadata) self.stepped = False self.closed = False + self._position = 0 + self.history = None # Used in case of server failures to regenerate attention caches on new servers + self.next_session = None + @classmethod async def create( - cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata + cls, + config: SequenceManagerConfig, + p2p: P2P, + span: RemoteSpanInfo, + uid: ModuleUID, + rpc_info: RPCInfo, + **metadata, ) -> _ServerInferenceSession: """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" + stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id) inputs_queue = asyncio.Queue() outputs_stream = await asyncio.wait_for( stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)), - timeout, + config.request_timeout, ) - return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata) + return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata) @staticmethod async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator: @@ -75,9 +89,11 @@ class _ServerInferenceSession: def step( self, - new_hidden_states: torch.Tensor, + inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None, + *, + step_id: str, ) -> torch.Tensor: """ Inference step: send a chunk of input tesors and receive a chunk of outputs @@ -86,44 +102,84 @@ class _ServerInferenceSession: """ if self.closed: raise Exception("Session is closed, cannot perform step") + + n_input_tokens = inputs.shape[1] + if self.history is None: + self.history = inputs + elif self.history.shape[1] == self._position: + self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1) + assert self.history.shape[1] == self._position + n_input_tokens, ( + f"Broken input cache: span={self.span} shape={self.history.shape} " + f"position={self._position} n_input_tokens={n_input_tokens}" + ) + + if not self.stepped: + inputs = self.history # Pass full inputs including prefix + else: + inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further + if prompts is None or is_dummy(prompts): prompts = DUMMY else: - assert prompts.ndim == 4, "deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]" + assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]" assert prompts.shape[0] == self.num_blocks - assert prompts.shape[1] in (new_hidden_states.shape[0], 1) - assert prompts.shape[2] <= new_hidden_states.shape[1] - assert prompts.shape[3] == new_hidden_states.shape[2] + assert prompts.shape[1] in (inputs.shape[0], 1) + assert prompts.shape[2] <= inputs.shape[1] + assert prompts.shape[3] == inputs.shape[2] if hypo_ids is None or is_dummy(hypo_ids): hypo_ids = DUMMY else: - assert len(hypo_ids) == len(new_hidden_states) + assert len(hypo_ids) == len(inputs) assert hypo_ids.dtype == torch.int64 # serialize inputs and put them into the queue - inputs = (new_hidden_states, prompts, hypo_ids) + input_tensors = (inputs, prompts, hypo_ids) + + request_metadata = dict(session_id=self.session_id, step_id=step_id) + if not self.stepped: + request_metadata.update(self.session_metadata) + elif self.config.use_server_to_server: + next_servers = self._collect_next_servers() + if next_servers: + request_metadata["next_servers"] = next_servers + outputs_serialized = RemoteExpertWorker.run_coroutine( self._step( runtime_pb2.ExpertRequest( uid=self.uid, tensors=[ serialize_torch_tensor(tensor.to(proto.dtype), proto.compression) - for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"])) + for tensor, proto in zip(input_tensors, nested_flatten(self.rpc_info["inference_schema"])) ], - metadata=self._serialized_metadata if not self.stepped else None, + metadata=MSGPackSerializer.dumps(request_metadata), ) ) ) outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors)) - assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}" + assert ( + outputs[0].shape == inputs.shape + ), f"output activation shape is different from input shape: {outputs[0].shape} != {inputs.shape}" + + self._position += n_input_tokens + return outputs[0] + def _collect_next_servers(self) -> List[Tuple[str, str, int, int]]: + next_servers = [] + session = self.next_session + while session is not None and session.stepped: + next_servers.append( + (session.span.peer_id.to_base58(), session.session_id, session.span.start, session.span.end) + ) + session = session.next_session + return next_servers + async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse: """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker""" await self._inputs_queue.put(inputs_serialized) self.stepped = True - return await asyncio.wait_for(anext(self._outputs_stream), self.timeout) + return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout) def close(self): """Finish a given inference session, close the underlying connection""" @@ -163,13 +219,15 @@ class InferenceSession: def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int): self._sequence_manager = sequence_manager self._closed = False - self._chosen_spans = [] self._server_sessions = [] - self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers self._position = 0 self._max_length = max_length self.last_token_id = None + @property + def num_blocks(self) -> int: + return len(self._sequence_manager) + @property def position(self) -> int: return self._position @@ -178,15 +236,15 @@ class InferenceSession: server_sessions = [] try: for span in chosen_spans: - stub = TransformerConnectionHandler.get_stub(self._sequence_manager.state.p2p, span.peer_id) span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end]) metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id) session = RemoteExpertWorker.run_coroutine( _ServerInferenceSession.create( - stub, + self._sequence_manager.config, + self._sequence_manager.state.p2p, + span, span_uids, rpc_info=self._sequence_manager.rpc_info, - timeout=self._sequence_manager.config.request_timeout, max_length=self._max_length, **metadata, ) @@ -206,7 +264,7 @@ class InferenceSession: logger.debug("Caught exception while closing connection to server:", exc_info=True) def __enter__(self) -> "InferenceSession": - assert not self._closed and not self._chosen_spans + assert not self._closed and not self._server_sessions return self def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: @@ -214,16 +272,17 @@ class InferenceSession: if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") - n_blocks = len(self._sequence_manager) if prompts is None or is_dummy(prompts): prompts = DUMMY else: - assert prompts.ndim == 4 and prompts.shape[0] == n_blocks + assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]" + assert prompts.shape[0] == self.num_blocks inputs_device = inputs.device inputs_dtype = inputs.dtype inputs = inputs.cpu() prompts = prompts.cpu() + step_id = str(uuid.uuid4()) n_input_tokens = inputs.shape[1] if self._position + n_input_tokens > self._max_length: @@ -233,97 +292,74 @@ class InferenceSession: server_idx = 0 block_idx = 0 - recovery_until = -1 # Recovery mode is disabled until a failure happens - while block_idx < n_blocks: + while block_idx < self.num_blocks: for attempt_no in itertools.count(): logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}") - span = None + server_session = None try: - if not self._chosen_spans or not self._server_sessions or attempt_no >= 1: - # If there is a failed server session, this code closes it - self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1]) - - n_prev_spans = len(self._chosen_spans) - update_end = self._chosen_spans[server_idx].end if server_idx < n_prev_spans else n_blocks - if attempt_no >= 1 and update_end > recovery_until: - logger.info( - f"Due to a server failure, remote attention caches " - f"from block {block_idx} to {update_end} will be regenerated" - ) - recovery_until = max(recovery_until, update_end) - - updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency") - # make_sequence() could return a longer sequence - updated_spans[-1].end = min(updated_spans[-1].end, update_end) - updated_sessions = self._enter_server_sessions(updated_spans) - logger.debug( - f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers" - ) - - # If there is a failed span, this code replaces it, otherwise it just adds new ones - self._chosen_spans[server_idx : server_idx + 1] = updated_spans - self._server_sessions[server_idx : server_idx + 1] = updated_sessions - recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None - self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * ( - len(updated_spans) - 1 - ) - assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), ( - f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, " - f"{len(self._server_inputs)} inputs" - ) - - session = self._server_sessions[server_idx] - span = self._chosen_spans[server_idx] - - if self._server_inputs[server_idx] is None: - self._server_inputs[server_idx] = inputs - elif self._server_inputs[server_idx].shape[1] == self._position: - self._server_inputs[server_idx] = torch.cat( - [self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1 - ) - assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, ( - f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} " - f"position={self._position} n_input_tokens={n_input_tokens}" - ) - - if not session.stepped: - inputs = self._server_inputs[server_idx] # Pass full inputs including prefix - else: - inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further + if not self._server_sessions or attempt_no >= 1: + self._update_sequence(server_idx, block_idx, attempt_no) - outputs = session.step(inputs, prompts[span.start : span.end], **kwargs) - assert ( - inputs.shape == outputs.shape - ), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})" + server_session = self._server_sessions[server_idx] + inputs = server_session.step( + inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs + ) - inputs = outputs server_idx += 1 - block_idx = span.end - self._sequence_manager.on_request_success(span.peer_id) + block_idx = server_session.span.end + self._sequence_manager.on_request_success(server_session.span.peer_id) break except Exception as e: - self._sequence_manager.on_request_failure(span.peer_id if span is not None else None) + self._sequence_manager.on_request_failure( + server_session.span.peer_id if server_session is not None else None + ) if attempt_no + 1 == self._sequence_manager.config.max_retries: raise delay = self._sequence_manager.get_retry_delay(attempt_no) logger.warning( - f"Caught exception when running inference via {span} (retry in {delay:.0f} sec): {repr(e)}" + f"Caught exception when running inference via {server_session.span if server_session is not None else None} " + f"(retry in {delay:.0f} sec): {repr(e)}" ) maybe_log_traceback(e) time.sleep(delay) self._position += n_input_tokens - inputs = inputs[:, -n_input_tokens:] - outputs = inputs.to(device=inputs_device, dtype=inputs_dtype) + outputs = inputs[:, -n_input_tokens:] + outputs = outputs.to(device=inputs_device, dtype=inputs_dtype) return outputs + def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int: + # If there is a failed server session, this code closes it + self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1]) + + n_prev_spans = len(self._server_sessions) + update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks + if attempt_no >= 1: + logger.info( + f"Due to a server failure, remote attention caches " + f"from block {block_idx} to {update_end} will be regenerated" + ) + + updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency") + # make_sequence() could return a longer sequence + updated_spans[-1].end = min(updated_spans[-1].end, update_end) + updated_sessions = self._enter_server_sessions(updated_spans) + logger.debug(f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers") + + # If there is a failed span, this code replaces it, otherwise it just adds new ones + if server_idx < n_prev_spans: + updated_sessions[0].history = self._server_sessions[server_idx].history + self._server_sessions[server_idx : server_idx + 1] = updated_sessions + + # Update links to the next server session for direct server-to-server communication via rpc_push() + for i in range(max(server_idx - 1, 0), min(server_idx + len(updated_spans), len(self._server_sessions) - 1)): + self._server_sessions[i].next_session = self._server_sessions[i + 1] + def close(self, *exc_details): """Finish a given inference session, close the underlying connection""" if not self._closed: - self._server_inputs.clear() self._exit_server_sessions(self._server_sessions) self._server_sessions.clear() - self._chosen_spans.clear() self._closed = True def __exit__(self, *exc_details): diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 1a31d66..88d6d16 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -34,6 +34,7 @@ class SequenceManagerConfig: daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers + use_server_to_server: bool = True # Use direct server-to-server communication request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests update_period: float = 60 # refresh DHT information once in this many seconds diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 79376f8..65ee5c6 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -2,6 +2,9 @@ from __future__ import annotations import asyncio import contextlib +import multiprocessing.managers +import sys +from concurrent.futures import ThreadPoolExecutor from itertools import chain from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union @@ -11,6 +14,7 @@ from hivemind import ( DHT, MSGPackSerializer, P2PContext, + PeerID, deserialize_tensor_stream, deserialize_torch_tensor, nested_flatten, @@ -25,7 +29,7 @@ from hivemind.utils.logging import get_logger from hivemind.utils.streaming import split_for_streaming import petals -from petals.data_structures import CHAIN_DELIMITER, InferenceMetadata, ModuleUID +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID from petals.server.backend import TransformerBackend from petals.server.memory_cache import Handle from petals.server.task_pool import PrioritizedTaskPool @@ -34,6 +38,23 @@ from petals.utils.misc import DUMMY, is_dummy logger = get_logger(__name__) + +# Fix pickling protobufs, see https://stackoverflow.com/a/74873028 +sys.modules["runtime_pb2"] = runtime_pb2 + +# Fix queues in multiprocessing.Manager in Python < 3.9.7, see https://bugs.python.org/issue30256 + +_OriginalAutoProxy = multiprocessing.managers.AutoProxy + + +def patched_autoproxy(*args, manager_owned=True, **kwargs): + # Calling original AutoProxy without the unwanted key argument + return _OriginalAutoProxy(*args, **kwargs) + + +multiprocessing.managers.AutoProxy = patched_autoproxy + + CACHE_TOKENS_AVAILABLE = "cache_tokens_available" @@ -47,6 +68,9 @@ class TransformerConnectionHandler(ConnectionHandler): dht: DHT, module_backends: Dict[str, TransformerBackend], *, + dht_prefix: str, + push_manager: multiprocessing.managers.SyncManager, + session_queues: Dict[str, multiprocessing.managers.BaseProxy], # BaseProxy for queue.Queue inference_max_length: int, request_timeout: float, session_timeout: float, @@ -56,6 +80,11 @@ class TransformerConnectionHandler(ConnectionHandler): super().__init__(dht, module_backends) for module_backend in self.module_backends.values(): assert isinstance(module_backend, TransformerBackend) + self.dht_prefix = dht_prefix + self._push_manager = push_manager + self._session_queues = session_queues + self._executor = ThreadPoolExecutor(max_workers=float("inf")) # For waiting on self.session_queues + self.inference_max_length = inference_max_length self.request_timeout = request_timeout self.session_timeout, self.step_timeout = session_timeout, step_timeout @@ -96,7 +125,7 @@ class TransformerConnectionHandler(ConnectionHandler): self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext, - ) -> AsyncIterator[runtime_pb2.ExpertRequest]: + ) -> AsyncIterator[runtime_pb2.ExpertResponse]: """Compute a single step of inference using attention cache; update attention cache accordingly.""" async with timeout(self.session_timeout): @@ -113,6 +142,7 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") points = metadata.get("points", 0) + session_id = metadata.get("session_id") if not requested_uids: raise ValueError("User must specify at least one block for inference, but got none") @@ -133,7 +163,11 @@ class TransformerConnectionHandler(ConnectionHandler): async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles: assert len(cache_handles) == len(requested_backends) - while request.tensors: # iterate while user is willing to supply tensors + first_request = request + background_tasks = set() + async for request, metadata in self._iterate_inference_steps( + first_request, requests, session_id, requested_uids, context + ): hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) # Cast inputs to backend dtype @@ -141,7 +175,8 @@ class TransformerConnectionHandler(ConnectionHandler): assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" # parse deep prompts (optional argument) - if prompts is None or is_dummy(prompts): + has_prompts = prompts is not None and not is_dummy(prompts) + if not has_prompts: prompts = [None] * len(requested_backends) else: prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] @@ -180,25 +215,136 @@ class TransformerConnectionHandler(ConnectionHandler): ) # serialize and send last layer outputs - yield runtime_pb2.ExpertResponse( - tensors=[ - serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) - for result, proto in zip( - (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema) - ) - ] - ) + output_tensors = [ + serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) + for result, proto in zip( + (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema) + ) + ] + if not has_prompts: + task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata)) + background_tasks.add(task) # Keep reference until it is done to save it from GC + task.add_done_callback(background_tasks.discard) + yield runtime_pb2.ExpertResponse(tensors=output_tensors) # prepare for next step - prefix_length += hidden_states.shape[1] - try: - request = await asyncio.wait_for(anext(requests), self.step_timeout) - except asyncio.TimeoutError: - self._log_request("rpc_inference.step", requested_uids, context, warning="timed out") - return + prefix_length += length_increment finally: self._log_request("rpc_inference.close", requested_uids, context) + async def _iterate_inference_steps( + self, + first_request: runtime_pb2.ExpertRequest, + requests: AsyncIterator[runtime_pb2.ExpertRequest], + session_id: Optional[str], + requested_uids: Sequence[str], + context: P2PContext, + ) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]: + loop = asyncio.get_event_loop() + if session_id is not None: + push_queue = self._push_manager.Queue() + self._session_queues[session_id] = push_queue + + processed_step_ids = set() + n_pushes = n_late_pushes = 0 + request = first_request + anext_task = get_push_task = None + try: + while request.tensors: # iterate while user is willing to supply tensors + metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} + step_id = metadata.get("step_id") + + pushed = metadata.get("pushed") + if pushed: + n_pushes += 1 + + if step_id is None or step_id not in processed_step_ids: + yield request, metadata + if step_id is not None: + processed_step_ids.add(step_id) + elif pushed: + n_late_pushes += 1 + self._log_request( + "rpc_inference.push", + requested_uids, + context, + warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time", + ) + + # Wait for the next request, coming either from the `requests` iterator or `push_queue` + if anext_task is None: + anext_task = asyncio.create_task(anext(requests)) + if get_push_task is None: + if session_id is not None: + get_push_task = loop.run_in_executor(self._executor, push_queue.get) + else: + get_push_task = asyncio.create_task(asyncio.Event().wait()) # Dummy never-ending task + done, _ = await asyncio.wait( + [anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED + ) + + if anext_task in done: + request = await anext_task + anext_task = None + elif get_push_task in done: + request = await get_push_task + get_push_task = None + else: + self._log_request("rpc_inference.step", requested_uids, context, warning="timed out") + anext_task.cancel() + get_push_task.cancel() + return + except: + logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True) + raise + finally: + if session_id is not None: + push_queue.put(None) # Stop thread for get_push_task + del self._session_queues[session_id] + + async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: + """Directly push activation tensors from one server to another""" + + requested_uids = self._check_uids(request.uid) + self._log_request("rpc_push", requested_uids, context) + + metadata = MSGPackSerializer.loads(request.metadata) + session_id = metadata["session_id"] + self._session_queues[session_id].put(request) + return runtime_pb2.ExpertResponse() + + async def _push_outputs( + self, request: runtime_pb2.ExpertRequest, serialized_outputs: runtime_pb2.Tensor, metadata: dict + ) -> None: + try: + next_servers = metadata.get("next_servers") + if not next_servers: + return + + next_peer_id, next_session_id, next_start, next_end = next_servers[0] + next_peer_id = PeerID.from_base58(next_peer_id) + next_uid = CHAIN_DELIMITER.join(f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(next_start, next_end)) + + # Sending hidden states serialized with output_schema to avoid double serialization + next_tensors = [serialized_outputs] + request.tensors[1:] + next_metadata = metadata.copy() + next_metadata.update(session_id=next_session_id, next_servers=next_servers[1:], pushed=True) + + stub = self.get_stub(self._p2p, next_peer_id) + await stub.rpc_push( + runtime_pb2.ExpertRequest( + uid=next_uid, + tensors=next_tensors, + metadata=MSGPackSerializer.dumps(next_metadata), + ), + timeout=self.request_timeout, + ) + except Exception: + logger.debug( + f"Failed to push outputs to peer_id={next_peer_id}, session_id={next_session_id}, blocks={next_start}:{next_end}:", + exc_info=True, + ) + async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: async with timeout(self.request_timeout): # Parse request and prepare backends @@ -348,7 +494,7 @@ class TransformerConnectionHandler(ConnectionHandler): @contextlib.asynccontextmanager async def _allocate_cache( self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int - ) -> Sequence[Sequence[Handle, ...]]: + ) -> Sequence[Sequence[Handle]]: """ Allocate memory cache for all transformer blocks, return cache handle :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend @@ -358,7 +504,13 @@ class TransformerConnectionHandler(ConnectionHandler): yield nested_pack(handles, descriptors) def _log_request( - self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None + self, + method: str, + uids: Optional[Sequence[ModuleUID]], + context: P2PContext, + *, + debug: Optional[str] = None, + warning: Optional[str] = None, ) -> None: if uids is not None: friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid] @@ -370,10 +522,12 @@ class TransformerConnectionHandler(ConnectionHandler): friendly_remote_id = "..." + str(context.remote_id)[-6:] message = f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})" - if warning is None: - logger.info(message) - else: + if warning is not None: logger.warning(f"{message}: {warning}") + elif debug is not None: + logger.debug(f"{message}: {debug}") + else: + logger.info(message) async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo: """Return metadata about stored block uids and current load""" diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 2fbaad2..894e9ea 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -45,7 +45,7 @@ class Server: self, *, initial_peers: List[str], - prefix: Optional[str], + dht_prefix: Optional[str], converted_model_name_or_path: str, throughput: Union[float, str], num_blocks: Optional[int] = None, @@ -105,13 +105,13 @@ class Server: revision=revision, ) - if prefix is None: - prefix = self.block_config.dht_prefix - assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, ( + if dht_prefix is None: + dht_prefix = self.block_config.dht_prefix + assert UID_DELIMITER not in dht_prefix and CHAIN_DELIMITER not in dht_prefix, ( f"DHT prefix should not contain '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'. " - f"Please specify another --prefix manually when starting a server" + f"Please specify another --dht_prefix manually when starting a server" ) - self.prefix = prefix + self.dht_prefix = dht_prefix if expiration is None: expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) @@ -121,7 +121,8 @@ class Server: self.session_timeout, self.step_timeout = session_timeout, step_timeout self.module_uids = [ - f"{self.prefix}.{block_index}" for block_index in range(self.block_config.num_hidden_layers) + f"{self.dht_prefix}{UID_DELIMITER}{block_index}" + for block_index in range(self.block_config.num_hidden_layers) ] if dht_client_mode is None: @@ -258,7 +259,7 @@ class Server: block_indices = self._choose_blocks() self.module_container = ModuleContainer.create( dht=self.dht, - prefix=self.prefix, + dht_prefix=self.dht_prefix, converted_model_name_or_path=self.converted_model_name_or_path, block_config=self.block_config, attn_cache_bytes=self.attn_cache_bytes, @@ -359,7 +360,7 @@ class ModuleContainer(threading.Thread): cls, *, dht: DHT, - prefix: str, + dht_prefix: str, converted_model_name_or_path: str, block_config: PretrainedConfig, attn_cache_bytes: int, @@ -382,7 +383,7 @@ class ModuleContainer(threading.Thread): should_validate_reachability: bool, **kwargs, ) -> ModuleContainer: - module_uids = [f"{prefix}.{block_index}" for block_index in block_indices] + module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices] joining_announcer = ModuleAnnouncerThread( module_uids, dht, @@ -459,6 +460,7 @@ class ModuleContainer(threading.Thread): return cls( dht, + dht_prefix, blocks, throughput=throughput, update_period=update_period, @@ -469,6 +471,7 @@ class ModuleContainer(threading.Thread): def __init__( self, dht: DHT, + dht_prefix: str, module_backends: Dict[str, TransformerBackend], *, inference_max_length: int, @@ -486,10 +489,17 @@ class ModuleContainer(threading.Thread): self.dht, self.module_backends = dht, module_backends self.throughput, self.update_period, self.expiration = throughput, update_period, expiration + + self.push_manager = mp.Manager() + self.push_manager.__enter__() + session_queues = self.push_manager.dict() self.conn_handlers = [ TransformerConnectionHandler( dht, self.module_backends, + dht_prefix=dht_prefix, + push_manager=self.push_manager, + session_queues=session_queues, inference_max_length=inference_max_length, request_timeout=request_timeout, session_timeout=session_timeout, @@ -497,6 +507,7 @@ class ModuleContainer(threading.Thread): ) for _ in range(num_handlers) ] + self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs) # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed. self.online_announcer = ModuleAnnouncerThread( @@ -577,6 +588,7 @@ class ModuleContainer(threading.Thread): logger.debug("Shutting down connection handlers") for handler in self.conn_handlers: handler.shutdown() + self.push_manager.__exit__(None, None, None) logger.debug(f"Shutting down pools") for pool in self.runtime.pools: