From 041ad2089108449db2201f6860c4d8f916525e1a Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 15 Dec 2022 05:04:09 +0400 Subject: [PATCH] Check reachability automatically and give advice how to fix it (#155) 1. If we connect to the **public swarm**, the server now **automatically checks its DHT's reachability** from the outside world using API at http://health.petals.ml This is important to disallow unreachable servers to proceed (they create issues for the clients, such as repetitive retries). If http://health.petals.ml is down, the server proceeds without the check (so we don't depend on it). However, if health.petals.ml is up and explicitly tells us that we are unrechable, the server shows the reason of that and how to solve it. The check may be disabled with the `--skip_reachability_check` option (though I can't imagine cases where someone needs to use it). 2. Added `--port` and `--public_ip` as **simplified convenience options** for users not familiar with `--host_maddrs` and `--announce_maddrs`. --- src/petals/cli/run_server.py | 47 +++++++++++++++++++++++++++++++++--- src/petals/server/server.py | 36 +++++++++++++++++++++------ 2 files changed, 72 insertions(+), 11 deletions(-) diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index d9cbfc5..27cca30 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -28,9 +28,20 @@ def main(): parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve") parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default," "use the same name as in the converted model.") - parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0', '/ip6/::/tcp/0'], required=False, - help='Multiaddrs to listen for external connections from other peers. Default: all IPv4/IPv6 interfaces, a random free TCP port') - parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False, + + parser.add_argument('--port', type=int, required=False, + help='Port this server listens to. ' + 'This is a simplified way to set the --host_maddrs and --announce_maddrs options (see below) ' + 'that sets the port across all interfaces (IPv4, IPv6) and protocols (TCP, etc.) ' + 'to the same number. Default: a random free port is chosen for each interface and protocol') + parser.add_argument('--public_ip', type=str, required=False, + help='Your public IPv4 address, which is visible from the Internet. ' + 'This is a simplified way to set the --announce_maddrs option (see below).' + 'Default: server announces IPv4/IPv6 addresses of your network interfaces') + + parser.add_argument('--host_maddrs', nargs='+', required=False, + help='Multiaddrs to listen for external connections from other peers') + parser.add_argument('--announce_maddrs', nargs='+', required=False, help='Visible multiaddrs the host announces for external connections from other peers') parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication') @@ -122,12 +133,33 @@ def main(): help="Convert the loaded model into mixed-8bit quantized model. " "Default: True if GPU is available. Use `--load_in_8bit False` to disable this") + parser.add_argument("--skip_reachability_check", action='store_true', + help="Skip checking this server's reachability via health.petals.ml " + "when connecting to the public swarm. If you connect to a private swarm, " + "the check is skipped by default. Use this option only if you know what you are doing") + # fmt:on args = vars(parser.parse_args()) args.pop("config", None) args["converted_model_name_or_path"] = args.pop("model") or args["converted_model_name_or_path"] + host_maddrs = args.pop("host_maddrs") + port = args.pop("port") + if port is not None: + assert host_maddrs is None, "You can't use --port and --host_maddrs at the same time" + else: + port = 0 + if host_maddrs is None: + host_maddrs = [f"/ip4/0.0.0.0/tcp/{port}", f"/ip6/::/tcp/{port}"] + + announce_maddrs = args.pop("announce_maddrs") + public_ip = args.pop("public_ip") + if public_ip is not None: + assert announce_maddrs is None, "You can't use --public_ip and --announce_maddrs at the same time" + assert port != 0, "Please specify a fixed non-zero --port when you use --public_ip (e.g., --port 31337)" + announce_maddrs = [f"/ip4/{public_ip}/tcp/{port}"] + if args.pop("increase_file_limit"): increase_file_limit() @@ -155,7 +187,14 @@ def main(): if load_in_8bit is not None: args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"] - server = Server(**args, compression=compression, max_disk_space=max_disk_space, attn_cache_size=attn_cache_size) + server = Server( + **args, + host_maddrs=host_maddrs, + announce_maddrs=announce_maddrs, + compression=compression, + max_disk_space=max_disk_space, + attn_cache_size=attn_cache_size, + ) try: server.run() except KeyboardInterrupt: diff --git a/src/petals/server/server.py b/src/petals/server/server.py index ee0b02e..f509c0b 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -1,7 +1,6 @@ from __future__ import annotations import gc -import itertools import math import multiprocessing as mp import random @@ -11,6 +10,7 @@ from typing import Dict, List, Optional, Union import numpy as np import psutil +import requests import torch from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind.moe.server.layers import add_custom_models_from_file @@ -77,6 +77,7 @@ class Server: mean_block_selection_delay: float = 2.5, use_auth_token: Optional[str] = None, load_in_8bit: Optional[bool] = None, + skip_reachability_check: bool = False, **kwargs, ): """Create a server with one or more bloom blocks. See run_server.py for documentation.""" @@ -120,6 +121,8 @@ class Server: visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()] if initial_peers == PUBLIC_INITIAL_PEERS: logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}") + if not skip_reachability_check: + self._check_reachability() else: logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") @@ -183,6 +186,30 @@ class Server: self.stop = threading.Event() + def _check_reachability(self): + try: + r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{self.dht.peer_id}", timeout=10) + r.raise_for_status() + response = r.json() + except Exception as e: + logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}") + return + + if not response["success"]: + # This happens only if health.petals.ml is up and explicitly told us that we are unreachable + raise RuntimeError( + f"Server is not reachable from the Internet:\n\n" + f"{response['message']}\n\n" + f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n" + f" 1. Choose a specific port for the Petals server, for example, 31337.\n" + f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n" + f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n" + f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n" + f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n" + ) + + logger.info("Server is reachable from the Internet, it will appear at http://health.petals.ml soon") + def _choose_num_blocks(self) -> int: assert ( self.converted_model_name_or_path == "bigscience/bloom-petals" @@ -570,7 +597,7 @@ class ModuleAnnouncerThread(threading.Thread): self.stop = threading.Event() def run(self) -> None: - for iter_no in itertools.count(): + while True: declare_active_modules( self.dht, self.module_uids, @@ -578,10 +605,5 @@ class ModuleAnnouncerThread(threading.Thread): state=self.state, throughput=self.throughput, ) - if iter_no == 0 and self.state == ServerState.JOINING: - logger.info( - f"Please ensure that your server is reachable. " - f"For public swarm, open http://health.petals.ml and find peer_id = {self.dht.peer_id}" - ) if self.stop.wait(self.update_period): break