You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/client/inference_session.py

423 lines
17 KiB
Python

from __future__ import annotations
import asyncio
import itertools
import time
import uuid
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple
import torch
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.proto import runtime_pb2
from hivemind.utils import MSGPackSerializer, anext, get_logger, nested_flatten
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy
from petals.utils.packaging import pack_args_kwargs
logger = get_logger(__name__)
class _ServerInferenceSession:
"""
An interface to a single multi-step *inference* session for a a set of blocks on a specific server.
:note: This class is *not* fault-tolerant out of the box.
"""
def __init__(
self,
sequence_manager: RemoteSequenceManager,
span: RemoteSpanInfo,
span_uids: Sequence[ModuleUID],
inputs_queue: asyncio.Queue,
outputs_aiter: AsyncIterator,
*block_kwargs,
max_length: int,
):
self.sequence_manager = sequence_manager
self.span, self.span_uids = span, span_uids
self.num_blocks = len(span_uids)
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self.session_id = str(uuid.uuid4())
self.max_length = max_length
self.stepped = False
self.closed = False
self._position = 0
self.history = None # Used in case of server failures to regenerate attention caches on new servers
self.next_session = None
self.block_kwargs = block_kwargs
assert len(self.block_kwargs) in (0, self.num_blocks)
@classmethod
async def create(
cls,
sequence_manager: RemoteSequenceManager,
span: RemoteSpanInfo,
span_uids: Sequence[ModuleUID],
*block_kwargs: Dict[str, Any],
**kwargs,
) -> _ServerInferenceSession:
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
inputs_queue = asyncio.Queue()
outputs_stream = await asyncio.wait_for(
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
sequence_manager.config.connect_timeout,
)
return cls(sequence_manager, span, span_uids, inputs_queue, outputs_stream, *block_kwargs, **kwargs)
@staticmethod
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
while True:
next_input_message = await asyncio.wait_for(queue.get(), input_timeout)
yield next_input_message
if not next_input_message.uid and not next_input_message.tensors:
break # this message means "done sending"
def step(
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
*,
hypo_ids: Optional[torch.Tensor] = None,
step_id: str,
) -> torch.Tensor:
"""
Inference step: send a chunk of input tensors and receive a chunk of outputs
:param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
"""
if self.closed:
raise Exception("Session is closed, cannot perform step")
n_input_tokens = inputs.shape[1]
if self.history is None:
self.history = inputs
elif self.history.shape[1] == self._position:
self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1)
assert self.history.shape[1] == self._position + n_input_tokens, (
f"Broken input cache: span={self.span} shape={self.history.shape} "
f"position={self._position} n_input_tokens={n_input_tokens}"
)
if not self.stepped:
inputs = self.history # Pass full inputs including prefix
block_kwargs = self.block_kwargs
else:
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
block_kwargs = []
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
assert prompts.shape[0] == self.num_blocks
assert prompts.shape[1] in (inputs.shape[0], 1)
assert prompts.shape[2] <= inputs.shape[1]
assert prompts.shape[3] == inputs.shape[2]
if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY_INT64
else:
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64
metadata = dict(session_id=self.session_id, step_id=step_id, max_length=self.max_length)
metadata.update(
self.sequence_manager.get_request_metadata(
self.span.peer_id,
"rpc_inference",
self.span_uids,
inputs,
prompts,
*block_kwargs,
max_length=self.max_length,
session_id=self.session_id,
step_id=step_id,
)
)
if self.stepped and self.sequence_manager.config.use_server_to_server:
next_servers = self._collect_next_servers()
if next_servers:
metadata["next_servers"] = next_servers
codecs = self.sequence_manager.get_compression_codecs(
self.span.peer_id, "rpc_inference", self.span_uids, inputs, prompts, *block_kwargs
)
# serialize inputs and put them into the queue
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs)
args_structure = metadata.setdefault("args_structure", args_structure)
if codecs is None:
codecs = [runtime_pb2.CompressionType.NONE] * len(input_tensors)
else:
codecs = list(nested_flatten(codecs))
assert len(codecs) == len(
input_tensors
), f"got {len(input_tensors)} tensors but {len(codecs)} compression codecs"
outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=CHAIN_DELIMITER.join(self.span_uids),
tensors=[
serialize_torch_tensor(tensor, compression)
for tensor, compression in zip(input_tensors, codecs)
],
metadata=MSGPackSerializer.dumps(metadata),
)
)
)
outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
assert (
outputs[0].shape == inputs.shape
), f"output activation shape is different from input shape: {outputs[0].shape} != {inputs.shape}"
self._position += n_input_tokens
return outputs[0]
def _collect_next_servers(self) -> List[Tuple[str, str, int, int]]:
next_servers = []
session = self.next_session
while session is not None and session.stepped:
next_servers.append(
(session.span.peer_id.to_base58(), session.session_id, session.span.start, session.span.end)
)
session = session.next_session
return next_servers
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
await self._inputs_queue.put(inputs_serialized)
self.stepped = True
return await asyncio.wait_for(anext(self._outputs_stream), self.sequence_manager.config.request_timeout)
def close(self):
"""Finish a given inference session, close the underlying connection"""
if self._outputs_stream is None:
return # already closed
RemoteExpertWorker.run_coroutine(self._aclose_stream())
self._outputs_stream = self._inputs_queue = None
self.closed = True
async def _aclose_stream(self):
"""Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
if self._outputs_stream is None:
return # already closed
if self.stepped:
await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
try:
await anext(self._outputs_stream)
except StopAsyncIteration:
pass
def __del__(self):
self.close()
def __enter__(self):
assert not self.closed
return self
def __exit__(self, *exc_details):
self.close()
class InferenceSession:
"""
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
"""
def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int, *block_kwargs: Dict[str, Any]):
self._sequence_manager = sequence_manager
self._closed = False
self._server_sessions = []
self._position = 0
self._max_length = max_length
self.output_ids = None
num_blocks = len(self._sequence_manager)
if len(block_kwargs) == 1:
block_kwargs = block_kwargs * num_blocks
assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}"
self.block_kwargs = block_kwargs
@property
def num_blocks(self) -> int:
return len(self._sequence_manager)
@property
def position(self) -> int:
return self._position
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
server_sessions = []
try:
for span in chosen_spans:
session = RemoteExpertWorker.run_coroutine(
_ServerInferenceSession.create(
self._sequence_manager,
span,
self._sequence_manager.block_uids[span.start : span.end],
*self.block_kwargs[span.start : span.end],
max_length=self._max_length,
)
)
server_sessions.append(session)
session.__enter__()
return server_sessions
except:
self._exit_server_sessions(server_sessions)
raise
def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None:
for session in reversed(server_sessions):
try:
session.__exit__(None, None, None)
except Exception:
logger.debug("Caught exception while closing connection to server:", exc_info=True)
def __enter__(self) -> "InferenceSession":
assert not self._closed and not self._server_sessions
return self
def step(
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
assert prompts.shape[0] == self.num_blocks
assert prompts.shape[1] in (inputs.shape[0], 1)
assert prompts.shape[2] <= inputs.shape[1]
assert prompts.shape[3] == inputs.shape[2]
if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY_INT64
else:
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64
inputs_device = inputs.device
inputs_dtype = inputs.dtype
inputs = inputs.cpu()
prompts = prompts.cpu()
hypo_ids = hypo_ids.cpu()
step_id = str(uuid.uuid4())
n_input_tokens = inputs.shape[1]
if self._position + n_input_tokens > self._max_length:
raise ValueError(
f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}"
)
server_idx = 0
block_idx = 0
while block_idx < self.num_blocks:
for attempt_no in itertools.count():
logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
server_session = None
try:
if not self._server_sessions or attempt_no >= 1:
self._update_sequence(server_idx, block_idx, attempt_no)
server_session = self._server_sessions[server_idx]
inputs = server_session.step(
inputs,
prompts[server_session.span.start : server_session.span.end],
hypo_ids=hypo_ids,
step_id=step_id,
)
server_idx += 1
block_idx = server_session.span.end
self._sequence_manager.on_request_success(server_session.span.peer_id)
break
except Exception as e:
self._sequence_manager.on_request_failure(
server_session.span.peer_id if server_session is not None else None
)
if attempt_no + 1 == self._sequence_manager.config.max_retries:
raise
delay = self._sequence_manager.get_retry_delay(attempt_no)
logger.warning(
f"Caught exception when running inference via {server_session.span if server_session is not None else None} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
maybe_log_traceback(e)
time.sleep(delay)
self._position += n_input_tokens
outputs = inputs[:, -n_input_tokens:]
outputs = outputs.to(device=inputs_device, dtype=inputs_dtype)
return outputs
def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int):
# If there is a failed server session, this code closes it
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
n_prev_spans = len(self._server_sessions)
update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
if attempt_no >= 1:
logger.debug(
f"Due to a server failure, remote attention caches "
f"from block {block_idx} to {update_end} will be regenerated"
)
updated_spans = self._sequence_manager.make_sequence(
block_idx, update_end, mode="min_latency", cache_tokens_needed=self._max_length
)
# make_sequence() could return a longer sequence
updated_spans[-1].end = min(updated_spans[-1].end, update_end)
updated_sessions = self._enter_server_sessions(updated_spans)
logger.debug(f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers")
# If there is a failed span, this code replaces it, otherwise it just adds new ones
if server_idx < n_prev_spans:
updated_sessions[0].history = self._server_sessions[server_idx].history
self._server_sessions[server_idx : server_idx + 1] = updated_sessions
# Update links to the next server session for direct server-to-server communication via rpc_push()
for i in range(max(server_idx - 1, 0), min(server_idx + len(updated_spans), len(self._server_sessions) - 1)):
self._server_sessions[i].next_session = self._server_sessions[i + 1]
def close(self, *exc_details):
"""Finish a given inference session, close the underlying connection"""
if not self._closed:
self._exit_server_sessions(self._server_sessions)
self._server_sessions.clear()
self._closed = True
def __exit__(self, *exc_details):
self.close(*exc_details)
def __del__(self):
self.close()
@property
def last_token_id(self) -> Optional[torch.Tensor]: # Backward compatibility with Petals < 2.1.0
return self.output_ids[:, -1:] if self.output_ids is not None else None
@last_token_id.setter
def last_token_id(self, value: torch.Tensor): # Backward compatibility with Petals < 2.1.0
if self.output_ids is None:
raise RuntimeError("Can't override `last_token_id` since the session has not stepped yet")
self.output_ids[:, -1:] = value