Return available cache size in rpc_info() (#191)

This PR makes servers return their free cache (in tokens * layers to make it compression-agnostic)

To be used when calling make_sequence(optimize="inference")
pull/195/head^2
justheuristic 1 year ago committed by GitHub
parent 37373a66c3
commit 5f58f00649
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,7 @@
"""Code for serving bloom blocks via hivemind-server"""
from __future__ import annotations
from collections import Counter
from itertools import chain
from typing import Any, Dict, Sequence, Tuple
@ -64,6 +65,10 @@ class TransformerBackend(ModuleBackend):
self.kwargs_schema,
)
self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
"""Create tensor descriptors for attention cache tensors used during inference_step"""
head_dim = self.config.hidden_size // self.config.n_head

@ -33,6 +33,8 @@ from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__file__)
CACHE_TOKENS_AVAILABLE = "cache_tokens_available"
class TransformerConnectionHandler(ConnectionHandler):
"""Handles three request types: forward, backward and forward-incremental (inference)"""
@ -378,6 +380,22 @@ class TransformerConnectionHandler(ConnectionHandler):
else:
logger.warning(f"{message}: {warning}")
async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
"""Return metadata about stored block uids and current load"""
rpc_info = {}
if request.uid:
backend = self.module_backends[request.uid]
rpc_info.update(self.module_backends[request.uid].get_info())
else:
backend = next(iter(self.module_backends.values()))
# not saving keys to rpc_info since user did not request any uid
cache_bytes_left = max(0, backend.memory_cache.max_size_bytes - backend.memory_cache.current_size_bytes)
if CACHE_TOKENS_AVAILABLE in rpc_info:
raise RuntimeError(f"Block rpc_info dict has a reserved field {CACHE_TOKENS_AVAILABLE} : {rpc_info}")
rpc_info[CACHE_TOKENS_AVAILABLE] = cache_bytes_left // max(backend.cache_bytes_per_token.values())
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(rpc_info))
async def _rpc_forward(
*flat_tensors: torch.Tensor,

@ -0,0 +1,39 @@
import time
import hivemind
import pytest
import torch
from test_utils import *
from petals.client import DistributedBloomConfig
from petals.data_structures import UID_DELIMITER
from petals.dht_utils import get_remote_sequence
from petals.server.handler import CACHE_TOKENS_AVAILABLE
@pytest.mark.forked
def test_server_info(block_from: int = 22, block_to: int = 24, max_length: int = 100, max_length2: int = 50):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
blocks1 = get_remote_sequence(dht, block_from, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}")
blocks2 = get_remote_sequence(dht, block_to - 1, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}")
info_before = blocks1.sequence_manager.rpc_info
with blocks1.inference_session(max_length=max_length) as sess:
sess.step(torch.randn(1, 1, config.hidden_size))
blocks1.sequence_manager._rpc_info = None # invalidate cache
info_inside = blocks1.sequence_manager.rpc_info
with blocks2.inference_session(max_length=max_length2) as sess2:
sess2.step(torch.randn(1, 1, config.hidden_size))
blocks2.sequence_manager._rpc_info = None # invalidate cache
info_inside2 = blocks2.sequence_manager.rpc_info
time.sleep(0.1)
blocks1.sequence_manager._rpc_info = None # invalidate cache
info_after = blocks1.sequence_manager.rpc_info
assert info_before[CACHE_TOKENS_AVAILABLE] == info_after[CACHE_TOKENS_AVAILABLE]
assert info_before[CACHE_TOKENS_AVAILABLE] - info_inside[CACHE_TOKENS_AVAILABLE] == max_length * len(blocks1)
assert info_inside[CACHE_TOKENS_AVAILABLE] - info_inside2[CACHE_TOKENS_AVAILABLE] == max_length2 * len(blocks2)
Loading…
Cancel
Save