Implement direct server-to-server communication (#331)

Implement #226.
pull/340/head
Alexander Borzunov 10 months ago committed by GitHub
parent 4d9c26fe5c
commit 158013a671
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,7 +9,7 @@ from petals.models import *
from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "1.2.0.dev0"
__version__ = "1.2.0.dev1"
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

@ -27,8 +27,7 @@ def main():
parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
"use the same name as in the converted model.")
parser.add_argument('--dht_prefix', type=str, default=None, help="Announce all blocks with this DHT prefix")
parser.add_argument('--port', type=int, required=False,
help='Port this server listens to. '

@ -3,7 +3,8 @@ from __future__ import annotations
import asyncio
import itertools
import time
from typing import AsyncIterator, List, Optional
import uuid
from typing import AsyncIterator, List, Optional, Tuple
import torch
from hivemind import (
@ -15,10 +16,10 @@ from hivemind import (
serialize_torch_tensor,
)
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import StubBase
from hivemind.p2p import P2P
from hivemind.proto import runtime_pb2
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
@ -35,35 +36,48 @@ class _ServerInferenceSession:
def __init__(
self,
config: SequenceManagerConfig,
span: RemoteSpanInfo,
uid: ModuleUID,
rpc_info: RPCInfo,
inputs_queue: asyncio.Queue,
outputs_aiter: AsyncIterator,
*,
timeout: float,
max_length: int,
**metadata,
):
self.uid, self.rpc_info = uid, rpc_info
self.config = config
self.span, self.uid, self.rpc_info = span, uid, rpc_info
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self.timeout = timeout
self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, **metadata))
self.session_id = str(uuid.uuid4())
self.session_metadata = dict(max_length=max_length, **metadata)
self.stepped = False
self.closed = False
self._position = 0
self.history = None # Used in case of server failures to regenerate attention caches on new servers
self.next_session = None
@classmethod
async def create(
cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata
cls,
config: SequenceManagerConfig,
p2p: P2P,
span: RemoteSpanInfo,
uid: ModuleUID,
rpc_info: RPCInfo,
**metadata,
) -> _ServerInferenceSession:
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
inputs_queue = asyncio.Queue()
outputs_stream = await asyncio.wait_for(
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
timeout,
config.request_timeout,
)
return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)
@staticmethod
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
@ -75,9 +89,11 @@ class _ServerInferenceSession:
def step(
self,
new_hidden_states: torch.Tensor,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
*,
step_id: str,
) -> torch.Tensor:
"""
Inference step: send a chunk of input tesors and receive a chunk of outputs
@ -86,44 +102,84 @@ class _ServerInferenceSession:
"""
if self.closed:
raise Exception("Session is closed, cannot perform step")
n_input_tokens = inputs.shape[1]
if self.history is None:
self.history = inputs
elif self.history.shape[1] == self._position:
self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1)
assert self.history.shape[1] == self._position + n_input_tokens, (
f"Broken input cache: span={self.span} shape={self.history.shape} "
f"position={self._position} n_input_tokens={n_input_tokens}"
)
if not self.stepped:
inputs = self.history # Pass full inputs including prefix
else:
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
assert prompts.ndim == 4, "deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]"
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
assert prompts.shape[0] == self.num_blocks
assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
assert prompts.shape[2] <= new_hidden_states.shape[1]
assert prompts.shape[3] == new_hidden_states.shape[2]
assert prompts.shape[1] in (inputs.shape[0], 1)
assert prompts.shape[2] <= inputs.shape[1]
assert prompts.shape[3] == inputs.shape[2]
if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY
else:
assert len(hypo_ids) == len(new_hidden_states)
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64
# serialize inputs and put them into the queue
inputs = (new_hidden_states, prompts, hypo_ids)
input_tensors = (inputs, prompts, hypo_ids)
request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
request_metadata.update(self.session_metadata)
elif self.config.use_server_to_server:
next_servers = self._collect_next_servers()
if next_servers:
request_metadata["next_servers"] = next_servers
outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=self.uid,
tensors=[
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
for tensor, proto in zip(input_tensors, nested_flatten(self.rpc_info["inference_schema"]))
],
metadata=self._serialized_metadata if not self.stepped else None,
metadata=MSGPackSerializer.dumps(request_metadata),
)
)
)
outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
assert (
outputs[0].shape == inputs.shape
), f"output activation shape is different from input shape: {outputs[0].shape} != {inputs.shape}"
self._position += n_input_tokens
return outputs[0]
def _collect_next_servers(self) -> List[Tuple[str, str, int, int]]:
next_servers = []
session = self.next_session
while session is not None and session.stepped:
next_servers.append(
(session.span.peer_id.to_base58(), session.session_id, session.span.start, session.span.end)
)
session = session.next_session
return next_servers
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
await self._inputs_queue.put(inputs_serialized)
self.stepped = True
return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout)
def close(self):
"""Finish a given inference session, close the underlying connection"""
@ -163,13 +219,15 @@ class InferenceSession:
def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
self._sequence_manager = sequence_manager
self._closed = False
self._chosen_spans = []
self._server_sessions = []
self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
self._position = 0
self._max_length = max_length
self.last_token_id = None
@property
def num_blocks(self) -> int:
return len(self._sequence_manager)
@property
def position(self) -> int:
return self._position
@ -178,15 +236,15 @@ class InferenceSession:
server_sessions = []
try:
for span in chosen_spans:
stub = TransformerConnectionHandler.get_stub(self._sequence_manager.state.p2p, span.peer_id)
span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
session = RemoteExpertWorker.run_coroutine(
_ServerInferenceSession.create(
stub,
self._sequence_manager.config,
self._sequence_manager.state.p2p,
span,
span_uids,
rpc_info=self._sequence_manager.rpc_info,
timeout=self._sequence_manager.config.request_timeout,
max_length=self._max_length,
**metadata,
)
@ -206,7 +264,7 @@ class InferenceSession:
logger.debug("Caught exception while closing connection to server:", exc_info=True)
def __enter__(self) -> "InferenceSession":
assert not self._closed and not self._chosen_spans
assert not self._closed and not self._server_sessions
return self
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
@ -214,16 +272,17 @@ class InferenceSession:
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
n_blocks = len(self._sequence_manager)
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
assert prompts.shape[0] == self.num_blocks
inputs_device = inputs.device
inputs_dtype = inputs.dtype
inputs = inputs.cpu()
prompts = prompts.cpu()
step_id = str(uuid.uuid4())
n_input_tokens = inputs.shape[1]
if self._position + n_input_tokens > self._max_length:
@ -233,97 +292,74 @@ class InferenceSession:
server_idx = 0
block_idx = 0
recovery_until = -1 # Recovery mode is disabled until a failure happens
while block_idx < n_blocks:
while block_idx < self.num_blocks:
for attempt_no in itertools.count():
logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
span = None
server_session = None
try:
if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
# If there is a failed server session, this code closes it
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
n_prev_spans = len(self._chosen_spans)
update_end = self._chosen_spans[server_idx].end if server_idx < n_prev_spans else n_blocks
if attempt_no >= 1 and update_end > recovery_until:
logger.info(
f"Due to a server failure, remote attention caches "
f"from block {block_idx} to {update_end} will be regenerated"
)
recovery_until = max(recovery_until, update_end)
updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency")
# make_sequence() could return a longer sequence
updated_spans[-1].end = min(updated_spans[-1].end, update_end)
updated_sessions = self._enter_server_sessions(updated_spans)
logger.debug(
f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
)
# If there is a failed span, this code replaces it, otherwise it just adds new ones
self._chosen_spans[server_idx : server_idx + 1] = updated_spans
self._server_sessions[server_idx : server_idx + 1] = updated_sessions
recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (
len(updated_spans) - 1
)
assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), (
f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, "
f"{len(self._server_inputs)} inputs"
)
session = self._server_sessions[server_idx]
span = self._chosen_spans[server_idx]
if self._server_inputs[server_idx] is None:
self._server_inputs[server_idx] = inputs
elif self._server_inputs[server_idx].shape[1] == self._position:
self._server_inputs[server_idx] = torch.cat(
[self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
)
assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, (
f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} "
f"position={self._position} n_input_tokens={n_input_tokens}"
)
if not session.stepped:
inputs = self._server_inputs[server_idx] # Pass full inputs including prefix
else:
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
if not self._server_sessions or attempt_no >= 1:
self._update_sequence(server_idx, block_idx, attempt_no)
outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
assert (
inputs.shape == outputs.shape
), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
server_session = self._server_sessions[server_idx]
inputs = server_session.step(
inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs
)
inputs = outputs
server_idx += 1
block_idx = span.end
self._sequence_manager.on_request_success(span.peer_id)
block_idx = server_session.span.end
self._sequence_manager.on_request_success(server_session.span.peer_id)
break
except Exception as e:
self._sequence_manager.on_request_failure(span.peer_id if span is not None else None)
self._sequence_manager.on_request_failure(
server_session.span.peer_id if server_session is not None else None
)
if attempt_no + 1 == self._sequence_manager.config.max_retries:
raise
delay = self._sequence_manager.get_retry_delay(attempt_no)
logger.warning(
f"Caught exception when running inference via {span} (retry in {delay:.0f} sec): {repr(e)}"
f"Caught exception when running inference via {server_session.span if server_session is not None else None} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
maybe_log_traceback(e)
time.sleep(delay)
self._position += n_input_tokens
inputs = inputs[:, -n_input_tokens:]
outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
outputs = inputs[:, -n_input_tokens:]
outputs = outputs.to(device=inputs_device, dtype=inputs_dtype)
return outputs
def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int:
# If there is a failed server session, this code closes it
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
n_prev_spans = len(self._server_sessions)
update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
if attempt_no >= 1:
logger.info(
f"Due to a server failure, remote attention caches "
f"from block {block_idx} to {update_end} will be regenerated"
)
updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency")
# make_sequence() could return a longer sequence
updated_spans[-1].end = min(updated_spans[-1].end, update_end)
updated_sessions = self._enter_server_sessions(updated_spans)
logger.debug(f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers")
# If there is a failed span, this code replaces it, otherwise it just adds new ones
if server_idx < n_prev_spans:
updated_sessions[0].history = self._server_sessions[server_idx].history
self._server_sessions[server_idx : server_idx + 1] = updated_sessions
# Update links to the next server session for direct server-to-server communication via rpc_push()
for i in range(max(server_idx - 1, 0), min(server_idx + len(updated_spans), len(self._server_sessions) - 1)):
self._server_sessions[i].next_session = self._server_sessions[i + 1]
def close(self, *exc_details):
"""Finish a given inference session, close the underlying connection"""
if not self._closed:
self._server_inputs.clear()
self._exit_server_sessions(self._server_sessions)
self._server_sessions.clear()
self._chosen_spans.clear()
self._closed = True
def __exit__(self, *exc_details):

@ -34,6 +34,7 @@ class SequenceManagerConfig:
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers
use_server_to_server: bool = True # Use direct server-to-server communication
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
update_period: float = 60 # refresh DHT information once in this many seconds

@ -2,6 +2,9 @@ from __future__ import annotations
import asyncio
import contextlib
import multiprocessing.managers
import sys
from concurrent.futures import ThreadPoolExecutor
from itertools import chain
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
@ -11,6 +14,7 @@ from hivemind import (
DHT,
MSGPackSerializer,
P2PContext,
PeerID,
deserialize_tensor_stream,
deserialize_torch_tensor,
nested_flatten,
@ -25,7 +29,7 @@ from hivemind.utils.logging import get_logger
from hivemind.utils.streaming import split_for_streaming
import petals
from petals.data_structures import CHAIN_DELIMITER, InferenceMetadata, ModuleUID
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID
from petals.server.backend import TransformerBackend
from petals.server.memory_cache import Handle
from petals.server.task_pool import PrioritizedTaskPool
@ -34,6 +38,23 @@ from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__name__)
# Fix pickling protobufs, see https://stackoverflow.com/a/74873028
sys.modules["runtime_pb2"] = runtime_pb2
# Fix queues in multiprocessing.Manager in Python < 3.9.7, see https://bugs.python.org/issue30256
_OriginalAutoProxy = multiprocessing.managers.AutoProxy
def patched_autoproxy(*args, manager_owned=True, **kwargs):
# Calling original AutoProxy without the unwanted key argument
return _OriginalAutoProxy(*args, **kwargs)
multiprocessing.managers.AutoProxy = patched_autoproxy
CACHE_TOKENS_AVAILABLE = "cache_tokens_available"
@ -47,6 +68,9 @@ class TransformerConnectionHandler(ConnectionHandler):
dht: DHT,
module_backends: Dict[str, TransformerBackend],
*,
dht_prefix: str,
push_manager: multiprocessing.managers.SyncManager,
session_queues: Dict[str, multiprocessing.managers.BaseProxy], # BaseProxy for queue.Queue
inference_max_length: int,
request_timeout: float,
session_timeout: float,
@ -56,6 +80,11 @@ class TransformerConnectionHandler(ConnectionHandler):
super().__init__(dht, module_backends)
for module_backend in self.module_backends.values():
assert isinstance(module_backend, TransformerBackend)
self.dht_prefix = dht_prefix
self._push_manager = push_manager
self._session_queues = session_queues
self._executor = ThreadPoolExecutor(max_workers=float("inf")) # For waiting on self.session_queues
self.inference_max_length = inference_max_length
self.request_timeout = request_timeout
self.session_timeout, self.step_timeout = session_timeout, step_timeout
@ -96,7 +125,7 @@ class TransformerConnectionHandler(ConnectionHandler):
self,
requests: AsyncIterator[runtime_pb2.ExpertRequest],
context: P2PContext,
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
async with timeout(self.session_timeout):
@ -113,6 +142,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
max_length = metadata.get("max_length")
points = metadata.get("points", 0)
session_id = metadata.get("session_id")
if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none")
@ -133,7 +163,11 @@ class TransformerConnectionHandler(ConnectionHandler):
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
assert len(cache_handles) == len(requested_backends)
while request.tensors: # iterate while user is willing to supply tensors
first_request = request
background_tasks = set()
async for request, metadata in self._iterate_inference_steps(
first_request, requests, session_id, requested_uids, context
):
hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
# Cast inputs to backend dtype
@ -141,7 +175,8 @@ class TransformerConnectionHandler(ConnectionHandler):
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
# parse deep prompts (optional argument)
if prompts is None or is_dummy(prompts):
has_prompts = prompts is not None and not is_dummy(prompts)
if not has_prompts:
prompts = [None] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
@ -180,25 +215,136 @@ class TransformerConnectionHandler(ConnectionHandler):
)
# serialize and send last layer outputs
yield runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(
(hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
)
]
)
output_tensors = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(
(hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
)
]
if not has_prompts:
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
background_tasks.add(task) # Keep reference until it is done to save it from GC
task.add_done_callback(background_tasks.discard)
yield runtime_pb2.ExpertResponse(tensors=output_tensors)
# prepare for next step
prefix_length += hidden_states.shape[1]
try:
request = await asyncio.wait_for(anext(requests), self.step_timeout)
except asyncio.TimeoutError:
self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
return
prefix_length += length_increment
finally:
self._log_request("rpc_inference.close", requested_uids, context)
async def _iterate_inference_steps(
self,
first_request: runtime_pb2.ExpertRequest,
requests: AsyncIterator[runtime_pb2.ExpertRequest],
session_id: Optional[str],
requested_uids: Sequence[str],
context: P2PContext,
) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]:
loop = asyncio.get_event_loop()
if session_id is not None:
push_queue = self._push_manager.Queue()
self._session_queues[session_id] = push_queue
processed_step_ids = set()
n_pushes = n_late_pushes = 0
request = first_request
anext_task = get_push_task = None
try:
while request.tensors: # iterate while user is willing to supply tensors
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
step_id = metadata.get("step_id")
pushed = metadata.get("pushed")
if pushed:
n_pushes += 1
if step_id is None or step_id not in processed_step_ids:
yield request, metadata
if step_id is not None:
processed_step_ids.add(step_id)
elif pushed:
n_late_pushes += 1
self._log_request(
"rpc_inference.push",
requested_uids,
context,
warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time",
)
# Wait for the next request, coming either from the `requests` iterator or `push_queue`
if anext_task is None:
anext_task = asyncio.create_task(anext(requests))
if get_push_task is None:
if session_id is not None:
get_push_task = loop.run_in_executor(self._executor, push_queue.get)
else:
get_push_task = asyncio.create_task(asyncio.Event().wait()) # Dummy never-ending task
done, _ = await asyncio.wait(
[anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED
)
if anext_task in done:
request = await anext_task
anext_task = None
elif get_push_task in done:
request = await get_push_task
get_push_task = None
else:
self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
anext_task.cancel()
get_push_task.cancel()
return
except:
logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True)
raise
finally:
if session_id is not None:
push_queue.put(None) # Stop thread for get_push_task
del self._session_queues[session_id]
async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
"""Directly push activation tensors from one server to another"""
requested_uids = self._check_uids(request.uid)
self._log_request("rpc_push", requested_uids, context)
metadata = MSGPackSerializer.loads(request.metadata)
session_id = metadata["session_id"]
self._session_queues[session_id].put(request)
return runtime_pb2.ExpertResponse()
async def _push_outputs(
self, request: runtime_pb2.ExpertRequest, serialized_outputs: runtime_pb2.Tensor, metadata: dict
) -> None:
try:
next_servers = metadata.get("next_servers")
if not next_servers:
return
next_peer_id, next_session_id, next_start, next_end = next_servers[0]
next_peer_id = PeerID.from_base58(next_peer_id)
next_uid = CHAIN_DELIMITER.join(f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(next_start, next_end))
# Sending hidden states serialized with output_schema to avoid double serialization
next_tensors = [serialized_outputs] + request.tensors[1:]
next_metadata = metadata.copy()
next_metadata.update(session_id=next_session_id, next_servers=next_servers[1:], pushed=True)
stub = self.get_stub(self._p2p, next_peer_id)
await stub.rpc_push(
runtime_pb2.ExpertRequest(
uid=next_uid,
tensors=next_tensors,
metadata=MSGPackSerializer.dumps(next_metadata),
),
timeout=self.request_timeout,
)
except Exception:
logger.debug(
f"Failed to push outputs to peer_id={next_peer_id}, session_id={next_session_id}, blocks={next_start}:{next_end}:",
exc_info=True,
)
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
async with timeout(self.request_timeout):
# Parse request and prepare backends
@ -348,7 +494,7 @@ class TransformerConnectionHandler(ConnectionHandler):
@contextlib.asynccontextmanager
async def _allocate_cache(
self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
) -> Sequence[Sequence[Handle, ...]]:
) -> Sequence[Sequence[Handle]]:
"""
Allocate memory cache for all transformer blocks, return cache handle
:returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend
@ -358,7 +504,13 @@ class TransformerConnectionHandler(ConnectionHandler):
yield nested_pack(handles, descriptors)
def _log_request(
self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None
self,
method: str,
uids: Optional[Sequence[ModuleUID]],
context: P2PContext,
*,
debug: Optional[str] = None,
warning: Optional[str] = None,
) -> None:
if uids is not None:
friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
@ -370,10 +522,12 @@ class TransformerConnectionHandler(ConnectionHandler):
friendly_remote_id = "..." + str(context.remote_id)[-6:]
message = f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})"
if warning is None:
logger.info(message)
else:
if warning is not None:
logger.warning(f"{message}: {warning}")
elif debug is not None:
logger.debug(f"{message}: {debug}")
else:
logger.info(message)
async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
"""Return metadata about stored block uids and current load"""

@ -45,7 +45,7 @@ class Server:
self,
*,
initial_peers: List[str],
prefix: Optional[str],
dht_prefix: Optional[str],
converted_model_name_or_path: str,
throughput: Union[float, str],
num_blocks: Optional[int] = None,
@ -105,13 +105,13 @@ class Server:
revision=revision,
)
if prefix is None:
prefix = self.block_config.dht_prefix
assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
if dht_prefix is None:
dht_prefix = self.block_config.dht_prefix
assert UID_DELIMITER not in dht_prefix and CHAIN_DELIMITER not in dht_prefix, (
f"DHT prefix should not contain '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'. "
f"Please specify another --prefix manually when starting a server"
f"Please specify another --dht_prefix manually when starting a server"
)
self.prefix = prefix
self.dht_prefix = dht_prefix
if expiration is None:
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
@ -121,7 +121,8 @@ class Server:
self.session_timeout, self.step_timeout = session_timeout, step_timeout
self.module_uids = [
f"{self.prefix}.{block_index}" for block_index in range(self.block_config.num_hidden_layers)
f"{self.dht_prefix}{UID_DELIMITER}{block_index}"
for block_index in range(self.block_config.num_hidden_layers)
]
if dht_client_mode is None:
@ -258,7 +259,7 @@ class Server:
block_indices = self._choose_blocks()
self.module_container = ModuleContainer.create(
dht=self.dht,
prefix=self.prefix,
dht_prefix=self.dht_prefix,
converted_model_name_or_path=self.converted_model_name_or_path,
block_config=self.block_config,
attn_cache_bytes=self.attn_cache_bytes,
@ -359,7 +360,7 @@ class ModuleContainer(threading.Thread):
cls,
*,
dht: DHT,
prefix: str,
dht_prefix: str,
converted_model_name_or_path: str,
block_config: PretrainedConfig,
attn_cache_bytes: int,
@ -382,7 +383,7 @@ class ModuleContainer(threading.Thread):
should_validate_reachability: bool,
**kwargs,
) -> ModuleContainer:
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
joining_announcer = ModuleAnnouncerThread(
module_uids,
dht,
@ -459,6 +460,7 @@ class ModuleContainer(threading.Thread):
return cls(
dht,
dht_prefix,
blocks,
throughput=throughput,
update_period=update_period,
@ -469,6 +471,7 @@ class ModuleContainer(threading.Thread):
def __init__(
self,
dht: DHT,
dht_prefix: str,
module_backends: Dict[str, TransformerBackend],
*,
inference_max_length: int,
@ -486,10 +489,17 @@ class ModuleContainer(threading.Thread):
self.dht, self.module_backends = dht, module_backends
self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
self.push_manager = mp.Manager()
self.push_manager.__enter__()
session_queues = self.push_manager.dict()
self.conn_handlers = [
TransformerConnectionHandler(
dht,
self.module_backends,
dht_prefix=dht_prefix,
push_manager=self.push_manager,
session_queues=session_queues,
inference_max_length=inference_max_length,
request_timeout=request_timeout,
session_timeout=session_timeout,
@ -497,6 +507,7 @@ class ModuleContainer(threading.Thread):
)
for _ in range(num_handlers)
]
self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
# note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
self.online_announcer = ModuleAnnouncerThread(
@ -577,6 +588,7 @@ class ModuleContainer(threading.Thread):
logger.debug("Shutting down connection handlers")
for handler in self.conn_handlers:
handler.shutdown()
self.push_manager.__exit__(None, None, None)
logger.debug(f"Shutting down pools")
for pool in self.runtime.pools:

Loading…
Cancel
Save