From 6b12b0d050f73826f6f66481d40146370e2bebbb Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 13 Jan 2023 07:46:10 +0400 Subject: [PATCH] Report server version and dht.client_mode in rpc_info(), check for updates on startup (#209) This PR: 1. Shows the current Petals version and checks for updates on startup. 2. Reports the current version and DHT mode in `rpc_info()`, so it can be shown on http://health.petals.ml or used on clients for efficient routing. --- setup.cfg | 1 + src/petals/cli/run_server.py | 3 +++ src/petals/server/handler.py | 27 ++++++++++++++++----------- src/petals/server/server.py | 2 +- src/petals/utils/version.py | 26 ++++++++++++++++++++++++++ 5 files changed, 47 insertions(+), 12 deletions(-) create mode 100644 src/petals/utils/version.py diff --git a/setup.cfg b/setup.cfg index ad197d1..73bb117 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,7 @@ install_requires = humanfriendly async-timeout>=4.0.2 cpufeature>=0.2.0 + packaging>=23.0 [options.extras_require] dev = diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index ff68966..135720d 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -8,6 +8,7 @@ from humanfriendly import parse_size from petals.constants import PUBLIC_INITIAL_PEERS from petals.server.server import Server +from petals.utils.version import validate_version logger = get_logger(__file__) @@ -193,6 +194,8 @@ def main(): if load_in_8bit is not None: args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"] + validate_version() + server = Server( **args, host_maddrs=host_maddrs, diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 6ddfb55..3c889f6 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -24,6 +24,7 @@ from hivemind.utils.asyncio import amap_in_executor, anext from hivemind.utils.logging import get_logger from hivemind.utils.streaming import split_for_streaming +import petals from petals.data_structures import CHAIN_DELIMITER, InferenceMetadata, ModuleUID from petals.server.backend import TransformerBackend from petals.server.memory_cache import Handle @@ -382,19 +383,23 @@ class TransformerConnectionHandler(ConnectionHandler): async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo: """Return metadata about stored block uids and current load""" - rpc_info = {} - if request.uid: - backend = self.module_backends[request.uid] - rpc_info.update(self.module_backends[request.uid].get_info()) - else: - backend = next(iter(self.module_backends.values())) - # not saving keys to rpc_info since user did not request any uid + backend = self.module_backends[request.uid] if request.uid else next(iter(self.module_backends.values())) cache_bytes_left = max(0, backend.memory_cache.max_size_bytes - backend.memory_cache.current_size_bytes) - if CACHE_TOKENS_AVAILABLE in rpc_info: - raise RuntimeError(f"Block rpc_info dict has a reserved field {CACHE_TOKENS_AVAILABLE} : {rpc_info}") - rpc_info[CACHE_TOKENS_AVAILABLE] = cache_bytes_left // max(backend.cache_bytes_per_token.values()) - return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(rpc_info)) + result = { + "version": petals.__version__, + "dht_client_mode": self.dht.client_mode, + CACHE_TOKENS_AVAILABLE: cache_bytes_left // max(backend.cache_bytes_per_token.values()), + } + + if request.uid: + block_info = self.module_backends[request.uid].get_info() + common_keys = set(result.keys()) & set(block_info.keys()) + if common_keys: + raise RuntimeError(f"The block's rpc_info has keys reserved for the server's rpc_info: {common_keys}") + result.update(block_info) + + return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result)) async def _rpc_forward( diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 7e76080..dca2ccd 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -102,7 +102,7 @@ class Server: f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); " f"Please specify --prefix manually when starting a server" ) - logger.info(f"Automatic dht prefix: {prefix}") + logger.debug(f"Automatic dht prefix: {prefix}") self.prefix = prefix if expiration is None: diff --git a/src/petals/utils/version.py b/src/petals/utils/version.py new file mode 100644 index 0000000..b992c27 --- /dev/null +++ b/src/petals/utils/version.py @@ -0,0 +1,26 @@ +import requests +from hivemind.utils.logging import TextStyle, get_logger +from packaging.version import parse + +import petals + +logger = get_logger(__file__) + + +def validate_version(): + logger.info(f"Running {TextStyle.BOLD}Petals {petals.__version__}{TextStyle.RESET}") + try: + r = requests.get("https://pypi.python.org/pypi/petals/json") + r.raise_for_status() + response = r.json() + + versions = [parse(ver) for ver in response.get("releases")] + latest = max(ver for ver in versions if not ver.is_prerelease) + + if parse(petals.__version__) < latest: + logger.info( + f"A newer version {latest} is available. Please upgrade with: " + f"{TextStyle.BOLD}pip install --upgrade petals{TextStyle.RESET}" + ) + except Exception as e: + logger.warning("Failed to fetch the latest Petals version from PyPI:", exc_info=True)