From 43ac6016accd2c397465880fbeab5f5d83ae4429 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 30 Nov 2022 18:40:43 +0400 Subject: [PATCH] Fix dtypes in backend schemas (#99) Currently, the schemas use `torch.float32`, so all inputs and outputs converted to float32 before sending and after receiving on both servers and clients. This creates a huge slowdown for the system. * This PR makes the schemas use the server's `--torch_dtype` argument (default is `torch.bloat16` for BLOOM-176B) * an option for client to request a specific output compression. Use case 1: client sends quantized inputs and expects quantized inputs in return. Use case 2: client uses quantization for gradients w.r.t. activations, but keeps grads w.r.t. __prompts__ as is for greater precision. * a comment explaining the purpose of NoSpendingPolicy - since we likely won't have it for the workshop * a test with custom compression (janky implementation for testing purposes) Co-authored-by: justheuristic --- src/petals/client/remote_forward_backward.py | 25 ++-- src/petals/client/sequence_manager.py | 11 ++ src/petals/client/sequential_autograd.py | 12 +- src/petals/client/spending_policy.py | 11 +- src/petals/server/backend.py | 6 +- src/petals/server/handler.py | 118 ++++++++++--------- src/petals/server/server.py | 8 +- tests/test_remote_sequential.py | 42 ++++++- 8 files changed, 150 insertions(+), 83 deletions(-) diff --git a/src/petals/client/remote_forward_backward.py b/src/petals/client/remote_forward_backward.py index 24ef895..542ad9c 100644 --- a/src/petals/client/remote_forward_backward.py +++ b/src/petals/client/remote_forward_backward.py @@ -2,7 +2,7 @@ Utility functions that call RPC forward or backward on a single remote server """ import asyncio -from typing import Iterable, List, Sequence, Tuple +from typing import Iterable, List, Optional, Sequence, Tuple import torch from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor @@ -63,7 +63,13 @@ async def _backward_stream( async def run_remote_forward( - uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, timeout: float, **kwargs + uid: ModuleUID, + stub: StubBase, + rpc_info: RPCInfo, + *inputs: torch.Tensor, + timeout: float, + metadata: Optional[bytes] = None, + **kwargs, ) -> Tuple[torch.Tensor, ...]: """ Serializes input tensors and calls "rpc_forward" on a remote server. @@ -102,11 +108,8 @@ async def run_remote_forward( # call RPC on remote server size = sum(t.element_size() * t.nelement() for t in inputs) - if size > MAX_UNARY_PAYLOAD_SIZE: - deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, timeout, **kwargs) - else: - deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, timeout, **kwargs) - + forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _forward_unary + deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs) return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"]) @@ -118,6 +121,7 @@ async def run_remote_backward( grad_outputs: List[torch.Tensor], *extra_tensors: torch.Tensor, timeout: float, + metadata: Optional[bytes] = None, **kwargs, ) -> Sequence[torch.Tensor]: """ @@ -146,9 +150,6 @@ async def run_remote_backward( ) size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs) - if size > MAX_UNARY_PAYLOAD_SIZE: - deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, timeout, **kwargs) - else: - deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, timeout, **kwargs) - + backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _backward_unary + deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs) return deserialized_grad_inputs diff --git a/src/petals/client/sequence_manager.py b/src/petals/client/sequence_manager.py index e36a557..7522995 100644 --- a/src/petals/client/sequence_manager.py +++ b/src/petals/client/sequence_manager.py @@ -10,6 +10,7 @@ from hivemind.proto import runtime_pb2 from hivemind.utils.logging import get_logger, use_hivemind_log_handler import petals.dht_utils +from petals.client.spending_policy import NoSpendingPolicy from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState from petals.server.handler import TransformerConnectionHandler @@ -43,6 +44,7 @@ class RemoteSequenceManager: self.timeout, self.min_backoff = timeout, min_backoff self._rpc_info = None self.lock_changes = threading.Lock() + self.policy = NoSpendingPolicy() self.update_() for uid, info in zip(self.block_uids, self.block_infos): @@ -166,3 +168,12 @@ class RemoteSequenceManager: if attempt_no == 0: return 0 return self.min_backoff * 2 ** (attempt_no - 1) + + def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[bytes]: + """ + :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference" + :param args: request-specific inputs, typicall block uids and input tensors + :param kwargs: additional request context, such as remote peer ID + :returns: msgpack-serialized metadata dict that will be passed alongside a given request + """ + return MSGPackSerializer.dumps(dict(points=self.policy.get_points(protocol, *args, **kwargs))) diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 8fcbffd..7dc7116 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -72,8 +72,14 @@ async def sequential_forward( inputs_and_prompts = [inputs, prompts[span.start : span.end]] span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) + metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts) (outputs,) = await run_remote_forward( - span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout + span_uids, + stub, + sequence_manager.rpc_info, + *inputs_and_prompts, + timeout=sequence_manager.timeout, + metadata=metadata, ) assert isinstance(outputs, torch.Tensor) @@ -146,6 +152,9 @@ async def sequential_backward( span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id) + metadata = sequence_manager.get_request_metadata( + "rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id + ) grad_outputs, *span_grad_prompts = await run_remote_backward( span_uids, stub, @@ -154,6 +163,7 @@ async def sequential_backward( grad_outputs, prompts[span.start : span.end], timeout=sequence_manager.timeout, + metadata=metadata, ) grad_outputs = [grad_outputs] grad_prompts_reversed.extend(span_grad_prompts) diff --git a/src/petals/client/spending_policy.py b/src/petals/client/spending_policy.py index 770d25a..0af3db7 100644 --- a/src/petals/client/spending_policy.py +++ b/src/petals/client/spending_policy.py @@ -1,14 +1,17 @@ +""" +An interface for exchanging internal "BLOOM points" for higher priority compute requests. NOT IMPLEMENTED. +The intent is to let Petals participants earn points by helping others while idle (e.g. at night), then use these + points to run their own compute experiments faster. See Section 4 of https://arxiv.org/abs/2209.01188 for discussion. +""" from abc import ABC, abstractmethod -from hivemind.proto.runtime_pb2 import ExpertRequest - class SpendingPolicyBase(ABC): @abstractmethod - def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float: + def get_points(self, protocol: str, *args, **kwargs) -> float: pass class NoSpendingPolicy(SpendingPolicyBase): - def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float: + def get_points(self, protocol: str, *args, **kwargs) -> float: return 0.0 diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 7b62301..2f7ace9 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -18,7 +18,7 @@ logger = get_logger(__file__) class TransformerBackend(ModuleBackend): """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward""" - def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: Optional[torch.dtype] = None, **kwargs): + def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs): super().__init__(*args, **kwargs) assert isinstance(self.module, BloomBlock) self.memory_cache = memory_cache @@ -37,7 +37,9 @@ class TransformerBackend(ModuleBackend): self.backward_pool = PrioritizedTaskPool( self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward" ) - self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype + + assert backend_dtype is not None + self.dtype = backend_dtype self.inference_schema = ( ( *self.args_schema, diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 0d34e7f..3c57fff 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -1,6 +1,6 @@ import asyncio import contextlib -from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union +from typing import Any, AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union import torch from async_timeout import timeout @@ -202,14 +202,8 @@ class TransformerConnectionHandler(ConnectionHandler): hidden_states = await _rpc_forward( *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points ) - assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3 - - # Serialize output and respond to client return runtime_pb2.ExpertResponse( - tensors=[ - serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) - for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)) - ] + tensors=self._serialize_outputs(hidden_states, requested_backends, metadata) ) async def rpc_forward_stream( @@ -230,22 +224,34 @@ class TransformerConnectionHandler(ConnectionHandler): hidden_states = await _rpc_forward( *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points ) - assert ( - isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3 - ), "hidden_states must be a 3d tensor" - - # Serialize the overall output - serialized_output = [ - serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) - for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)) - ] # Split the serialized_output for streaming and respond to client - output_split = [ - part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) - ] - async for part in as_aiter(*output_split): - yield runtime_pb2.ExpertResponse(tensors=[part]) + for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata): + for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE): + yield runtime_pb2.ExpertResponse(tensors=[part]) + + def _serialize_outputs( + self, + hidden_states: torch.Tensor, + requested_backends: Sequence[TransformerBackend], + metadata: Dict[str, Any], + ) -> Sequence[runtime_pb2.Tensor]: + """Serialize forward outputs using either outputs_schema or custom user-specified schema""" + assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor" + outputs_schema = requested_backends[-1].outputs_schema + + if metadata.get("output_compression") is not None: + assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list" + output_compression = tuple(metadata["output_compression"]) + assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers" + assert len(output_compression) == 1, f"output_compression tuple should have 1 element" + else: + output_compression = tuple(tensor.compression for tensor in outputs_schema) + + return [ + serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True) + for result, proto, compression in zip([hidden_states], outputs_schema, output_compression) + ] async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: async with timeout(self.request_timeout): @@ -265,21 +271,7 @@ class TransformerConnectionHandler(ConnectionHandler): *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points ) - # Modify grad_inputs_schema to support grad_prompts - assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize - - grad_inputs_schema_with_prompts = ( - requested_backends[0].args_schema * len(grads), - requested_backends[0].kwargs_schema, - ) # TODO generalize - - # Serialize the overall grad_input and respond - return runtime_pb2.ExpertResponse( - tensors=[ - serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) - for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts)) - ] - ) + return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata)) async def rpc_backward_stream( self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext @@ -298,28 +290,38 @@ class TransformerConnectionHandler(ConnectionHandler): grads = await _rpc_backward( *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points ) - - # Modify grad_inputs_schema to support grad_prompts - assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize - grad_inputs_schema_with_prompts = ( - requested_backends[0].args_schema * len(grads), - requested_backends[0].kwargs_schema, - ) # TODO generalize - - # Serialize the overall grad_inputs - serialized_grad_inputs = [ - serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) - for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts)) - ] # Split the serialized_grad_inputs for streaming and respond - output_split = [ - part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) - ] + for tensor in self._serialize_grads(grads, requested_backends, metadata): + for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE): + yield runtime_pb2.ExpertResponse(tensors=[part]) - async for part in as_aiter(*output_split): - yield runtime_pb2.ExpertResponse(tensors=[part]) - - def _check_uids(self, uids: str) -> Sequence[ModuleUID]: + def _serialize_grads( + self, + grads: Sequence[torch.Tensor], + requested_backends: Sequence[TransformerBackend], + metadata: Dict[str, Any], + ) -> Sequence[runtime_pb2.Tensor]: + """Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema""" + # Modify grad_inputs_schema to support grad_prompts + assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize + flat_grads_schema = tuple( + nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema)) + ) # TODO generalize + + if metadata.get("output_compression") is not None: + assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list" + output_compression = tuple(metadata["output_compression"]) + assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers" + assert len(output_compression) == len(grads), f"output_compression should have {len(grads)} elements" + else: + output_compression = tuple(tensor.compression for tensor in flat_grads_schema) + + return [ + serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True) + for result, proto, compression in zip(grads, flat_grads_schema, output_compression) + ] + + def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]: """Check that the first request to rpc_inference is valid""" uids = (uids or "").split(CHAIN_DELIMITER) if not uids: @@ -360,7 +362,7 @@ class TransformerConnectionHandler(ConnectionHandler): yield handles - def _log_request(self, method: str, uids: List[ModuleUID], context: P2PContext) -> None: + def _log_request(self, method: str, uids: Sequence[ModuleUID], context: P2PContext) -> None: friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid] friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()] friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 11dd581..f3db2e6 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -286,27 +286,27 @@ class ModuleContainer(threading.Thread): ) if load_in_8bit: - dtype = block.input_layernorm.weight.dtype block = replace_8bit_linear(block) block = block.to(device) for param in block.parameters(): param.requires_grad = False + backend_dtype = block.input_layernorm.weight.dtype if torch_dtype == "auto" else torch_dtype blocks[module_uid] = TransformerBackend( module_uid, block, memory_cache=memory_cache, - backend_dtype=None if torch_dtype == "auto" else torch_dtype, + backend_dtype=backend_dtype, args_schema=( BatchTensorDescriptor( - 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression + 1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression ), ), kwargs_schema={}, outputs_schema=( BatchTensorDescriptor( - 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression + 1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression ), ), min_batch_size=min_batch_size, diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 9c92f20..4237188 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -1,11 +1,13 @@ import pytest import torch -from hivemind import DHT, get_logger, use_hivemind_log_handler +from hivemind import DHT, BatchTensorDescriptor, MSGPackSerializer, get_logger, use_hivemind_log_handler +from hivemind.proto import runtime_pb2 from test_utils import * from petals.bloom.from_pretrained import load_pretrained_block -from petals.client import RemoteSequential +from petals.client import RemoteSequenceManager, RemoteSequential from petals.client.remote_model import DistributedBloomConfig +from petals.data_structures import UID_DELIMITER use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) @@ -43,6 +45,42 @@ def test_remote_sequential(): (second_half_outputs * grad_proj).sum().backward() assert torch.allclose(test_inputs.grad, full_grad) + # test RemoteSequential with lossy compression + block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] + lossy_sequential = RemoteSequential( + config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p) + ) + + test_inputs.grad = None + approx_outputs = lossy_sequential(test_inputs) + (approx_outputs * grad_proj).sum().backward() + + assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used" + assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used" + assert abs(approx_outputs - full_outputs).mean() < 0.01 + assert abs(test_inputs.grad - full_grad).mean() < 0.3 + + +class DummyCustomSequenceManager(RemoteSequenceManager): + """A sequence manager that compresses inputs/outputs during forward and backward pass.""" + + @property + def rpc_info(self): + rpc_info = super().rpc_info + dims = (2048, 1024) + compressed_input_schema = BatchTensorDescriptor(dims, compression=runtime_pb2.CompressionType.FLOAT16) + rpc_info["forward_schema"] = (compressed_input_schema,), dict() # (args, kwargs) + return rpc_info + + def get_request_metadata(self, protocol: str, *args, **kwargs): + if protocol == "rpc_forward": + return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.FLOAT16,))) + elif protocol == "rpc_backward": + return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.BLOCKWISE_8BIT,))) + else: + assert protocol == "rpc_inference" + return super().get_request_metadata(protocol, *args, **kwargs) + @pytest.mark.forked def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):