|
|
|
@ -1,7 +1,8 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
import contextlib
|
|
|
|
|
import itertools
|
|
|
|
|
import time
|
|
|
|
|
from typing import AsyncIterator, List, Optional
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
@ -13,26 +14,25 @@ from hivemind import (
|
|
|
|
|
get_logger,
|
|
|
|
|
nested_flatten,
|
|
|
|
|
serialize_torch_tensor,
|
|
|
|
|
use_hivemind_log_handler,
|
|
|
|
|
)
|
|
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
|
|
from hivemind.p2p import StubBase
|
|
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
|
|
from hivemind.utils.asyncio import aiter_with_timeout
|
|
|
|
|
|
|
|
|
|
from src.client.sequence_manager import RemoteSequenceManager
|
|
|
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
|
|
|
|
|
from src.server.handler import TransformerConnectionHandler
|
|
|
|
|
from src.utils.misc import DUMMY, is_dummy
|
|
|
|
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RemoteTransformerBlockInferenceSession:
|
|
|
|
|
class _ServerInferenceSession:
|
|
|
|
|
"""
|
|
|
|
|
An interface to a single multi-step *inference* session for a specific remote module on a specific server
|
|
|
|
|
An interface to a single multi-step *inference* session for a a set of blocks on a specific server.
|
|
|
|
|
|
|
|
|
|
:note: this inference session is *not* fault-tolerant out of the box
|
|
|
|
|
:note: This class is *not* fault-tolerant out of the box.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
@ -42,32 +42,35 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
|
|
inputs_queue: asyncio.Queue,
|
|
|
|
|
outputs_aiter: AsyncIterator,
|
|
|
|
|
*,
|
|
|
|
|
timeout: float,
|
|
|
|
|
max_length: int,
|
|
|
|
|
points: int = 0,
|
|
|
|
|
):
|
|
|
|
|
self.uid, self.rpc_info = uid, rpc_info
|
|
|
|
|
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
|
|
|
|
|
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
|
|
|
|
|
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
|
|
|
|
|
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
|
|
|
|
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
|
|
|
|
self.timeout = timeout
|
|
|
|
|
self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
|
|
|
|
|
self.stepped = False
|
|
|
|
|
self.closed = False
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def _create(
|
|
|
|
|
cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata
|
|
|
|
|
) -> RemoteTransformerBlockInferenceSession:
|
|
|
|
|
async def create(
|
|
|
|
|
cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata
|
|
|
|
|
) -> _ServerInferenceSession:
|
|
|
|
|
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
|
|
inputs_queue = asyncio.Queue()
|
|
|
|
|
outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
|
|
|
|
|
return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata)
|
|
|
|
|
outputs_stream = await asyncio.wait_for(
|
|
|
|
|
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
|
|
|
|
|
timeout,
|
|
|
|
|
)
|
|
|
|
|
return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
|
|
|
|
|
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
|
|
|
|
|
while True:
|
|
|
|
|
next_input_message = await asyncio.wait_for(queue.get(), timeout)
|
|
|
|
|
next_input_message = await asyncio.wait_for(queue.get(), input_timeout)
|
|
|
|
|
yield next_input_message
|
|
|
|
|
if not next_input_message.uid and not next_input_message.tensors:
|
|
|
|
|
break # this message means "done sending"
|
|
|
|
@ -77,7 +80,7 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
|
|
new_hidden_states: torch.Tensor,
|
|
|
|
|
prompts: Optional[torch.Tensor] = None,
|
|
|
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
):
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
|
Inference step: send a chunk of input tesors and receive a chunk of outputs
|
|
|
|
|
:prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
|
|
|
|
@ -122,7 +125,7 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
|
|
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
|
|
await self._inputs_queue.put(inputs_serialized)
|
|
|
|
|
self.stepped = True
|
|
|
|
|
return await anext(self._outputs_stream)
|
|
|
|
|
return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
|
|
|
|
|
|
|
|
|
|
def close(self):
|
|
|
|
|
"""Finish a given inference session, close the underlying connection"""
|
|
|
|
@ -154,60 +157,163 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
|
|
self.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RemoteSequentialInferenceSession:
|
|
|
|
|
class InferenceSession:
|
|
|
|
|
"""
|
|
|
|
|
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None, **metadata):
|
|
|
|
|
self.sequence_manager = sequence_manager
|
|
|
|
|
self.p2p = p2p
|
|
|
|
|
self.closed = False
|
|
|
|
|
self.chosen_spans: List[RemoteSpanInfo] = []
|
|
|
|
|
self.stack = contextlib.ExitStack()
|
|
|
|
|
self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
|
|
|
|
|
self.metadata = metadata
|
|
|
|
|
self.timeout = timeout
|
|
|
|
|
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int, **metadata):
|
|
|
|
|
self._sequence_manager = sequence_manager
|
|
|
|
|
self._p2p = p2p
|
|
|
|
|
self._closed = False
|
|
|
|
|
self._chosen_spans = []
|
|
|
|
|
self._server_sessions = []
|
|
|
|
|
self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
|
|
|
|
|
self._position = 0
|
|
|
|
|
self._max_length = max_length
|
|
|
|
|
self._metadata = metadata
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
assert not self.closed and not self.chosen_spans
|
|
|
|
|
self.stack.__enter__()
|
|
|
|
|
# TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
|
|
|
|
|
self.chosen_spans.extend(self.sequence_manager.make_sequence())
|
|
|
|
|
|
|
|
|
|
for chosen_span in self.chosen_spans:
|
|
|
|
|
stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
|
|
|
|
|
span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
|
|
|
|
|
inference_session = RemoteExpertWorker.run_coroutine(
|
|
|
|
|
RemoteTransformerBlockInferenceSession._create(
|
|
|
|
|
stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata
|
|
|
|
|
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
|
|
|
|
|
server_sessions = []
|
|
|
|
|
try:
|
|
|
|
|
for span in chosen_spans:
|
|
|
|
|
stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
|
|
|
|
|
span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
|
|
|
|
|
session = RemoteExpertWorker.run_coroutine(
|
|
|
|
|
_ServerInferenceSession.create(
|
|
|
|
|
stub,
|
|
|
|
|
span_uids,
|
|
|
|
|
rpc_info=self._sequence_manager.rpc_info,
|
|
|
|
|
timeout=self._sequence_manager.timeout,
|
|
|
|
|
max_length=self._max_length,
|
|
|
|
|
**self._metadata,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
self.inference_sessions.append(inference_session)
|
|
|
|
|
self.stack.enter_context(inference_session)
|
|
|
|
|
server_sessions.append(session)
|
|
|
|
|
session.__enter__()
|
|
|
|
|
return server_sessions
|
|
|
|
|
except:
|
|
|
|
|
self._exit_server_sessions(server_sessions)
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None:
|
|
|
|
|
for session in reversed(server_sessions):
|
|
|
|
|
try:
|
|
|
|
|
session.__exit__(None, None, None)
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.debug("Caught exception while closing connection to server:", exc_info=True)
|
|
|
|
|
|
|
|
|
|
def __enter__(self) -> "InferenceSession":
|
|
|
|
|
assert not self._closed and not self._chosen_spans
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
|
|
|
|
|
assert not self.closed
|
|
|
|
|
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
|
|
|
|
assert not self._closed
|
|
|
|
|
if torch.is_grad_enabled():
|
|
|
|
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
|
|
|
|
|
|
|
|
|
n_blocks = len(self._sequence_manager)
|
|
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
|
|
prompts = DUMMY
|
|
|
|
|
else:
|
|
|
|
|
assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
|
|
|
|
|
for session in self.inference_sessions:
|
|
|
|
|
outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
|
|
|
|
|
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
|
|
|
|
|
inputs = outputs
|
|
|
|
|
assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
|
|
|
|
|
|
|
|
|
|
n_input_tokens = inputs.shape[1]
|
|
|
|
|
if self._position + n_input_tokens > self._max_length:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
server_idx = 0
|
|
|
|
|
block_idx = 0
|
|
|
|
|
recovery_until = -1 # Recovery mode is disabled until a failure happens
|
|
|
|
|
while block_idx < n_blocks:
|
|
|
|
|
for attempt_no in itertools.count():
|
|
|
|
|
logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
|
|
|
|
|
try:
|
|
|
|
|
if attempt_no >= 1:
|
|
|
|
|
self._sequence_manager.update_()
|
|
|
|
|
if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
|
|
|
|
|
# If there is a failed server session, this code closes it
|
|
|
|
|
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
|
|
|
|
|
|
|
|
|
|
n_prev_spans = len(self._chosen_spans)
|
|
|
|
|
update_end = self._chosen_spans[server_idx].end if server_idx < n_prev_spans else n_blocks
|
|
|
|
|
if attempt_no >= 1 and update_end > recovery_until:
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Due to a server failure, remote attention caches "
|
|
|
|
|
f"from block {block_idx} to {update_end} will be regenerated"
|
|
|
|
|
)
|
|
|
|
|
recovery_until = max(recovery_until, update_end)
|
|
|
|
|
|
|
|
|
|
updated_spans = self._sequence_manager.make_sequence(block_idx, update_end)
|
|
|
|
|
# make_sequence() could return a longer sequence
|
|
|
|
|
updated_spans[-1].end = min(updated_spans[-1].end, update_end)
|
|
|
|
|
updated_sessions = self._enter_server_sessions(updated_spans)
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# If there is a failed span, this code replaces it, otherwise it just adds new ones
|
|
|
|
|
self._chosen_spans[server_idx : server_idx + 1] = updated_spans
|
|
|
|
|
self._server_sessions[server_idx : server_idx + 1] = updated_sessions
|
|
|
|
|
recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
|
|
|
|
|
self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (
|
|
|
|
|
len(updated_spans) - 1
|
|
|
|
|
)
|
|
|
|
|
assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), (
|
|
|
|
|
f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, "
|
|
|
|
|
f"{len(self._server_inputs)} inputs"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
session = self._server_sessions[server_idx]
|
|
|
|
|
span = self._chosen_spans[server_idx]
|
|
|
|
|
|
|
|
|
|
if self._server_inputs[server_idx] is None:
|
|
|
|
|
self._server_inputs[server_idx] = inputs
|
|
|
|
|
elif self._server_inputs[server_idx].shape[1] == self._position:
|
|
|
|
|
self._server_inputs[server_idx] = torch.cat(
|
|
|
|
|
[self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
|
|
|
|
|
)
|
|
|
|
|
assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, (
|
|
|
|
|
f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} "
|
|
|
|
|
f"position={self._position} n_input_tokens={n_input_tokens}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not session.stepped:
|
|
|
|
|
inputs = self._server_inputs[server_idx] # Pass full inputs including prefix
|
|
|
|
|
else:
|
|
|
|
|
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
|
|
|
|
|
|
|
|
|
outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
|
|
|
|
|
assert (
|
|
|
|
|
inputs.shape == outputs.shape
|
|
|
|
|
), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
|
|
|
|
|
|
|
|
|
|
inputs = outputs
|
|
|
|
|
server_idx += 1
|
|
|
|
|
block_idx = span.end
|
|
|
|
|
break
|
|
|
|
|
except Exception as e:
|
|
|
|
|
delay = self._sequence_manager.get_retry_delay(attempt_no)
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Caught exception when running inference from block {block_idx} "
|
|
|
|
|
f"(retry in {delay:.0f} sec): {repr(e)}"
|
|
|
|
|
)
|
|
|
|
|
logger.debug("See detailed traceback below:", exc_info=True)
|
|
|
|
|
time.sleep(delay)
|
|
|
|
|
|
|
|
|
|
self._position += n_input_tokens
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
def close(self, *exc_details):
|
|
|
|
|
"""Finish a given inference session, close the underlying connection"""
|
|
|
|
|
if not self.closed:
|
|
|
|
|
self.stack.__exit__(*exc_details or (None, None, None))
|
|
|
|
|
self.inference_sessions.clear()
|
|
|
|
|
self.closed = True
|
|
|
|
|
if not self._closed:
|
|
|
|
|
self._server_inputs.clear()
|
|
|
|
|
self._exit_server_sessions(self._server_sessions)
|
|
|
|
|
self._server_sessions.clear()
|
|
|
|
|
self._chosen_spans.clear()
|
|
|
|
|
self._closed = True
|
|
|
|
|
|
|
|
|
|
def __exit__(self, *exc_details):
|
|
|
|
|
self.close(*exc_details)
|
|
|
|
|