Fix dtypes in backend schemas (#99)

Currently, the schemas use `torch.float32`, so all inputs and outputs converted to float32 before sending and after receiving on both servers and clients. This creates a huge slowdown for the system.

* This PR makes the schemas use the server's `--torch_dtype` argument (default is `torch.bloat16` for BLOOM-176B)
* an option for client to request a specific output compression. Use case 1: client sends quantized inputs and expects quantized inputs in return. Use case 2: client uses quantization for gradients w.r.t. activations, but keeps grads w.r.t. __prompts__ as is for greater precision.
* a comment explaining the purpose of NoSpendingPolicy - since we likely won't have it for the workshop
* a test with custom compression (janky implementation for testing purposes)

Co-authored-by: justheuristic <justheuristic@gmail.com>
fix-ptune
Alexander Borzunov 1 year ago committed by GitHub
parent 7bd5916744
commit 43ac6016ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,7 +2,7 @@
Utility functions that call RPC forward or backward on a single remote server
"""
import asyncio
from typing import Iterable, List, Sequence, Tuple
from typing import Iterable, List, Optional, Sequence, Tuple
import torch
from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
@ -63,7 +63,13 @@ async def _backward_stream(
async def run_remote_forward(
uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, timeout: float, **kwargs
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
*inputs: torch.Tensor,
timeout: float,
metadata: Optional[bytes] = None,
**kwargs,
) -> Tuple[torch.Tensor, ...]:
"""
Serializes input tensors and calls "rpc_forward" on a remote server.
@ -102,11 +108,8 @@ async def run_remote_forward(
# call RPC on remote server
size = sum(t.element_size() * t.nelement() for t in inputs)
if size > MAX_UNARY_PAYLOAD_SIZE:
deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
else:
deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _forward_unary
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
@ -118,6 +121,7 @@ async def run_remote_backward(
grad_outputs: List[torch.Tensor],
*extra_tensors: torch.Tensor,
timeout: float,
metadata: Optional[bytes] = None,
**kwargs,
) -> Sequence[torch.Tensor]:
"""
@ -146,9 +150,6 @@ async def run_remote_backward(
)
size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
if size > MAX_UNARY_PAYLOAD_SIZE:
deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
else:
deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _backward_unary
deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
return deserialized_grad_inputs

@ -10,6 +10,7 @@ from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
import petals.dht_utils
from petals.client.spending_policy import NoSpendingPolicy
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
from petals.server.handler import TransformerConnectionHandler
@ -43,6 +44,7 @@ class RemoteSequenceManager:
self.timeout, self.min_backoff = timeout, min_backoff
self._rpc_info = None
self.lock_changes = threading.Lock()
self.policy = NoSpendingPolicy()
self.update_()
for uid, info in zip(self.block_uids, self.block_infos):
@ -166,3 +168,12 @@ class RemoteSequenceManager:
if attempt_no == 0:
return 0
return self.min_backoff * 2 ** (attempt_no - 1)
def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[bytes]:
"""
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
:param args: request-specific inputs, typicall block uids and input tensors
:param kwargs: additional request context, such as remote peer ID
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
"""
return MSGPackSerializer.dumps(dict(points=self.policy.get_points(protocol, *args, **kwargs)))

@ -72,8 +72,14 @@ async def sequential_forward(
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
(outputs,) = await run_remote_forward(
span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout
span_uids,
stub,
sequence_manager.rpc_info,
*inputs_and_prompts,
timeout=sequence_manager.timeout,
metadata=metadata,
)
assert isinstance(outputs, torch.Tensor)
@ -146,6 +152,9 @@ async def sequential_backward(
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
metadata = sequence_manager.get_request_metadata(
"rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
)
grad_outputs, *span_grad_prompts = await run_remote_backward(
span_uids,
stub,
@ -154,6 +163,7 @@ async def sequential_backward(
grad_outputs,
prompts[span.start : span.end],
timeout=sequence_manager.timeout,
metadata=metadata,
)
grad_outputs = [grad_outputs]
grad_prompts_reversed.extend(span_grad_prompts)

@ -1,14 +1,17 @@
"""
An interface for exchanging internal "BLOOM points" for higher priority compute requests. NOT IMPLEMENTED.
The intent is to let Petals participants earn points by helping others while idle (e.g. at night), then use these
points to run their own compute experiments faster. See Section 4 of https://arxiv.org/abs/2209.01188 for discussion.
"""
from abc import ABC, abstractmethod
from hivemind.proto.runtime_pb2 import ExpertRequest
class SpendingPolicyBase(ABC):
@abstractmethod
def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
def get_points(self, protocol: str, *args, **kwargs) -> float:
pass
class NoSpendingPolicy(SpendingPolicyBase):
def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
def get_points(self, protocol: str, *args, **kwargs) -> float:
return 0.0

@ -18,7 +18,7 @@ logger = get_logger(__file__)
class TransformerBackend(ModuleBackend):
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: Optional[torch.dtype] = None, **kwargs):
def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
super().__init__(*args, **kwargs)
assert isinstance(self.module, BloomBlock)
self.memory_cache = memory_cache
@ -37,7 +37,9 @@ class TransformerBackend(ModuleBackend):
self.backward_pool = PrioritizedTaskPool(
self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
)
self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
assert backend_dtype is not None
self.dtype = backend_dtype
self.inference_schema = (
(
*self.args_schema,

@ -1,6 +1,6 @@
import asyncio
import contextlib
from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
from typing import Any, AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
import torch
from async_timeout import timeout
@ -202,14 +202,8 @@ class TransformerConnectionHandler(ConnectionHandler):
hidden_states = await _rpc_forward(
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
# Serialize output and respond to client
return runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
]
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
)
async def rpc_forward_stream(
@ -230,22 +224,34 @@ class TransformerConnectionHandler(ConnectionHandler):
hidden_states = await _rpc_forward(
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
assert (
isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
), "hidden_states must be a 3d tensor"
# Serialize the overall output
serialized_output = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
]
# Split the serialized_output for streaming and respond to client
output_split = [
part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
]
async for part in as_aiter(*output_split):
yield runtime_pb2.ExpertResponse(tensors=[part])
for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
yield runtime_pb2.ExpertResponse(tensors=[part])
def _serialize_outputs(
self,
hidden_states: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
metadata: Dict[str, Any],
) -> Sequence[runtime_pb2.Tensor]:
"""Serialize forward outputs using either outputs_schema or custom user-specified schema"""
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
outputs_schema = requested_backends[-1].outputs_schema
if metadata.get("output_compression") is not None:
assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list"
output_compression = tuple(metadata["output_compression"])
assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers"
assert len(output_compression) == 1, f"output_compression tuple should have 1 element"
else:
output_compression = tuple(tensor.compression for tensor in outputs_schema)
return [
serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
for result, proto, compression in zip([hidden_states], outputs_schema, output_compression)
]
async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
async with timeout(self.request_timeout):
@ -265,21 +271,7 @@ class TransformerConnectionHandler(ConnectionHandler):
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
# Modify grad_inputs_schema to support grad_prompts
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
grad_inputs_schema_with_prompts = (
requested_backends[0].args_schema * len(grads),
requested_backends[0].kwargs_schema,
) # TODO generalize
# Serialize the overall grad_input and respond
return runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
]
)
return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
async def rpc_backward_stream(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@ -298,28 +290,38 @@ class TransformerConnectionHandler(ConnectionHandler):
grads = await _rpc_backward(
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
# Modify grad_inputs_schema to support grad_prompts
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
grad_inputs_schema_with_prompts = (
requested_backends[0].args_schema * len(grads),
requested_backends[0].kwargs_schema,
) # TODO generalize
# Serialize the overall grad_inputs
serialized_grad_inputs = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
]
# Split the serialized_grad_inputs for streaming and respond
output_split = [
part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
]
for tensor in self._serialize_grads(grads, requested_backends, metadata):
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
yield runtime_pb2.ExpertResponse(tensors=[part])
async for part in as_aiter(*output_split):
yield runtime_pb2.ExpertResponse(tensors=[part])
def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
def _serialize_grads(
self,
grads: Sequence[torch.Tensor],
requested_backends: Sequence[TransformerBackend],
metadata: Dict[str, Any],
) -> Sequence[runtime_pb2.Tensor]:
"""Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema"""
# Modify grad_inputs_schema to support grad_prompts
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
flat_grads_schema = tuple(
nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema))
) # TODO generalize
if metadata.get("output_compression") is not None:
assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list"
output_compression = tuple(metadata["output_compression"])
assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers"
assert len(output_compression) == len(grads), f"output_compression should have {len(grads)} elements"
else:
output_compression = tuple(tensor.compression for tensor in flat_grads_schema)
return [
serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
for result, proto, compression in zip(grads, flat_grads_schema, output_compression)
]
def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]:
"""Check that the first request to rpc_inference is valid"""
uids = (uids or "").split(CHAIN_DELIMITER)
if not uids:
@ -360,7 +362,7 @@ class TransformerConnectionHandler(ConnectionHandler):
yield handles
def _log_request(self, method: str, uids: List[ModuleUID], context: P2PContext) -> None:
def _log_request(self, method: str, uids: Sequence[ModuleUID], context: P2PContext) -> None:
friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()]
friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids

@ -286,27 +286,27 @@ class ModuleContainer(threading.Thread):
)
if load_in_8bit:
dtype = block.input_layernorm.weight.dtype
block = replace_8bit_linear(block)
block = block.to(device)
for param in block.parameters():
param.requires_grad = False
backend_dtype = block.input_layernorm.weight.dtype if torch_dtype == "auto" else torch_dtype
blocks[module_uid] = TransformerBackend(
module_uid,
block,
memory_cache=memory_cache,
backend_dtype=None if torch_dtype == "auto" else torch_dtype,
backend_dtype=backend_dtype,
args_schema=(
BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
),
),
kwargs_schema={},
outputs_schema=(
BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
),
),
min_batch_size=min_batch_size,

@ -1,11 +1,13 @@
import pytest
import torch
from hivemind import DHT, get_logger, use_hivemind_log_handler
from hivemind import DHT, BatchTensorDescriptor, MSGPackSerializer, get_logger, use_hivemind_log_handler
from hivemind.proto import runtime_pb2
from test_utils import *
from petals.bloom.from_pretrained import load_pretrained_block
from petals.client import RemoteSequential
from petals.client import RemoteSequenceManager, RemoteSequential
from petals.client.remote_model import DistributedBloomConfig
from petals.data_structures import UID_DELIMITER
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -43,6 +45,42 @@ def test_remote_sequential():
(second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad)
# test RemoteSequential with lossy compression
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
lossy_sequential = RemoteSequential(
config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p)
)
test_inputs.grad = None
approx_outputs = lossy_sequential(test_inputs)
(approx_outputs * grad_proj).sum().backward()
assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
assert abs(approx_outputs - full_outputs).mean() < 0.01
assert abs(test_inputs.grad - full_grad).mean() < 0.3
class DummyCustomSequenceManager(RemoteSequenceManager):
"""A sequence manager that compresses inputs/outputs during forward and backward pass."""
@property
def rpc_info(self):
rpc_info = super().rpc_info
dims = (2048, 1024)
compressed_input_schema = BatchTensorDescriptor(dims, compression=runtime_pb2.CompressionType.FLOAT16)
rpc_info["forward_schema"] = (compressed_input_schema,), dict() # (args, kwargs)
return rpc_info
def get_request_metadata(self, protocol: str, *args, **kwargs):
if protocol == "rpc_forward":
return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.FLOAT16,)))
elif protocol == "rpc_backward":
return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.BLOCKWISE_8BIT,)))
else:
assert protocol == "rpc_inference"
return super().get_request_metadata(protocol, *args, **kwargs)
@pytest.mark.forked
def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):

Loading…
Cancel
Save