diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index c9d0a97..28d3632 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -343,7 +343,7 @@ class InferenceSession: n_prev_spans = len(self._server_sessions) update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks if attempt_no >= 1: - logger.info( + logger.debug( f"Due to a server failure, remote attention caches " f"from block {block_idx} to {update_end} will be regenerated" ) diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index c1e31b4..7c86f14 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -20,6 +20,19 @@ class ServerState(Enum): RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True) +@pydantic.dataclasses.dataclass +class ModelInfo: + num_blocks: int + repository: Optional[str] = None + + def to_dict(self) -> dict: + return dataclasses.asdict(self) + + @classmethod + def from_dict(cls, source: dict): + return cls(**source) + + @pydantic.dataclasses.dataclass class ServerInfo: state: ServerState diff --git a/src/petals/server/server.py b/src/petals/server/server.py index d084f06..ab646a5 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -3,6 +3,7 @@ from __future__ import annotations import gc import math import multiprocessing as mp +import os import random import threading import time @@ -21,7 +22,7 @@ from transformers import PretrainedConfig import petals from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState from petals.server import block_selection from petals.server.backend import TransformerBackend, merge_inference_pools_inplace from petals.server.block_utils import get_block_size, resolve_block_dtype @@ -259,6 +260,9 @@ class Server: using_relay=reachable_via_relay, **throughput_info, ) + self.model_info = ModelInfo(num_blocks=self.block_config.num_hidden_layers) + if not os.path.isdir(converted_model_name_or_path): + self.model_info.repository = "https://huggingface.co/" + converted_model_name_or_path self.balance_quality = balance_quality self.mean_balance_check_period = mean_balance_check_period @@ -330,6 +334,7 @@ class Server: block_config=self.block_config, attn_cache_bytes=self.attn_cache_bytes, server_info=self.server_info, + model_info=self.model_info, block_indices=block_indices, num_handlers=self.num_handlers, min_batch_size=self.min_batch_size, @@ -436,6 +441,7 @@ class ModuleContainer(threading.Thread): block_config: PretrainedConfig, attn_cache_bytes: int, server_info: ServerInfo, + model_info: ModelInfo, block_indices: List[int], min_batch_size: int, max_batch_size: int, @@ -463,6 +469,7 @@ class ModuleContainer(threading.Thread): module_uids, dht, server_info, + model_info, block_config=block_config, memory_cache=memory_cache, update_period=update_period, @@ -671,6 +678,7 @@ class ModuleAnnouncerThread(threading.Thread): module_uids: List[str], dht: DHT, server_info: ServerInfo, + model_info: ModelInfo, *, block_config: PretrainedConfig, memory_cache: MemoryCache, @@ -683,6 +691,7 @@ class ModuleAnnouncerThread(threading.Thread): self.module_uids = module_uids self.dht = dht self.server_info = server_info + self.model_info = model_info self.memory_cache = memory_cache self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype]) @@ -693,10 +702,10 @@ class ModuleAnnouncerThread(threading.Thread): self.trigger = threading.Event() self.max_pinged = max_pinged - dht_prefix = module_uids[0].split(UID_DELIMITER)[0] + self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0] block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids] start_block, end_block = min(block_indices), max(block_indices) + 1 - self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)] + self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)] self.ping_aggregator = PingAggregator(self.dht) def run(self) -> None: @@ -720,6 +729,13 @@ class ModuleAnnouncerThread(threading.Thread): ) if self.server_info.state == ServerState.OFFLINE: break + if not self.dht_prefix.startswith("_"): # Not private + self.dht.store( + key="_petals.models", + subkey=self.dht_prefix, + value=self.model_info.to_dict(), + expiration_time=get_dht_time() + self.expiration, + ) delay = self.update_period - (time.perf_counter() - start_time) if delay < 0: