from __future__ import annotations import asyncio import itertools import time import uuid from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple import torch from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.proto import runtime_pb2 from hivemind.utils import MSGPackSerializer, anext, get_logger, nested_flatten from petals.client.routing import RemoteSequenceManager, maybe_log_traceback from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo from petals.server.handler import TransformerConnectionHandler from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy from petals.utils.packaging import pack_args_kwargs logger = get_logger(__name__) class _ServerInferenceSession: """ An interface to a single multi-step *inference* session for a a set of blocks on a specific server. :note: This class is *not* fault-tolerant out of the box. """ def __init__( self, sequence_manager: RemoteSequenceManager, span: RemoteSpanInfo, span_uids: Sequence[ModuleUID], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator, *block_kwargs, max_length: int, ): self.sequence_manager = sequence_manager self.span, self.span_uids = span, span_uids self.num_blocks = len(span_uids) self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter self.session_id = str(uuid.uuid4()) self.max_length = max_length self.stepped = False self.closed = False self._position = 0 self.history = None # Used in case of server failures to regenerate attention caches on new servers self.next_session = None self.block_kwargs = block_kwargs assert len(self.block_kwargs) in (0, self.num_blocks) @classmethod async def create( cls, sequence_manager: RemoteSequenceManager, span: RemoteSpanInfo, span_uids: Sequence[ModuleUID], *block_kwargs: Dict[str, Any], **kwargs, ) -> _ServerInferenceSession: """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) inputs_queue = asyncio.Queue() outputs_stream = await asyncio.wait_for( stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)), sequence_manager.config.connect_timeout, ) return cls(sequence_manager, span, span_uids, inputs_queue, outputs_stream, *block_kwargs, **kwargs) @staticmethod async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator: while True: next_input_message = await asyncio.wait_for(queue.get(), input_timeout) yield next_input_message if not next_input_message.uid and not next_input_message.tensors: break # this message means "done sending" def step( self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, *, hypo_ids: Optional[torch.Tensor] = None, step_id: str, ) -> torch.Tensor: """ Inference step: send a chunk of input tensors and receive a chunk of outputs :param prompts: optional DEEP prompts, added to a prefix of each layer's outputs, if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size] """ if self.closed: raise Exception("Session is closed, cannot perform step") n_input_tokens = inputs.shape[1] if self.history is None: self.history = inputs elif self.history.shape[1] == self._position: self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1) assert self.history.shape[1] == self._position + n_input_tokens, ( f"Broken input cache: span={self.span} shape={self.history.shape} " f"position={self._position} n_input_tokens={n_input_tokens}" ) if not self.stepped: inputs = self.history # Pass full inputs including prefix block_kwargs = self.block_kwargs else: inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further block_kwargs = [] if prompts is None or is_dummy(prompts): prompts = DUMMY else: assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]" assert prompts.shape[0] == self.num_blocks assert prompts.shape[1] in (inputs.shape[0], 1) assert prompts.shape[2] <= inputs.shape[1] assert prompts.shape[3] == inputs.shape[2] if hypo_ids is None or is_dummy(hypo_ids): hypo_ids = DUMMY_INT64 else: assert len(hypo_ids) == len(inputs) assert hypo_ids.dtype == torch.int64 metadata = dict(session_id=self.session_id, step_id=step_id, max_length=self.max_length) metadata.update( self.sequence_manager.get_request_metadata( self.span.peer_id, "rpc_inference", self.span_uids, inputs, prompts, *block_kwargs, max_length=self.max_length, session_id=self.session_id, step_id=step_id, ) ) if self.stepped and self.sequence_manager.config.use_server_to_server: next_servers = self._collect_next_servers() if next_servers: metadata["next_servers"] = next_servers codecs = self.sequence_manager.get_compression_codecs( self.span.peer_id, "rpc_inference", self.span_uids, inputs, prompts, *block_kwargs ) # serialize inputs and put them into the queue input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs) args_structure = metadata.setdefault("args_structure", args_structure) if codecs is None: codecs = [runtime_pb2.CompressionType.NONE] * len(input_tensors) else: codecs = list(nested_flatten(codecs)) assert len(codecs) == len( input_tensors ), f"got {len(input_tensors)} tensors but {len(codecs)} compression codecs" outputs_serialized = RemoteExpertWorker.run_coroutine( self._step( runtime_pb2.ExpertRequest( uid=CHAIN_DELIMITER.join(self.span_uids), tensors=[ serialize_torch_tensor(tensor, compression) for tensor, compression in zip(input_tensors, codecs) ], metadata=MSGPackSerializer.dumps(metadata), ) ) ) outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors)) assert ( outputs[0].shape == inputs.shape ), f"output activation shape is different from input shape: {outputs[0].shape} != {inputs.shape}" self._position += n_input_tokens return outputs[0] def _collect_next_servers(self) -> List[Tuple[str, str, int, int]]: next_servers = [] session = self.next_session while session is not None and session.stepped: next_servers.append( (session.span.peer_id.to_base58(), session.session_id, session.span.start, session.span.end) ) session = session.next_session return next_servers async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse: """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker""" await self._inputs_queue.put(inputs_serialized) self.stepped = True return await asyncio.wait_for(anext(self._outputs_stream), self.sequence_manager.config.request_timeout) def close(self): """Finish a given inference session, close the underlying connection""" if self._outputs_stream is None: return # already closed RemoteExpertWorker.run_coroutine(self._aclose_stream()) self._outputs_stream = self._inputs_queue = None self.closed = True async def _aclose_stream(self): """Close the inference session. This code is meant to be run inside RemoteExpertWorker""" if self._outputs_stream is None: return # already closed if self.stepped: await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session try: await anext(self._outputs_stream) except StopAsyncIteration: pass def __del__(self): self.close() def __enter__(self): assert not self.closed return self def __exit__(self, *exc_details): self.close() class InferenceSession: """ An interface to a multi-step *inference* session for a sequence of remote transformer blocks """ def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int, *block_kwargs: Dict[str, Any]): self._sequence_manager = sequence_manager self._closed = False self._server_sessions = [] self._position = 0 self._max_length = max_length self.output_ids = None num_blocks = len(self._sequence_manager) if len(block_kwargs) == 1: block_kwargs = block_kwargs * num_blocks assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}" self.block_kwargs = block_kwargs @property def num_blocks(self) -> int: return len(self._sequence_manager) @property def position(self) -> int: return self._position def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]: server_sessions = [] try: for span in chosen_spans: session = RemoteExpertWorker.run_coroutine( _ServerInferenceSession.create( self._sequence_manager, span, self._sequence_manager.block_uids[span.start : span.end], *self.block_kwargs[span.start : span.end], max_length=self._max_length, ) ) server_sessions.append(session) session.__enter__() return server_sessions except: self._exit_server_sessions(server_sessions) raise def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None: for session in reversed(server_sessions): try: session.__exit__(None, None, None) except Exception: logger.debug("Caught exception while closing connection to server:", exc_info=True) def __enter__(self) -> "InferenceSession": assert not self._closed and not self._server_sessions return self def step( self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert not self._closed if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") if prompts is None or is_dummy(prompts): prompts = DUMMY else: assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]" assert prompts.shape[0] == self.num_blocks assert prompts.shape[1] in (inputs.shape[0], 1) assert prompts.shape[2] <= inputs.shape[1] assert prompts.shape[3] == inputs.shape[2] if hypo_ids is None or is_dummy(hypo_ids): hypo_ids = DUMMY_INT64 else: assert len(hypo_ids) == len(inputs) assert hypo_ids.dtype == torch.int64 inputs_device = inputs.device inputs_dtype = inputs.dtype inputs = inputs.cpu() prompts = prompts.cpu() hypo_ids = hypo_ids.cpu() step_id = str(uuid.uuid4()) n_input_tokens = inputs.shape[1] if self._position + n_input_tokens > self._max_length: raise ValueError( f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}" ) server_idx = 0 block_idx = 0 while block_idx < self.num_blocks: for attempt_no in itertools.count(): logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}") server_session = None try: if not self._server_sessions or attempt_no >= 1: self._update_sequence(server_idx, block_idx, attempt_no) server_session = self._server_sessions[server_idx] inputs = server_session.step( inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids=hypo_ids, step_id=step_id, ) server_idx += 1 block_idx = server_session.span.end self._sequence_manager.on_request_success(server_session.span.peer_id) break except Exception as e: self._sequence_manager.on_request_failure( server_session.span.peer_id if server_session is not None else None ) if attempt_no + 1 == self._sequence_manager.config.max_retries: raise delay = self._sequence_manager.get_retry_delay(attempt_no) logger.warning( f"Caught exception when running inference via {server_session.span if server_session is not None else None} " f"(retry in {delay:.0f} sec): {repr(e)}" ) maybe_log_traceback(e) time.sleep(delay) self._position += n_input_tokens outputs = inputs[:, -n_input_tokens:] outputs = outputs.to(device=inputs_device, dtype=inputs_dtype) return outputs def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int): # If there is a failed server session, this code closes it self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1]) n_prev_spans = len(self._server_sessions) update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks if attempt_no >= 1: logger.debug( f"Due to a server failure, remote attention caches " f"from block {block_idx} to {update_end} will be regenerated" ) updated_spans = self._sequence_manager.make_sequence( block_idx, update_end, mode="min_latency", cache_tokens_needed=self._max_length ) # make_sequence() could return a longer sequence updated_spans[-1].end = min(updated_spans[-1].end, update_end) updated_sessions = self._enter_server_sessions(updated_spans) logger.debug(f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers") # If there is a failed span, this code replaces it, otherwise it just adds new ones if server_idx < n_prev_spans: updated_sessions[0].history = self._server_sessions[server_idx].history self._server_sessions[server_idx : server_idx + 1] = updated_sessions # Update links to the next server session for direct server-to-server communication via rpc_push() for i in range(max(server_idx - 1, 0), min(server_idx + len(updated_spans), len(self._server_sessions) - 1)): self._server_sessions[i].next_session = self._server_sessions[i + 1] def close(self, *exc_details): """Finish a given inference session, close the underlying connection""" if not self._closed: self._exit_server_sessions(self._server_sessions) self._server_sessions.clear() self._closed = True def __exit__(self, *exc_details): self.close(*exc_details) def __del__(self): self.close() @property def last_token_id(self) -> Optional[torch.Tensor]: # Backward compatibility with Petals < 2.1.0 return self.output_ids[:, -1:] if self.output_ids is not None else None @last_token_id.setter def last_token_id(self, value: torch.Tensor): # Backward compatibility with Petals < 2.1.0 if self.output_ids is None: raise RuntimeError("Can't override `last_token_id` since the session has not stepped yet") self.output_ids[:, -1:] = value