From ac9b5467067735a885f7a4dfad689a2a2dc7f594 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 7 Aug 2023 14:32:51 +0300 Subject: [PATCH] [Refactor] extract block forward, backward and inference into a separate file (#435) This PR does not change any functionality. It merely moves stuff around. List of changes: handler.py/_rpc_forward became block_methods/rpc_forward handler.py/_rpc_backward became block_methods/rpc_backward the math bits of rpc_inference were extracted into block_methods/iterate_rpc_inference --------- Co-authored-by: Your Name Co-authored-by: artek0chumak Co-authored-by: Aleksandr Borzunov --- src/petals/server/block_functions.py | 195 +++++++++++++++++++++++++++ src/petals/server/handler.py | 192 +++----------------------- 2 files changed, 214 insertions(+), 173 deletions(-) create mode 100644 src/petals/server/block_functions.py diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py new file mode 100644 index 0000000..9208deb --- /dev/null +++ b/src/petals/server/block_functions.py @@ -0,0 +1,195 @@ +""" +This module implements server-side computations on served blocks: forward, backward and inference; used by handler +""" +from __future__ import annotations + +from typing import AsyncIterator, Optional, Sequence, Tuple, Union + +import torch +from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor +from hivemind.moe.expert_uid import ExpertUID +from hivemind.proto import runtime_pb2 +from hivemind.utils.nested import nested_flatten + +from petals.data_structures import InferenceMetadata +from petals.server.backend import TransformerBackend +from petals.server.memory_cache import Handle +from petals.server.task_pool import PrioritizedTaskPool +from petals.server.task_prioritizer import TaskPrioritizerBase +from petals.utils.misc import DUMMY, is_dummy + + +async def run_rpc_forward( + *flat_tensors: torch.Tensor, + requested_backends: Sequence[TransformerBackend], + active_adapter: str = "", + prioritizer: TaskPrioritizerBase, + points: int = 0, +) -> torch.Tensor: + """ + Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream + + :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors + :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy) + :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass + :returns: hidden states after the last layer [batch_size, seq_length, hid_size] + """ + hidden_states, prompts = flat_tensors + dtype = requested_backends[0].dtype + # check parse input tensors and cast dtypes + hidden_states = hidden_states.to(dtype) + assert hidden_states.ndim == 3 + if prompts is None or is_dummy(prompts): + prompts = [DUMMY] * len(requested_backends) + else: + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + + # Run a chain of requested backends + for backend, prompt in zip(requested_backends, prompts): + if not is_dummy(prompt): + hidden_states[:, : prompt.shape[1]] += prompt + + assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + priority = prioritizer.prioritize( + hidden_states, points=points / len(requested_backends), backend=backend, type="forward" + ) + (hidden_states,) = await backend.forward_pool.submit_task( + hidden_states, + active_adapter, + priority=priority, + ) + assert isinstance(hidden_states, torch.Tensor) + assert ( + hidden_states.ndim == 3 + ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" + + return hidden_states + + +async def run_rpc_backward( + *flat_tensors: torch.Tensor, + requested_backends: Sequence[TransformerBackend], + active_adapter: str = "", + prioritizer: TaskPrioritizerBase, + points: int = 0, +) -> Union[torch.Tensor, Sequence[torch.Tensor]]: + inputs, grad_outputs, prompts = flat_tensors + # Cast inputs & grad outputs to backend dtype + inputs = inputs.to(requested_backends[0].dtype) + grad_outputs = grad_outputs.to(requested_backends[-1].dtype) + + if prompts is None or is_dummy(prompts): + prompts = [DUMMY] * len(requested_backends) + else: + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + + # Run a forward chain to collect intermediate inputs + # Note that we do not forward for the last module since we do not need its output + inter_inputs = [] + for backend, prompt in zip(requested_backends[:-1], prompts[:-1]): + assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" + if not is_dummy(prompt): + inputs[:, : prompt.shape[1]] += prompt + inter_inputs.append(inputs) + assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + priority = prioritizer.prioritize( + inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward" + ) + (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority) + + assert isinstance(inputs, torch.Tensor) + + if not is_dummy(prompts[-1]): + inputs[:, : prompts[-1].shape[1]] += prompts[-1] + inter_inputs.append(inputs) + + assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" + grad_prompts_reversed = [] + # Run a chain of requested backends + for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))): + assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + priority = prioritizer.prioritize( + inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" + ) + (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority) + + assert isinstance(grad_outputs, torch.Tensor) + if not is_dummy(prompt): + grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) + + grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY + return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape + + +async def iterate_rpc_inference( + requested_uids: Sequence[ExpertUID], + requested_backends: Sequence[TransformerBackend], + active_adapter: Optional[str], + input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]], + cache_handles: Sequence[Sequence[Handle]], + max_length: int, + prioritizer: TaskPrioritizerBase, + points: int, +) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]: + assert len(cache_handles) == len(requested_backends) + + prefix_length = 0 + point_per_piece = points / max_length if max_length > 0 else 0.0 + + async for request, step_metadata in input_iterator: + hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) + + # Cast inputs to backend dtype + hidden_states = hidden_states.to(requested_backends[0].dtype) + assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" + + # parse deep prompts (optional argument) + has_prompts = prompts is not None and not is_dummy(prompts) + if not has_prompts: + prompts = [None] * len(requested_backends) + else: + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts] + + if not (len(requested_backends) == len(prompts)): + raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") + + length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq) + if prefix_length + length_increment > max_length: + raise ValueError( + f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" + f" exceeds pre-allocated maximum {max_length}" + ) + + priority = prioritizer.prioritize( + hidden_states, + hypo_ids, + points=point_per_piece, + requested_uids=requested_uids, + type="inference", + ) + + inference_infos = tuple( + InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) + for uid, handles in zip(requested_uids, cache_handles) + ) + + if hidden_states.numel() == 0: + pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. + # when user wants to pre-allocate cache or check that server *can* allocate that cache + else: + assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor" + (hidden_states,) = await requested_backends[0].inference_pool.submit_task( + hidden_states, hypo_ids, inference_infos, *prompts, priority=priority + ) + + # serialize and send last layer outputs + output_tensors = [ + serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) + for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)) + ] + can_push = not has_prompts + yield output_tensors, can_push + + # prepare for next step + prefix_length += length_increment diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d3776de..b9be294 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -6,7 +6,7 @@ import multiprocessing as mp import sys from enum import Enum from itertools import chain -from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple import torch from async_timeout import timeout @@ -29,12 +29,11 @@ from hivemind.utils.logging import get_logger from hivemind.utils.streaming import split_for_streaming import petals -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID from petals.server.backend import TransformerBackend +from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward from petals.server.memory_cache import Handle -from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase -from petals.utils.misc import DUMMY, is_dummy logger = get_logger(__name__) @@ -147,7 +146,6 @@ class TransformerConnectionHandler(ConnectionHandler): metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") - active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) session_id = metadata.get("session_id") if not requested_uids: @@ -163,78 +161,28 @@ class TransformerConnectionHandler(ConnectionHandler): f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}" ) - point_per_piece = points / max_length if max_length > 0 else 0.0 batch_size = request.tensors[0].size[0] if request.tensors else 1 - prefix_length = 0 async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles: - assert len(cache_handles) == len(requested_backends) - first_request = request background_tasks = set() - async for request, metadata in self._iterate_inference_steps( - first_request, requests, session_id, requested_uids, context + async for output_tensors, can_push in iterate_rpc_inference( + requested_uids=requested_uids, + requested_backends=requested_backends, + active_adapter=self._get_active_adapter(metadata), + input_iterator=self._iterate_inference_steps( + request, requests, session_id, requested_uids, context + ), + cache_handles=cache_handles, + max_length=max_length, + prioritizer=self._prioritizer, + points=points, ): - hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) - - # Cast inputs to backend dtype - hidden_states = hidden_states.to(requested_backends[0].dtype) - assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" - - # parse deep prompts (optional argument) - has_prompts = prompts is not None and not is_dummy(prompts) - if not has_prompts: - prompts = [None] * len(requested_backends) - else: - prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] - prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts] - - if not (len(requested_backends) == len(prompts)): - raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") - - length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq) - if prefix_length + length_increment > max_length: - raise ValueError( - f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" - f" exceeds pre-allocated maximum {max_length}" - ) - - priority = self._prioritizer.prioritize( - hidden_states, - hypo_ids, - points=point_per_piece, - requested_uids=requested_uids, - type="inference", - ) - - inference_infos = tuple( - InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) - for uid, handles in zip(requested_uids, cache_handles) - ) - - if hidden_states.numel() == 0: - pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. - # when user wants to pre-allocate cache or check that server *can* allocate that cache - else: - assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor" - (hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task( - hidden_states, hypo_ids, inference_infos, *prompts, priority=priority - ) - - # serialize and send last layer outputs - output_tensors = [ - serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) - for result, proto in zip( - (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema) - ) - ] - if not has_prompts: + if can_push: task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata)) background_tasks.add(task) # Keep reference until it is done to save it from GC task.add_done_callback(background_tasks.discard) yield runtime_pb2.ExpertResponse(tensors=output_tensors) - # prepare for next step - prefix_length += length_increment finally: self._log_request("rpc_inference.close", requested_uids, context) @@ -408,7 +356,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_forward should have number of points as number or None, got {points}" - hidden_states = await _rpc_forward( + hidden_states = await run_rpc_forward( *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -435,7 +383,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_forward_stream should have number of points as number or None, got {points}" - hidden_states = await _rpc_forward( + hidden_states = await run_rpc_forward( *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -486,7 +434,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_backward should have number of points as number or None, got {points}" - grads = await _rpc_backward( + grads = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -511,7 +459,7 @@ class TransformerConnectionHandler(ConnectionHandler): points, (float, int) ), f"rpc_backward_stream should have number of points as number or None, got {points}" - grads = await _rpc_backward( + grads = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, @@ -621,105 +569,3 @@ class TransformerConnectionHandler(ConnectionHandler): result.update(block_info) return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result)) - - -async def _rpc_forward( - *flat_tensors: torch.Tensor, - requested_backends: Sequence[TransformerBackend], - active_adapter: str = "", - prioritizer: TaskPrioritizerBase, - points: int = 0, -) -> torch.Tensor: - """ - Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream - - :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors - :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy) - :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass - :returns: hidden states after the last layer [batch_size, seq_length, hid_size] - """ - hidden_states, prompts = flat_tensors - dtype = requested_backends[0].dtype - # check parse input tensors and cast dtypes - hidden_states = hidden_states.to(dtype) - assert hidden_states.ndim == 3 - if prompts is None or is_dummy(prompts): - prompts = [DUMMY] * len(requested_backends) - else: - prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] - - # Run a chain of requested backends - for backend, prompt in zip(requested_backends, prompts): - if not is_dummy(prompt): - hidden_states[:, : prompt.shape[1]] += prompt - - assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" - priority = prioritizer.prioritize( - hidden_states, points=points / len(requested_backends), backend=backend, type="forward" - ) - (hidden_states,) = await backend.forward_pool.submit_task( - hidden_states, - active_adapter, - priority=priority, - ) - assert isinstance(hidden_states, torch.Tensor) - assert ( - hidden_states.ndim == 3 - ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" - - return hidden_states - - -async def _rpc_backward( - *flat_tensors: torch.Tensor, - requested_backends: Sequence[TransformerBackend], - active_adapter: str = "", - prioritizer: TaskPrioritizerBase, - points: int = 0, -) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - inputs, grad_outputs, prompts = flat_tensors - # Cast inputs & grad outputs to backend dtype - inputs = inputs.to(requested_backends[0].dtype) - grad_outputs = grad_outputs.to(requested_backends[-1].dtype) - - if prompts is None or is_dummy(prompts): - prompts = [DUMMY] * len(requested_backends) - else: - prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] - - # Run a forward chain to collect intermediate inputs - # Note that we do not forward for the last module since we do not need its output - inter_inputs = [] - for backend, prompt in zip(requested_backends[:-1], prompts[:-1]): - assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" - if not is_dummy(prompt): - inputs[:, : prompt.shape[1]] += prompt - inter_inputs.append(inputs) - assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" - priority = prioritizer.prioritize( - inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward" - ) - (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority) - - assert isinstance(inputs, torch.Tensor) - - if not is_dummy(prompts[-1]): - inputs[:, : prompts[-1].shape[1]] += prompts[-1] - inter_inputs.append(inputs) - - assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" - grad_prompts_reversed = [] - # Run a chain of requested backends - for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))): - assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" - priority = prioritizer.prioritize( - inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" - ) - (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority) - - assert isinstance(grad_outputs, torch.Tensor) - if not is_dummy(prompt): - grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) - - grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY - return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape