diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index d14c4a2..e4d36f6 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -7,22 +7,17 @@ import uuid from typing import AsyncIterator, List, Optional, Tuple import torch -from hivemind import ( - MSGPackSerializer, - anext, - deserialize_torch_tensor, - get_logger, - nested_flatten, - serialize_torch_tensor, -) +from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.p2p import P2P from hivemind.proto import runtime_pb2 +from hivemind.utils.tensor_descr import BatchTensorDescriptor from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig, maybe_log_traceback from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo from petals.server.handler import TransformerConnectionHandler -from petals.utils.misc import DUMMY, is_dummy +from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy +from petals.utils.packaging import pack_args_kwargs logger = get_logger(__name__) @@ -128,13 +123,13 @@ class _ServerInferenceSession: assert prompts.shape[3] == inputs.shape[2] if hypo_ids is None or is_dummy(hypo_ids): - hypo_ids = DUMMY + hypo_ids = DUMMY_INT64 else: assert len(hypo_ids) == len(inputs) assert hypo_ids.dtype == torch.int64 # serialize inputs and put them into the queue - input_tensors = (inputs, prompts, hypo_ids) + input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids) request_metadata = dict(session_id=self.session_id, step_id=step_id) if not self.stepped: @@ -144,13 +139,25 @@ class _ServerInferenceSession: if next_servers: request_metadata["next_servers"] = next_servers + request_metadata["args_structure"] = args_structure + + # TODO: make possible to use different compression method for different tensors + server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"] + compression = server_side_inference_schema[0].compression + inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors) + + # TODO: create more explicit way to check servers schema and client's structure + assert len(input_tensors) >= len( + server_side_inference_schema + ), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step" + outputs_serialized = RemoteExpertWorker.run_coroutine( self._step( runtime_pb2.ExpertRequest( uid=self.uid, tensors=[ serialize_torch_tensor(tensor.to(proto.dtype), proto.compression) - for tensor, proto in zip(input_tensors, nested_flatten(self.rpc_info["inference_schema"])) + for tensor, proto in zip(input_tensors, inference_schema) ], metadata=MSGPackSerializer.dumps(request_metadata), ) diff --git a/src/petals/client/remote_forward_backward.py b/src/petals/client/remote_forward_backward.py index a116822..c7cb7c2 100644 --- a/src/petals/client/remote_forward_backward.py +++ b/src/petals/client/remote_forward_backward.py @@ -12,6 +12,7 @@ from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_U from hivemind.proto import runtime_pb2 from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter from hivemind.utils.streaming import split_for_streaming +from hivemind.utils.tensor_descr import BatchTensorDescriptor from petals.client.routing.sequence_manager import SequenceManagerConfig from petals.data_structures import ModuleUID, RPCInfo @@ -84,26 +85,20 @@ async def run_remote_forward( kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]} # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors - forward_inputs = (inputs, kwargs) - - # Modify forward_schema to support prompts + forward_inputs = tuple(nested_flatten((inputs, kwargs))) args_schema, kwargs_schema = rpc_info["forward_schema"] - # TODO: rm this assert when support arbitrary number of input tensors - assert len(args_schema) == 1 and len(inputs) == 2 - forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema) - - if not nested_compare(forward_inputs, forward_schema_with_prompts): - raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?") - - forward_inputs = nested_flatten(forward_inputs) + compression = args_schema[0].compression + forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs) inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs) + # TODO: create more explicit way to check servers schema and client's structure + assert len(inputs) >= len(args_schema) + 1, "Inputs and prompt tensors are necessary for a forward step" # Asynchronous serialization loop = asyncio.get_running_loop() serialized_tensors = await asyncio.gather( *( loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression) - for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts)) + for tensor, proto in zip(inputs, forward_schema) ) ) @@ -119,9 +114,7 @@ async def run_remote_backward( uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, - inputs: torch.Tensor, - grad_outputs: List[torch.Tensor], - *extra_tensors: torch.Tensor, + *inputs_and_grad_outputs: torch.Tensor, config: SequenceManagerConfig, metadata: Optional[bytes] = None, **kwargs, @@ -131,16 +124,14 @@ async def run_remote_backward( Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221 but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here. """ - - grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs) - inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors))) - - # Modify forward_schema to support prompts args_schema, kwargs_schema = rpc_info["forward_schema"] - assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor) - # TODO generalize this - prompts_schema = next(iter(args_schema)) - backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema))) + outputs_schema = rpc_info["outputs_schema"] + compression = args_schema[0].compression + backward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in inputs_and_grad_outputs) + # TODO: create more explicit way to check servers schema and client's structure + assert ( + len(inputs_and_grad_outputs) >= len(args_schema) + len(outputs_schema) + 1 + ), "Inputs, grad_outputs and prompt tensors are necessary for a backward step" # Asynchronous serialization loop = asyncio.get_running_loop() diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index bd15c2b..0c97bb2 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -487,14 +487,21 @@ class RemoteSequenceManager: return 0 return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff) - def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]: + def get_request_metadata( + self, protocol: str, args_structure: Any = None, *args, **kwargs + ) -> Optional[Dict[str, Any]]: """ :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference" + :param args_structure: the structure of flattened tensors from pack_args_kwargs in petals.utils.packaging :param args: request-specific inputs, typically block uids and input tensors :param kwargs: additional request context, such as remote peer ID :returns: msgpack-serialized metadata dict that will be passed alongside a given request """ - return dict(points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter) + return dict( + points=self.policy.get_points(protocol, *args, **kwargs), + active_adapter=self.config.active_adapter, + args_structure=args_structure, + ) def shutdown(self): self._thread.shutdown() diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index ebc56b4..7996ff5 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -16,6 +16,7 @@ from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_ from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo from petals.server.handler import TransformerConnectionHandler from petals.utils.misc import DUMMY, is_dummy +from petals.utils.packaging import pack_args_kwargs logger = get_logger(__name__) @@ -67,15 +68,17 @@ async def sequential_forward( span = sequences.popleft() stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) - inputs_and_prompts = [inputs, prompts[span.start : span.end]] + flat_tensors, args_structure = pack_args_kwargs(inputs, prompts[span.start : span.end]) span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) - metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts) + metadata = sequence_manager.get_request_metadata( + "rpc_forward", args_structure, span_uids, *flat_tensors + ) (outputs,) = await run_remote_forward( span_uids, stub, sequence_manager.rpc_info, - *inputs_and_prompts, + *flat_tensors, config=sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata), ) @@ -149,18 +152,21 @@ async def sequential_backward( inputs = intermediate_inputs.pop() span = forward_sequences.pop() + grad_outputs_cpu = [grad.cpu() for grad in grad_outputs] + flat_tensors, args_structure = pack_args_kwargs( + inputs, *grad_outputs_cpu, prompts[span.start : span.end] + ) + span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) metadata = sequence_manager.get_request_metadata( - "rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id + "rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id ) grad_outputs, *span_grad_prompts = await run_remote_backward( span_uids, stub, sequence_manager.rpc_info, - inputs, - grad_outputs, - prompts[span.start : span.end], + *flat_tensors, config=sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata), ) diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index c1f1d93..f3e512f 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -3,12 +3,13 @@ This module implements server-side computations on served blocks: forward, backw """ from __future__ import annotations -from typing import AsyncIterator, Optional, Sequence, Tuple, Union +from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union import torch from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor from hivemind.moe.expert_uid import ExpertUID from hivemind.proto import runtime_pb2 +from hivemind.utils.logging import get_logger from hivemind.utils.nested import nested_flatten from petals.data_structures import InferenceMetadata @@ -18,6 +19,7 @@ from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_prioritizer import TaskPrioritizerBase from petals.utils.convert_block import QuantType from petals.utils.misc import DUMMY, is_dummy +from petals.utils.packaging import unpack_args_kwargs # We prioritize short inference requests and make them use a *merged* inference pool, # so they are processed without interruptions and extra overheads @@ -25,6 +27,8 @@ from petals.utils.misc import DUMMY, is_dummy MAX_SHORT_INFERENCE_TOKENS = 128 MAX_NF4_SHORT_INFERENCE_TOKENS = 1 +logger = get_logger(__name__) + async def run_rpc_forward( *flat_tensors: torch.Tensor, @@ -32,6 +36,7 @@ async def run_rpc_forward( active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, + args_structure: Any = None, ) -> torch.Tensor: """ Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream @@ -41,7 +46,11 @@ async def run_rpc_forward( :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass :returns: hidden states after the last layer [batch_size, seq_length, hid_size] """ - hidden_states, prompts = flat_tensors + if args_structure is not None: + # TODO: kwargs currently is unused, it can be used later for peft-like adaptation + flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure) + hidden_states, prompts, *_ = flat_tensors + dtype = requested_backends[0].dtype # check parse input tensors and cast dtypes hidden_states = hidden_states.to(dtype) @@ -79,8 +88,13 @@ async def run_rpc_backward( active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, + args_structure: Any = None, ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - inputs, grad_outputs, prompts = flat_tensors + if args_structure is not None: + # TODO: kwargs currently is unused, it can be used later for peft-like adaptation + flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure) + inputs, grad_outputs, prompts, *_ = flat_tensors + # Cast inputs & grad outputs to backend dtype inputs = inputs.to(requested_backends[0].dtype) grad_outputs = grad_outputs.to(requested_backends[-1].dtype) @@ -139,6 +153,7 @@ async def iterate_rpc_inference( prioritizer: TaskPrioritizerBase, points: int, quant_type: QuantType, + args_structure: Any = None, ) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]: assert len(cache_handles) == len(requested_backends) @@ -146,7 +161,12 @@ async def iterate_rpc_inference( point_per_piece = points / max_length if max_length > 0 else 0.0 async for request, step_metadata in input_iterator: - hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) + flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors) + if args_structure is not None: + # TODO: kwargs currently is unused, it can be used later for peft-like adaptation + flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure) + + hidden_states, prompts, hypo_ids, *_ = flat_tensors batch_size, length_increment, _ = hidden_states.shape # Cast inputs to backend dtype diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 00df0d5..c4db8ef 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -151,6 +151,7 @@ class TransformerConnectionHandler(ConnectionHandler): max_length = metadata.get("max_length") points = metadata.get("points", 0) session_id = metadata.get("session_id") + args_structure = metadata.get("args_structure") if not requested_uids: raise ValueError("User must specify at least one block for inference, but got none") assert isinstance( @@ -180,6 +181,7 @@ class TransformerConnectionHandler(ConnectionHandler): prioritizer=self._prioritizer, points=points, quant_type=self.quant_type, + args_structure=args_structure, ): if can_push: task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata)) @@ -356,6 +358,7 @@ class TransformerConnectionHandler(ConnectionHandler): metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) + args_structure = metadata.get("args_structure") assert isinstance( points, (float, int) ), f"rpc_forward should have number of points as number or None, got {points}" @@ -366,6 +369,7 @@ class TransformerConnectionHandler(ConnectionHandler): prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, + args_structure=args_structure, ) return runtime_pb2.ExpertResponse( tensors=self._serialize_outputs(hidden_states, requested_backends, metadata) @@ -383,6 +387,7 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) + args_structure = metadata.get("args_structure") assert isinstance( points, (float, int) ), f"rpc_forward_stream should have number of points as number or None, got {points}" @@ -393,6 +398,7 @@ class TransformerConnectionHandler(ConnectionHandler): prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, + args_structure=args_structure, ) # Split the serialized_output for streaming and respond to client @@ -434,6 +440,7 @@ class TransformerConnectionHandler(ConnectionHandler): metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) + args_structure = metadata.get("args_structure") assert isinstance( points, (float, int) ), f"rpc_backward should have number of points as number or None, got {points}" @@ -444,6 +451,7 @@ class TransformerConnectionHandler(ConnectionHandler): prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, + args_structure=args_structure, ) return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata)) @@ -459,6 +467,7 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) + args_structure = metadata.get("args_structure") assert isinstance( points, (float, int) ), f"rpc_backward_stream should have number of points as number or None, got {points}" @@ -469,6 +478,7 @@ class TransformerConnectionHandler(ConnectionHandler): prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, + args_structure=args_structure, ) # Split the serialized_grad_inputs for streaming and respond for tensor in self._serialize_grads(grads, requested_backends, metadata): diff --git a/src/petals/utils/misc.py b/src/petals/utils/misc.py index 2f67202..d8068e1 100644 --- a/src/petals/utils/misc.py +++ b/src/petals/utils/misc.py @@ -2,6 +2,8 @@ import torch DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters +DUMMY_INT64 = torch.empty(0, dtype=torch.int64) + def is_dummy(tensor: torch.Tensor): return tensor.numel() == 0 diff --git a/src/petals/utils/packaging.py b/src/petals/utils/packaging.py new file mode 100644 index 0000000..c6d9faa --- /dev/null +++ b/src/petals/utils/packaging.py @@ -0,0 +1,49 @@ +from typing import Any, Dict, List, Tuple + +import torch +from hivemind import nested_flatten, nested_pack + +# TODO: Move functions to hivemind + + +def _mark_masked_tensor(index: int) -> bytes: + return b"__T" + str(index).encode() + + +def _is_masked_tensor(item: Any) -> bool: + return isinstance(item, bytes) and item.startswith(b"__T") + + +def _get_tensor_index(item: bytes) -> int: + return int(item[3:]) + + +def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]: + """ + Check the function's arguments and pack all tensors into different flattened lists. + :returns: a flattened list of tensors and args and kwargs, where tensors were masked + """ + masked_flat_values, flat_tensors, tensor_to_index = [], [], {} + for value in nested_flatten((args, kwargs)): + if isinstance(value, torch.Tensor): + tensor_index = tensor_to_index.setdefault(value, len(flat_tensors)) + if tensor_index == len(flat_tensors): + flat_tensors.append(value) + masked_flat_values.append(_mark_masked_tensor(tensor_index)) + else: + masked_flat_values.append(value) + return flat_tensors, nested_pack(masked_flat_values, (args, kwargs)) + + +def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any): + """ + Restore arguments after `pack_args_kwargs` function. + :returns: list of args and dict of kwargs + """ + return nested_pack( + ( + value if not _is_masked_tensor(value) else flat_tensors[_get_tensor_index(value)] + for value in nested_flatten(args_structure) + ), + args_structure, + ) diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index f75281e..378d25e 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -3,10 +3,13 @@ import sys import pytest import torch +from hivemind import nested_compare, nested_flatten from petals import AutoDistributedConfig from petals.server.throughput import measure_compute_rps from petals.utils.convert_block import QuantType +from petals.utils.misc import DUMMY, is_dummy +from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs from test_utils import MODEL_NAME @@ -44,3 +47,29 @@ def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: boo inference=inference, ) assert isinstance(compute_rps, float) and compute_rps > 0 + + +@pytest.mark.forked +def test_pack_inputs(): + x = torch.ones(3) + y = torch.arange(5) + z = DUMMY + + args = (x, z, None, (y, y), z) + kwargs = dict(foo=torch.zeros(1, 1), bar={"l": "i", "g": "h", "t": ("y", "e", "a", "r", torch.rand(1), x, y)}) + + flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs) + + assert len(flat_tensors) == 5 + assert all(isinstance(t, torch.Tensor) for t in flat_tensors) + + restored_args, restored_kwargs = unpack_args_kwargs(flat_tensors, args_structure) + + assert len(restored_args) == len(args) + assert torch.all(restored_args[0] == x).item() and restored_args[2] is None + assert nested_compare((args, kwargs), (restored_args, restored_kwargs)) + for original, restored in zip(nested_flatten((args, kwargs)), nested_flatten((restored_args, restored_kwargs))): + if isinstance(original, torch.Tensor): + assert torch.all(original == restored) + else: + assert original == restored