diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index d9cbfc5..27cca30 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -28,9 +28,20 @@ def main(): parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve") parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default," "use the same name as in the converted model.") - parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0', '/ip6/::/tcp/0'], required=False, - help='Multiaddrs to listen for external connections from other peers. Default: all IPv4/IPv6 interfaces, a random free TCP port') - parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False, + + parser.add_argument('--port', type=int, required=False, + help='Port this server listens to. ' + 'This is a simplified way to set the --host_maddrs and --announce_maddrs options (see below) ' + 'that sets the port across all interfaces (IPv4, IPv6) and protocols (TCP, etc.) ' + 'to the same number. Default: a random free port is chosen for each interface and protocol') + parser.add_argument('--public_ip', type=str, required=False, + help='Your public IPv4 address, which is visible from the Internet. ' + 'This is a simplified way to set the --announce_maddrs option (see below).' + 'Default: server announces IPv4/IPv6 addresses of your network interfaces') + + parser.add_argument('--host_maddrs', nargs='+', required=False, + help='Multiaddrs to listen for external connections from other peers') + parser.add_argument('--announce_maddrs', nargs='+', required=False, help='Visible multiaddrs the host announces for external connections from other peers') parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication') @@ -122,12 +133,33 @@ def main(): help="Convert the loaded model into mixed-8bit quantized model. " "Default: True if GPU is available. Use `--load_in_8bit False` to disable this") + parser.add_argument("--skip_reachability_check", action='store_true', + help="Skip checking this server's reachability via health.petals.ml " + "when connecting to the public swarm. If you connect to a private swarm, " + "the check is skipped by default. Use this option only if you know what you are doing") + # fmt:on args = vars(parser.parse_args()) args.pop("config", None) args["converted_model_name_or_path"] = args.pop("model") or args["converted_model_name_or_path"] + host_maddrs = args.pop("host_maddrs") + port = args.pop("port") + if port is not None: + assert host_maddrs is None, "You can't use --port and --host_maddrs at the same time" + else: + port = 0 + if host_maddrs is None: + host_maddrs = [f"/ip4/0.0.0.0/tcp/{port}", f"/ip6/::/tcp/{port}"] + + announce_maddrs = args.pop("announce_maddrs") + public_ip = args.pop("public_ip") + if public_ip is not None: + assert announce_maddrs is None, "You can't use --public_ip and --announce_maddrs at the same time" + assert port != 0, "Please specify a fixed non-zero --port when you use --public_ip (e.g., --port 31337)" + announce_maddrs = [f"/ip4/{public_ip}/tcp/{port}"] + if args.pop("increase_file_limit"): increase_file_limit() @@ -155,7 +187,14 @@ def main(): if load_in_8bit is not None: args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"] - server = Server(**args, compression=compression, max_disk_space=max_disk_space, attn_cache_size=attn_cache_size) + server = Server( + **args, + host_maddrs=host_maddrs, + announce_maddrs=announce_maddrs, + compression=compression, + max_disk_space=max_disk_space, + attn_cache_size=attn_cache_size, + ) try: server.run() except KeyboardInterrupt: diff --git a/src/petals/server/server.py b/src/petals/server/server.py index ee0b02e..f509c0b 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -1,7 +1,6 @@ from __future__ import annotations import gc -import itertools import math import multiprocessing as mp import random @@ -11,6 +10,7 @@ from typing import Dict, List, Optional, Union import numpy as np import psutil +import requests import torch from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind.moe.server.layers import add_custom_models_from_file @@ -77,6 +77,7 @@ class Server: mean_block_selection_delay: float = 2.5, use_auth_token: Optional[str] = None, load_in_8bit: Optional[bool] = None, + skip_reachability_check: bool = False, **kwargs, ): """Create a server with one or more bloom blocks. See run_server.py for documentation.""" @@ -120,6 +121,8 @@ class Server: visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()] if initial_peers == PUBLIC_INITIAL_PEERS: logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}") + if not skip_reachability_check: + self._check_reachability() else: logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") @@ -183,6 +186,30 @@ class Server: self.stop = threading.Event() + def _check_reachability(self): + try: + r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{self.dht.peer_id}", timeout=10) + r.raise_for_status() + response = r.json() + except Exception as e: + logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}") + return + + if not response["success"]: + # This happens only if health.petals.ml is up and explicitly told us that we are unreachable + raise RuntimeError( + f"Server is not reachable from the Internet:\n\n" + f"{response['message']}\n\n" + f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n" + f" 1. Choose a specific port for the Petals server, for example, 31337.\n" + f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n" + f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n" + f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n" + f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n" + ) + + logger.info("Server is reachable from the Internet, it will appear at http://health.petals.ml soon") + def _choose_num_blocks(self) -> int: assert ( self.converted_model_name_or_path == "bigscience/bloom-petals" @@ -570,7 +597,7 @@ class ModuleAnnouncerThread(threading.Thread): self.stop = threading.Event() def run(self) -> None: - for iter_no in itertools.count(): + while True: declare_active_modules( self.dht, self.module_uids, @@ -578,10 +605,5 @@ class ModuleAnnouncerThread(threading.Thread): state=self.state, throughput=self.throughput, ) - if iter_no == 0 and self.state == ServerState.JOINING: - logger.info( - f"Please ensure that your server is reachable. " - f"For public swarm, open http://health.petals.ml and find peer_id = {self.dht.peer_id}" - ) if self.stop.wait(self.update_period): break