From 5f58f006495ca5fe1f96f74dd7e4de6315ff52d7 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 12 Jan 2023 06:49:41 +0300 Subject: [PATCH] Return available cache size in rpc_info() (#191) This PR makes servers return their free cache (in tokens * layers to make it compression-agnostic) To be used when calling make_sequence(optimize="inference") --- src/petals/server/backend.py | 5 +++++ src/petals/server/handler.py | 18 +++++++++++++++++ tests/test_server_stats.py | 39 ++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+) create mode 100644 tests/test_server_stats.py diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 67b03c0..4f9a3bb 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -1,6 +1,7 @@ """Code for serving bloom blocks via hivemind-server""" from __future__ import annotations +from collections import Counter from itertools import chain from typing import Any, Dict, Sequence, Tuple @@ -64,6 +65,10 @@ class TransformerBackend(ModuleBackend): self.kwargs_schema, ) + self.cache_bytes_per_token: Dict[torch.device, int] = Counter() + for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1): + self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8 + def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]: """Create tensor descriptors for attention cache tensors used during inference_step""" head_dim = self.config.hidden_size // self.config.n_head diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 387431a..6ddfb55 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -33,6 +33,8 @@ from petals.utils.misc import DUMMY, is_dummy logger = get_logger(__file__) +CACHE_TOKENS_AVAILABLE = "cache_tokens_available" + class TransformerConnectionHandler(ConnectionHandler): """Handles three request types: forward, backward and forward-incremental (inference)""" @@ -378,6 +380,22 @@ class TransformerConnectionHandler(ConnectionHandler): else: logger.warning(f"{message}: {warning}") + async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo: + """Return metadata about stored block uids and current load""" + rpc_info = {} + if request.uid: + backend = self.module_backends[request.uid] + rpc_info.update(self.module_backends[request.uid].get_info()) + else: + backend = next(iter(self.module_backends.values())) + # not saving keys to rpc_info since user did not request any uid + + cache_bytes_left = max(0, backend.memory_cache.max_size_bytes - backend.memory_cache.current_size_bytes) + if CACHE_TOKENS_AVAILABLE in rpc_info: + raise RuntimeError(f"Block rpc_info dict has a reserved field {CACHE_TOKENS_AVAILABLE} : {rpc_info}") + rpc_info[CACHE_TOKENS_AVAILABLE] = cache_bytes_left // max(backend.cache_bytes_per_token.values()) + return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(rpc_info)) + async def _rpc_forward( *flat_tensors: torch.Tensor, diff --git a/tests/test_server_stats.py b/tests/test_server_stats.py new file mode 100644 index 0000000..0f2b3f0 --- /dev/null +++ b/tests/test_server_stats.py @@ -0,0 +1,39 @@ +import time + +import hivemind +import pytest +import torch +from test_utils import * + +from petals.client import DistributedBloomConfig +from petals.data_structures import UID_DELIMITER +from petals.dht_utils import get_remote_sequence +from petals.server.handler import CACHE_TOKENS_AVAILABLE + + +@pytest.mark.forked +def test_server_info(block_from: int = 22, block_to: int = 24, max_length: int = 100, max_length2: int = 50): + dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) + config = DistributedBloomConfig.from_pretrained(MODEL_NAME) + + blocks1 = get_remote_sequence(dht, block_from, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}") + blocks2 = get_remote_sequence(dht, block_to - 1, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}") + info_before = blocks1.sequence_manager.rpc_info + + with blocks1.inference_session(max_length=max_length) as sess: + sess.step(torch.randn(1, 1, config.hidden_size)) + blocks1.sequence_manager._rpc_info = None # invalidate cache + info_inside = blocks1.sequence_manager.rpc_info + + with blocks2.inference_session(max_length=max_length2) as sess2: + sess2.step(torch.randn(1, 1, config.hidden_size)) + blocks2.sequence_manager._rpc_info = None # invalidate cache + info_inside2 = blocks2.sequence_manager.rpc_info + + time.sleep(0.1) + blocks1.sequence_manager._rpc_info = None # invalidate cache + info_after = blocks1.sequence_manager.rpc_info + + assert info_before[CACHE_TOKENS_AVAILABLE] == info_after[CACHE_TOKENS_AVAILABLE] + assert info_before[CACHE_TOKENS_AVAILABLE] - info_inside[CACHE_TOKENS_AVAILABLE] == max_length * len(blocks1) + assert info_inside[CACHE_TOKENS_AVAILABLE] - info_inside2[CACHE_TOKENS_AVAILABLE] == max_length2 * len(blocks2)