diff --git a/setup.cfg b/setup.cfg index 11513bd..3ba993e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ install_requires = huggingface-hub==0.11.1 transformers==4.25.1 speedtest-cli==2.1.3 - hivemind==1.1.3 + hivemind==1.1.5 tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index e089937..fc0771d 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -38,6 +38,9 @@ def main(): 'This is a simplified way to set the --announce_maddrs option (see below).' 'Default: server announces IPv4/IPv6 addresses of your network interfaces') + parser.add_argument("--no_auto_relay", action="store_false", dest="use_auto_relay", + help="Do not look for libp2p relays to reach peers behind NATs/firewalls") + parser.add_argument('--host_maddrs', nargs='+', required=False, help='Multiaddrs to listen for external connections from other peers') parser.add_argument('--announce_maddrs', nargs='+', required=False, diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index 3e52e40..5d22bfd 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -107,6 +107,8 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel): num_workers=n_layer, startup_timeout=config.daemon_startup_timeout, start=True, + use_relay=True, + use_auto_relay=True, ) ) assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance" diff --git a/src/petals/server/reachability.py b/src/petals/server/reachability.py new file mode 100644 index 0000000..d8b5fba --- /dev/null +++ b/src/petals/server/reachability.py @@ -0,0 +1,39 @@ +import math +import time + +import requests +from hivemind.utils.logging import get_logger + +logger = get_logger(__file__) + + +def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None: + for attempt_no in range(math.floor(wait_time / retry_delay) + 1): + try: + r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{peer_id}", timeout=10) + r.raise_for_status() + response = r.json() + + if response["success"]: + logger.info("Server is reachable from the Internet. It will appear at http://health.petals.ml soon") + return + + if attempt_no == 0: + # Usually, libp2p manages to set up relays before we finish loading blocks. + # In other cases, we may need to wait for up to `wait_time` seconds before it's done. + logger.info("Detected a NAT or a firewall, connecting to libp2p relays. This takes a few minutes") + time.sleep(retry_delay) + except Exception as e: + logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}") + return + + raise RuntimeError( + f"Server has not become reachable from the Internet:\n\n" + f"{response['message']}\n\n" + f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n" + f" 1. Choose a specific port for the Petals server, for example, 31337.\n" + f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n" + f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n" + f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n" + f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n" + ) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index a8927aa..e1a2293 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -10,7 +10,6 @@ from typing import Dict, List, Optional, Sequence, Union import numpy as np import psutil -import requests import torch from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind.moe.server.layers import add_custom_models_from_file @@ -28,6 +27,7 @@ from petals.server.backend import TransformerBackend from petals.server.block_utils import get_block_size from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache +from petals.server.reachability import check_reachability from petals.server.throughput import get_host_throughput from petals.utils.convert_block import check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR @@ -78,6 +78,8 @@ class Server: load_in_8bit: Optional[bool] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None, skip_reachability_check: bool = False, + use_relay: bool = True, + use_auto_relay: bool = True, **kwargs, ): """Create a server with one or more bloom blocks. See run_server.py for documentation.""" @@ -117,14 +119,20 @@ class Server: ) self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] - self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs) + self.dht = DHT( + initial_peers=initial_peers, + start=True, + num_workers=self.block_config.n_layer, + use_relay=use_relay, + use_auto_relay=use_auto_relay, + **kwargs, + ) visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()] if initial_peers == PUBLIC_INITIAL_PEERS: logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}") - if not skip_reachability_check: - self._check_reachability() else: logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") + self.need_reachability_check = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" @@ -196,35 +204,14 @@ class Server: self.stop = threading.Event() - def _check_reachability(self): - try: - r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{self.dht.peer_id}", timeout=10) - r.raise_for_status() - response = r.json() - except Exception as e: - logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}") - return - - if not response["success"]: - # This happens only if health.petals.ml is up and explicitly told us that we are unreachable - raise RuntimeError( - f"Server is not reachable from the Internet:\n\n" - f"{response['message']}\n\n" - f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n" - f" 1. Choose a specific port for the Petals server, for example, 31337.\n" - f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n" - f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n" - f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n" - f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n" - ) - - logger.info("Server is reachable from the Internet, it will appear at http://health.petals.ml soon") - def _choose_num_blocks(self) -> int: assert ( self.converted_model_name_or_path == "bigscience/bloom-petals" ), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually" - assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually" + assert self.device.type == "cuda", ( + "GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. " + "CPU-only servers in the public swarm are discouraged since they are much slower" + ) num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1 if num_devices > 1: @@ -287,6 +274,7 @@ class Server: use_auth_token=self.use_auth_token, load_in_8bit=self.load_in_8bit, tensor_parallel_devices=self.tensor_parallel_devices, + need_reachability_check=self.need_reachability_check, start=True, ) try: @@ -380,6 +368,7 @@ class ModuleContainer(threading.Thread): use_auth_token: Optional[str], load_in_8bit: bool, tensor_parallel_devices: Sequence[torch.device], + need_reachability_check: bool, **kwargs, ) -> ModuleContainer: module_uids = [f"{prefix}.{block_index}" for block_index in block_indices] @@ -433,6 +422,9 @@ class ModuleContainer(threading.Thread): min_batch_size=min_batch_size, max_batch_size=max_batch_size, ) + + if need_reachability_check: + check_reachability(dht.peer_id) except: logger.debug("Shutting down backends") for backend in blocks.values():