Check reachability automatically and give advice how to fix it (#155)

1. If we connect to the **public swarm**, the server now **automatically checks its DHT's reachability** from the outside world using API at http://health.petals.ml This is important to disallow unreachable servers to proceed (they create issues for the clients, such as repetitive retries).

    If http://health.petals.ml is down, the server proceeds without the check (so we don't depend on it). However, if health.petals.ml is up and explicitly tells us that we are unrechable, the server shows the reason of that and how to solve it.

    The check may be disabled with the `--skip_reachability_check` option (though I can't imagine cases where someone needs to use it).

2. Added `--port` and `--public_ip` as **simplified convenience options** for users not familiar with `--host_maddrs` and `--announce_maddrs`.
pull/148/head
Alexander Borzunov 1 year ago committed by GitHub
parent 73df69a117
commit 041ad20891
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -28,9 +28,20 @@ def main():
parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
"use the same name as in the converted model.")
parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0', '/ip6/::/tcp/0'], required=False,
help='Multiaddrs to listen for external connections from other peers. Default: all IPv4/IPv6 interfaces, a random free TCP port')
parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,
parser.add_argument('--port', type=int, required=False,
help='Port this server listens to. '
'This is a simplified way to set the --host_maddrs and --announce_maddrs options (see below) '
'that sets the port across all interfaces (IPv4, IPv6) and protocols (TCP, etc.) '
'to the same number. Default: a random free port is chosen for each interface and protocol')
parser.add_argument('--public_ip', type=str, required=False,
help='Your public IPv4 address, which is visible from the Internet. '
'This is a simplified way to set the --announce_maddrs option (see below).'
'Default: server announces IPv4/IPv6 addresses of your network interfaces')
parser.add_argument('--host_maddrs', nargs='+', required=False,
help='Multiaddrs to listen for external connections from other peers')
parser.add_argument('--announce_maddrs', nargs='+', required=False,
help='Visible multiaddrs the host announces for external connections from other peers')
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
@ -122,12 +133,33 @@ def main():
help="Convert the loaded model into mixed-8bit quantized model. "
"Default: True if GPU is available. Use `--load_in_8bit False` to disable this")
parser.add_argument("--skip_reachability_check", action='store_true',
help="Skip checking this server's reachability via health.petals.ml "
"when connecting to the public swarm. If you connect to a private swarm, "
"the check is skipped by default. Use this option only if you know what you are doing")
# fmt:on
args = vars(parser.parse_args())
args.pop("config", None)
args["converted_model_name_or_path"] = args.pop("model") or args["converted_model_name_or_path"]
host_maddrs = args.pop("host_maddrs")
port = args.pop("port")
if port is not None:
assert host_maddrs is None, "You can't use --port and --host_maddrs at the same time"
else:
port = 0
if host_maddrs is None:
host_maddrs = [f"/ip4/0.0.0.0/tcp/{port}", f"/ip6/::/tcp/{port}"]
announce_maddrs = args.pop("announce_maddrs")
public_ip = args.pop("public_ip")
if public_ip is not None:
assert announce_maddrs is None, "You can't use --public_ip and --announce_maddrs at the same time"
assert port != 0, "Please specify a fixed non-zero --port when you use --public_ip (e.g., --port 31337)"
announce_maddrs = [f"/ip4/{public_ip}/tcp/{port}"]
if args.pop("increase_file_limit"):
increase_file_limit()
@ -155,7 +187,14 @@ def main():
if load_in_8bit is not None:
args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"]
server = Server(**args, compression=compression, max_disk_space=max_disk_space, attn_cache_size=attn_cache_size)
server = Server(
**args,
host_maddrs=host_maddrs,
announce_maddrs=announce_maddrs,
compression=compression,
max_disk_space=max_disk_space,
attn_cache_size=attn_cache_size,
)
try:
server.run()
except KeyboardInterrupt:

@ -1,7 +1,6 @@
from __future__ import annotations
import gc
import itertools
import math
import multiprocessing as mp
import random
@ -11,6 +10,7 @@ from typing import Dict, List, Optional, Union
import numpy as np
import psutil
import requests
import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file
@ -77,6 +77,7 @@ class Server:
mean_block_selection_delay: float = 2.5,
use_auth_token: Optional[str] = None,
load_in_8bit: Optional[bool] = None,
skip_reachability_check: bool = False,
**kwargs,
):
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
@ -120,6 +121,8 @@ class Server:
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
if initial_peers == PUBLIC_INITIAL_PEERS:
logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
if not skip_reachability_check:
self._check_reachability()
else:
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
@ -183,6 +186,30 @@ class Server:
self.stop = threading.Event()
def _check_reachability(self):
try:
r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{self.dht.peer_id}", timeout=10)
r.raise_for_status()
response = r.json()
except Exception as e:
logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}")
return
if not response["success"]:
# This happens only if health.petals.ml is up and explicitly told us that we are unreachable
raise RuntimeError(
f"Server is not reachable from the Internet:\n\n"
f"{response['message']}\n\n"
f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
f" 1. Choose a specific port for the Petals server, for example, 31337.\n"
f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
)
logger.info("Server is reachable from the Internet, it will appear at http://health.petals.ml soon")
def _choose_num_blocks(self) -> int:
assert (
self.converted_model_name_or_path == "bigscience/bloom-petals"
@ -570,7 +597,7 @@ class ModuleAnnouncerThread(threading.Thread):
self.stop = threading.Event()
def run(self) -> None:
for iter_no in itertools.count():
while True:
declare_active_modules(
self.dht,
self.module_uids,
@ -578,10 +605,5 @@ class ModuleAnnouncerThread(threading.Thread):
state=self.state,
throughput=self.throughput,
)
if iter_no == 0 and self.state == ServerState.JOINING:
logger.info(
f"Please ensure that your server is reachable. "
f"For public swarm, open http://health.petals.ml and find peer_id = {self.dht.peer_id}"
)
if self.stop.wait(self.update_period):
break

Loading…
Cancel
Save