[Refactor] extract block forward, backward and inference into a separate file (#435)

This PR does not change any functionality. It merely moves stuff around.
List of changes:

handler.py/_rpc_forward became block_methods/rpc_forward
handler.py/_rpc_backward became block_methods/rpc_backward
the math bits of rpc_inference were extracted into block_methods/iterate_rpc_inference

---------

Co-authored-by: Your Name <you@example.com>
Co-authored-by: artek0chumak <artek.chumak@gmail.com>
Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
pull/431/head^2
justheuristic 9 months ago committed by GitHub
parent 593d980ad8
commit ac9b546706
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,195 @@
"""
This module implements server-side computations on served blocks: forward, backward and inference; used by handler
"""
from __future__ import annotations
from typing import AsyncIterator, Optional, Sequence, Tuple, Union
import torch
from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.moe.expert_uid import ExpertUID
from hivemind.proto import runtime_pb2
from hivemind.utils.nested import nested_flatten
from petals.data_structures import InferenceMetadata
from petals.server.backend import TransformerBackend
from petals.server.memory_cache import Handle
from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import TaskPrioritizerBase
from petals.utils.misc import DUMMY, is_dummy
async def run_rpc_forward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> torch.Tensor:
"""
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
hidden_states, prompts = flat_tensors
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
assert hidden_states.ndim == 3
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a chain of requested backends
for backend, prompt in zip(requested_backends, prompts):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
)
(hidden_states,) = await backend.forward_pool.submit_task(
hidden_states,
active_adapter,
priority=priority,
)
assert isinstance(hidden_states, torch.Tensor)
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
return hidden_states
async def run_rpc_backward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
inputs, grad_outputs, prompts = flat_tensors
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = []
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
if not is_dummy(prompt):
inputs[:, : prompt.shape[1]] += prompt
inter_inputs.append(inputs)
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
)
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
assert isinstance(inputs, torch.Tensor)
if not is_dummy(prompts[-1]):
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
inter_inputs.append(inputs)
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
grad_prompts_reversed = []
# Run a chain of requested backends
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
)
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
async def iterate_rpc_inference(
requested_uids: Sequence[ExpertUID],
requested_backends: Sequence[TransformerBackend],
active_adapter: Optional[str],
input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],
cache_handles: Sequence[Sequence[Handle]],
max_length: int,
prioritizer: TaskPrioritizerBase,
points: int,
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
assert len(cache_handles) == len(requested_backends)
prefix_length = 0
point_per_piece = points / max_length if max_length > 0 else 0.0
async for request, step_metadata in input_iterator:
hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
# Cast inputs to backend dtype
hidden_states = hidden_states.to(requested_backends[0].dtype)
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
# parse deep prompts (optional argument)
has_prompts = prompts is not None and not is_dummy(prompts)
if not has_prompts:
prompts = [None] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
if prefix_length + length_increment > max_length:
raise ValueError(
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
f" exceeds pre-allocated maximum {max_length}"
)
priority = prioritizer.prioritize(
hidden_states,
hypo_ids,
points=point_per_piece,
requested_uids=requested_uids,
type="inference",
)
inference_infos = tuple(
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
for uid, handles in zip(requested_uids, cache_handles)
)
if hidden_states.numel() == 0:
pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
else:
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
)
# serialize and send last layer outputs
output_tensors = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
]
can_push = not has_prompts
yield output_tensors, can_push
# prepare for next step
prefix_length += length_increment

@ -6,7 +6,7 @@ import multiprocessing as mp
import sys
from enum import Enum
from itertools import chain
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple
import torch
from async_timeout import timeout
@ -29,12 +29,11 @@ from hivemind.utils.logging import get_logger
from hivemind.utils.streaming import split_for_streaming
import petals
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID
from petals.server.backend import TransformerBackend
from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward
from petals.server.memory_cache import Handle
from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__name__)
@ -147,7 +146,6 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
max_length = metadata.get("max_length")
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
session_id = metadata.get("session_id")
if not requested_uids:
@ -163,78 +161,28 @@ class TransformerConnectionHandler(ConnectionHandler):
f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}"
)
point_per_piece = points / max_length if max_length > 0 else 0.0
batch_size = request.tensors[0].size[0] if request.tensors else 1
prefix_length = 0
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
assert len(cache_handles) == len(requested_backends)
first_request = request
background_tasks = set()
async for request, metadata in self._iterate_inference_steps(
first_request, requests, session_id, requested_uids, context
async for output_tensors, can_push in iterate_rpc_inference(
requested_uids=requested_uids,
requested_backends=requested_backends,
active_adapter=self._get_active_adapter(metadata),
input_iterator=self._iterate_inference_steps(
request, requests, session_id, requested_uids, context
),
cache_handles=cache_handles,
max_length=max_length,
prioritizer=self._prioritizer,
points=points,
):
hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
# Cast inputs to backend dtype
hidden_states = hidden_states.to(requested_backends[0].dtype)
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
# parse deep prompts (optional argument)
has_prompts = prompts is not None and not is_dummy(prompts)
if not has_prompts:
prompts = [None] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
if prefix_length + length_increment > max_length:
raise ValueError(
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
f" exceeds pre-allocated maximum {max_length}"
)
priority = self._prioritizer.prioritize(
hidden_states,
hypo_ids,
points=point_per_piece,
requested_uids=requested_uids,
type="inference",
)
inference_infos = tuple(
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
for uid, handles in zip(requested_uids, cache_handles)
)
if hidden_states.numel() == 0:
pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
else:
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
(hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
)
# serialize and send last layer outputs
output_tensors = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(
(hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
)
]
if not has_prompts:
if can_push:
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
background_tasks.add(task) # Keep reference until it is done to save it from GC
task.add_done_callback(background_tasks.discard)
yield runtime_pb2.ExpertResponse(tensors=output_tensors)
# prepare for next step
prefix_length += length_increment
finally:
self._log_request("rpc_inference.close", requested_uids, context)
@ -408,7 +356,7 @@ class TransformerConnectionHandler(ConnectionHandler):
points, (float, int)
), f"rpc_forward should have number of points as number or None, got {points}"
hidden_states = await _rpc_forward(
hidden_states = await run_rpc_forward(
*flat_inputs,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
@ -435,7 +383,7 @@ class TransformerConnectionHandler(ConnectionHandler):
points, (float, int)
), f"rpc_forward_stream should have number of points as number or None, got {points}"
hidden_states = await _rpc_forward(
hidden_states = await run_rpc_forward(
*flat_inputs,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
@ -486,7 +434,7 @@ class TransformerConnectionHandler(ConnectionHandler):
points, (float, int)
), f"rpc_backward should have number of points as number or None, got {points}"
grads = await _rpc_backward(
grads = await run_rpc_backward(
*flat_tensors,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
@ -511,7 +459,7 @@ class TransformerConnectionHandler(ConnectionHandler):
points, (float, int)
), f"rpc_backward_stream should have number of points as number or None, got {points}"
grads = await _rpc_backward(
grads = await run_rpc_backward(
*flat_tensors,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
@ -621,105 +569,3 @@ class TransformerConnectionHandler(ConnectionHandler):
result.update(block_info)
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
async def _rpc_forward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> torch.Tensor:
"""
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
hidden_states, prompts = flat_tensors
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
assert hidden_states.ndim == 3
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a chain of requested backends
for backend, prompt in zip(requested_backends, prompts):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
)
(hidden_states,) = await backend.forward_pool.submit_task(
hidden_states,
active_adapter,
priority=priority,
)
assert isinstance(hidden_states, torch.Tensor)
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
return hidden_states
async def _rpc_backward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
inputs, grad_outputs, prompts = flat_tensors
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = []
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
if not is_dummy(prompt):
inputs[:, : prompt.shape[1]] += prompt
inter_inputs.append(inputs)
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
)
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
assert isinstance(inputs, torch.Tensor)
if not is_dummy(prompts[-1]):
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
inter_inputs.append(inputs)
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
grad_prompts_reversed = []
# Run a chain of requested backends
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
)
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape

Loading…
Cancel
Save