from __future__ import annotations import asyncio import contextlib import multiprocessing as mp import sys from enum import Enum from itertools import chain from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch from async_timeout import timeout from hivemind import ( DHT, MSGPackSerializer, P2PContext, PeerID, deserialize_tensor_stream, deserialize_torch_tensor, nested_flatten, nested_pack, serialize_torch_tensor, ) from hivemind.moe.server.connection_handler import ConnectionHandler from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE from hivemind.proto import runtime_pb2 from hivemind.utils.asyncio import amap_in_executor, anext from hivemind.utils.logging import get_logger from hivemind.utils.streaming import split_for_streaming import petals from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID from petals.server.backend import TransformerBackend from petals.server.memory_cache import Handle from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase from petals.utils.misc import DUMMY, is_dummy logger = get_logger(__name__) # Fix pickling protobufs, see https://stackoverflow.com/a/74873028 sys.modules["runtime_pb2"] = runtime_pb2 CACHE_TOKENS_AVAILABLE = "cache_tokens_available" class Event(Enum): NEW_SESSION = 0 END_SESSION = 1 PUSH = 2 SHUTDOWN = 3 class TransformerConnectionHandler(ConnectionHandler): """Handles three request types: forward, backward and forward-incremental (inference)""" module_backends: Dict[ModuleUID, TransformerBackend] def __init__( self, dht: DHT, module_backends: Dict[str, TransformerBackend], *, adapters: Optional[Sequence[str]], dht_prefix: str, handler_event_queues: Sequence[mp.Queue], handler_index: int, inference_max_length: int, request_timeout: float, session_timeout: float, step_timeout: float, task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(), ): super().__init__(dht, module_backends) for module_backend in self.module_backends.values(): assert isinstance(module_backend, TransformerBackend) self.dht_prefix = dht_prefix self.adapters = adapters self._handler_event_queues = handler_event_queues self._handler_index = handler_index self._own_event_queue = handler_event_queues[handler_index] self._listener_task: Optional[asyncio.Task] = None self._session_queues: Dict[str, asyncio.Queue] = {} self._session_handlers: Dict[str, int] = {} self.inference_max_length = inference_max_length self.request_timeout = request_timeout self.session_timeout, self.step_timeout = session_timeout, step_timeout self._prioritizer = task_prioritizer async def add_p2p_handlers(self, *args, **kwargs) -> None: if self._listener_task is None: # Start listening to our own event queue before we accept any requests self._listener_task = asyncio.create_task(self._listen_to_event_queue()) await super().add_p2p_handlers(*args, **kwargs) def shutdown(self): if self.is_alive(): self._outer_pipe.send("_shutdown") self._own_event_queue.put((Event.SHUTDOWN, None, None)) self.join(self.shutdown_timeout) if self.is_alive(): logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM") self.terminate() async def _gather_inputs( self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext ) -> Tuple[str, List[torch.Tensor], Dict]: block_uid, metadata = None, None def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]: nonlocal block_uid, metadata if block_uid is None: block_uid = req.uid elif block_uid != req.uid: raise ValueError("Block uids differ in one request") if metadata is None: metadata = MSGPackSerializer.loads(req.metadata) if req.metadata else {} return req.tensors tensors_stream = amap_in_executor(_unpack, requests) inputs = await deserialize_tensor_stream(tensors_stream) assert isinstance(block_uid, str) and isinstance(metadata, dict) return block_uid, inputs, metadata async def rpc_inference( self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext, ) -> AsyncIterator[runtime_pb2.ExpertResponse]: """Compute a single step of inference using attention cache; update attention cache accordingly.""" async with timeout(self.session_timeout): try: request = await asyncio.wait_for(anext(requests), self.step_timeout) except asyncio.TimeoutError: self._log_request("rpc_inference.open", None, context, warning="timed out") return requested_uids = self._check_uids(request.uid) self._log_request("rpc_inference.open", requested_uids, context) try: metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) session_id = metadata.get("session_id") alloc_timeout = float(metadata.get("alloc_timeout", 0.0)) if not requested_uids: raise ValueError("User must specify at least one block for inference, but got none") assert isinstance( max_length, int ), f"rpc_inference metadata must contain int max_length, got {max_length}" assert isinstance( points, (float, int) ), f"rpc_inference should have number of points as a number or None, got {points}" if not 0 <= max_length <= self.inference_max_length: raise ValueError( f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}" ) point_per_piece = points / max_length if max_length > 0 else 0.0 batch_size = request.tensors[0].size[0] if request.tensors else 1 prefix_length = 0 async with self._allocate_cache(requested_backends, batch_size, max_length, alloc_timeout) as cache_handles: assert len(cache_handles) == len(requested_backends) first_request = request background_tasks = set() async for request, metadata in self._iterate_inference_steps( first_request, requests, session_id, requested_uids, context ): hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) # Cast inputs to backend dtype hidden_states = hidden_states.to(requested_backends[0].dtype) assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" # parse deep prompts (optional argument) has_prompts = prompts is not None and not is_dummy(prompts) if not has_prompts: prompts = [None] * len(requested_backends) else: prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts] if not (len(requested_backends) == len(prompts)): raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq) if prefix_length + length_increment > max_length: raise ValueError( f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" f" exceeds pre-allocated maximum {max_length}" ) priority = self._prioritizer.prioritize( hidden_states, hypo_ids, points=point_per_piece, requested_uids=requested_uids, type="inference", ) inference_infos = tuple( InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) for uid, handles in zip(requested_uids, cache_handles) ) if hidden_states.numel() == 0: pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. # when user wants to pre-allocate cache or check that server *can* allocate that cache else: assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor" (hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task( hidden_states, hypo_ids, inference_infos, *prompts, priority=priority ) # serialize and send last layer outputs output_tensors = [ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) for result, proto in zip( (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema) ) ] if not has_prompts: task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata)) background_tasks.add(task) # Keep reference until it is done to save it from GC task.add_done_callback(background_tasks.discard) yield runtime_pb2.ExpertResponse(tensors=output_tensors) # prepare for next step prefix_length += length_increment finally: self._log_request("rpc_inference.close", requested_uids, context) @contextlib.contextmanager def _managed_session(self, session_id: str): assert session_id not in self._session_queues, f"session id {session_id} is not unique" try: self._session_queues[session_id] = asyncio.Queue() self._session_handlers[session_id] = self._handler_index for other_index, other_queue in enumerate(self._handler_event_queues): if other_index != self._handler_index: other_queue.put_nowait((Event.NEW_SESSION, session_id, self._handler_index)) yield finally: self._session_queues.pop(session_id).put_nowait(None) # put None so that the get task will not hang del self._session_handlers[session_id] for other_index, other_queue in enumerate(self._handler_event_queues): if other_index != self._handler_index: other_queue.put_nowait((Event.END_SESSION, session_id, self._handler_index)) def _put_into_session_queue(self, session_id: str, request: runtime_pb2.ExpertRequest): handler_index = self._session_handlers.get(session_id) if handler_index is None: logger.debug(f"Ignored rpc_push to unknown session ID: {session_id}") elif handler_index == self._handler_index: self._session_queues[session_id].put_nowait(request) else: self._handler_event_queues[handler_index].put_nowait((Event.PUSH, session_id, request)) async def _get_from_session_queue(self, session_id: str) -> Optional[runtime_pb2.ExpertRequest]: assert self._session_handlers[session_id] == self._handler_index, "session belongs to another handler" return await self._session_queues[session_id].get() async def _listen_to_event_queue(self): loop = asyncio.get_event_loop() while True: try: event, session_id, payload = await loop.run_in_executor(None, self._own_event_queue.get) if event == Event.SHUTDOWN: break elif event == Event.NEW_SESSION: self._session_handlers[session_id] = payload # index of the handler that owns that session elif event == Event.END_SESSION: self._session_handlers.pop(session_id, None) elif event == Event.PUSH: maybe_session_queue = self._session_queues.get(session_id) if maybe_session_queue is not None: maybe_session_queue.put_nowait(payload) else: raise RuntimeError(f"Unexpected event: {event}") except Exception as e: logger.exception(e) async def _iterate_inference_steps( self, first_request: runtime_pb2.ExpertRequest, requests: AsyncIterator[runtime_pb2.ExpertRequest], session_id: Optional[str], requested_uids: Sequence[str], context: P2PContext, ) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]: processed_step_ids = set() n_pushes = n_late_pushes = 0 request = first_request anext_task = get_push_task = None try: with self._managed_session(session_id) if session_id is not None else contextlib.nullcontext(): while request.tensors: # iterate while user is willing to supply tensors metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} step_id = metadata.get("step_id") pushed = metadata.get("pushed") if pushed: n_pushes += 1 self._log_request("rpc_inference.push", requested_uids, context, debug=f"session received push") if step_id is None or step_id not in processed_step_ids: yield request, metadata if step_id is not None: processed_step_ids.add(step_id) elif pushed: n_late_pushes += 1 self._log_request( "rpc_inference.push", requested_uids, context, warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time", ) # Wait for the next request, coming either from the `requests` iterator or `push_queue` if anext_task is None: anext_task = asyncio.create_task(anext(requests)) if get_push_task is None: if session_id is not None: get_push_task = asyncio.create_task(self._get_from_session_queue(session_id)) else: get_push_task = asyncio.create_task(asyncio.Event().wait()) # Dummy never-ending task done, _ = await asyncio.wait( [anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED ) if anext_task in done: request = await anext_task anext_task = None elif get_push_task in done: request = await get_push_task get_push_task = None else: self._log_request("rpc_inference.step", requested_uids, context, warning="timed out") anext_task.cancel() get_push_task.cancel() return except: logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True) raise async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: """Directly push activation tensors from one server to another""" requested_uids = self._check_uids(request.uid) metadata = MSGPackSerializer.loads(request.metadata) session_id = metadata["session_id"] self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}") self._put_into_session_queue(session_id, request) return runtime_pb2.ExpertResponse() async def _push_outputs( self, request: runtime_pb2.ExpertRequest, serialized_outputs: runtime_pb2.Tensor, metadata: dict ) -> None: try: next_servers = metadata.get("next_servers") if not next_servers: return next_peer_id, next_session_id, next_start, next_end = next_servers[0] next_peer_id = PeerID.from_base58(next_peer_id) next_uid = CHAIN_DELIMITER.join(f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(next_start, next_end)) # Sending hidden states serialized with output_schema to avoid double serialization next_tensors = [serialized_outputs] + request.tensors[1:] next_metadata = metadata.copy() next_metadata.update(session_id=next_session_id, next_servers=next_servers[1:], pushed=True) stub = self.get_stub(self._p2p, next_peer_id) await stub.rpc_push( runtime_pb2.ExpertRequest( uid=next_uid, tensors=next_tensors, metadata=MSGPackSerializer.dumps(next_metadata), ), timeout=self.request_timeout, ) except Exception: logger.debug( f"Failed to push outputs to peer_id={next_peer_id}, session_id={next_session_id}, blocks={next_start}:{next_end}:", exc_info=True, ) async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: async with timeout(self.request_timeout): # Parse request and prepare backends flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors] requested_uids = self._check_uids(request.uid) self._log_request("rpc_forward", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) ), f"rpc_forward should have number of points as number or None, got {points}" hidden_states = await _rpc_forward( *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, ) return runtime_pb2.ExpertResponse( tensors=self._serialize_outputs(hidden_states, requested_backends, metadata) ) async def rpc_forward_stream( self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext ) -> AsyncIterator[runtime_pb2.ExpertRequest]: async with timeout(self.request_timeout): # Parse requests and prepare backends uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context) requested_uids = self._check_uids(uid_str) self._log_request("rpc_forward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) ), f"rpc_forward_stream should have number of points as number or None, got {points}" hidden_states = await _rpc_forward( *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, ) # Split the serialized_output for streaming and respond to client for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata): for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE): yield runtime_pb2.ExpertResponse(tensors=[part]) def _serialize_outputs( self, hidden_states: torch.Tensor, requested_backends: Sequence[TransformerBackend], metadata: Dict[str, Any], ) -> Sequence[runtime_pb2.Tensor]: """Serialize forward outputs using either outputs_schema or custom user-specified schema""" assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor" outputs_schema = requested_backends[-1].outputs_schema if metadata.get("output_compression") is not None: assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list" output_compression = tuple(metadata["output_compression"]) assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers" assert len(output_compression) == 1, f"output_compression tuple should have 1 element" else: output_compression = tuple(tensor.compression for tensor in outputs_schema) return [ serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True) for result, proto, compression in zip([hidden_states], outputs_schema, output_compression) ] async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: async with timeout(self.request_timeout): # Parse requests and prepare backends flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors] requested_uids = self._check_uids(request.uid) self._log_request("rpc_backward", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) ), f"rpc_backward should have number of points as number or None, got {points}" grads = await _rpc_backward( *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, ) return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata)) async def rpc_backward_stream( self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext ) -> AsyncIterator[runtime_pb2.ExpertResponse]: async with timeout(self.request_timeout): uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context) requested_uids = self._check_uids(uids_header) self._log_request("rpc_backward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) ), f"rpc_backward_stream should have number of points as number or None, got {points}" grads = await _rpc_backward( *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, ) # Split the serialized_grad_inputs for streaming and respond for tensor in self._serialize_grads(grads, requested_backends, metadata): for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE): yield runtime_pb2.ExpertResponse(tensors=[part]) def _get_active_adapter(self, metadata: dict) -> str: active_adapter = metadata.get("active_adapter", "") if active_adapter and (active_adapter not in self.adapters): raise KeyError(f"adapter {active_adapter} not found") return active_adapter def _serialize_grads( self, grads: Sequence[torch.Tensor], requested_backends: Sequence[TransformerBackend], metadata: Dict[str, Any], ) -> Sequence[runtime_pb2.Tensor]: """Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema""" # Modify grad_inputs_schema to support grad_prompts assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize flat_grads_schema = tuple( nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema)) ) # TODO generalize if metadata.get("output_compression") is not None: assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list" output_compression = tuple(metadata["output_compression"]) assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers" assert len(output_compression) == len(grads), f"output_compression should have {len(grads)} elements" else: output_compression = tuple(tensor.compression for tensor in flat_grads_schema) return [ serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True) for result, proto, compression in zip(grads, flat_grads_schema, output_compression) ] def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]: """Check that the first request to rpc_inference is valid""" uids = (uids or "").split(CHAIN_DELIMITER) if not uids: raise RuntimeError("User did not provide any uids") for uid in uids: if uid not in self.module_backends: raise RuntimeError(f"Remote peer does not serve {uid}") return tuple(uids) @contextlib.asynccontextmanager async def _allocate_cache( self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int, timeout: Optional[float], ) -> Sequence[Sequence[Handle]]: """ Allocate memory cache for all transformer blocks, return cache handle :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend """ descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends] async with backends[0].memory_cache.allocate_cache(*chain(*descriptors)) as handles: yield nested_pack(handles, descriptors) def _log_request( self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, debug: Optional[str] = None, warning: Optional[str] = None, ) -> None: if uids is not None: friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid] friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()] friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids else: friendly_uids = "n/a" friendly_remote_id = "..." + str(context.remote_id)[-6:] message = f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})" if warning is not None: logger.warning(f"{message}: {warning}") elif debug is not None: logger.debug(f"{message}: {debug}") else: logger.info(message) async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo: """Return metadata about stored block uids and current load""" backend = self.module_backends[request.uid] if request.uid else next(iter(self.module_backends.values())) result = { "version": petals.__version__, "dht_client_mode": self.dht.client_mode, CACHE_TOKENS_AVAILABLE: backend.memory_cache.bytes_left // max(backend.cache_bytes_per_token.values()), } if request.uid: block_info = self.module_backends[request.uid].get_info() common_keys = set(result.keys()) & set(block_info.keys()) if common_keys: raise RuntimeError(f"The block's rpc_info has keys reserved for the server's rpc_info: {common_keys}") result.update(block_info) return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result)) async def _rpc_forward( *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, ) -> torch.Tensor: """ Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy) :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass :returns: hidden states after the last layer [batch_size, seq_length, hid_size] """ hidden_states, prompts = flat_tensors dtype = requested_backends[0].dtype # check parse input tensors and cast dtypes hidden_states = hidden_states.to(dtype) assert hidden_states.ndim == 3 if prompts is None or is_dummy(prompts): prompts = [DUMMY] * len(requested_backends) else: prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] # Run a chain of requested backends for backend, prompt in zip(requested_backends, prompts): if not is_dummy(prompt): hidden_states[:, : prompt.shape[1]] += prompt assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" priority = prioritizer.prioritize( hidden_states, points=points / len(requested_backends), backend=backend, type="forward" ) (hidden_states,) = await backend.forward_pool.submit_task( hidden_states, active_adapter, priority=priority, ) assert isinstance(hidden_states, torch.Tensor) assert ( hidden_states.ndim == 3 ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" return hidden_states async def _rpc_backward( *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: inputs, grad_outputs, prompts = flat_tensors # Cast inputs & grad outputs to backend dtype inputs = inputs.to(requested_backends[0].dtype) grad_outputs = grad_outputs.to(requested_backends[-1].dtype) if prompts is None or is_dummy(prompts): prompts = [DUMMY] * len(requested_backends) else: prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] # Run a forward chain to collect intermediate inputs # Note that we do not forward for the last module since we do not need its output inter_inputs = [] for backend, prompt in zip(requested_backends[:-1], prompts[:-1]): assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" if not is_dummy(prompt): inputs[:, : prompt.shape[1]] += prompt inter_inputs.append(inputs) assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" priority = prioritizer.prioritize( inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward" ) (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority) assert isinstance(inputs, torch.Tensor) if not is_dummy(prompts[-1]): inputs[:, : prompts[-1].shape[1]] += prompts[-1] inter_inputs.append(inputs) assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" grad_prompts_reversed = [] # Run a chain of requested backends for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))): assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" priority = prioritizer.prioritize( inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" ) (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority) assert isinstance(grad_outputs, torch.Tensor) if not is_dummy(prompt): grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape