Support libp2p relays for NAT traversal (#186)

- Added relay options to servers
- Enabled relay options by default
- Changed hivemind version to 1.1.5
- Moved reachability check to be performed after blocks are loaded

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
pull/189/head
Egiazarian Vage 1 year ago committed by GitHub
parent 16b69d6050
commit 93bed7da5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -37,7 +37,7 @@ install_requires =
huggingface-hub==0.11.1
transformers==4.25.1
speedtest-cli==2.1.3
hivemind==1.1.3
hivemind==1.1.5
tensor_parallel==1.0.23
humanfriendly
async-timeout>=4.0.2

@ -38,6 +38,9 @@ def main():
'This is a simplified way to set the --announce_maddrs option (see below).'
'Default: server announces IPv4/IPv6 addresses of your network interfaces')
parser.add_argument("--no_auto_relay", action="store_false", dest="use_auto_relay",
help="Do not look for libp2p relays to reach peers behind NATs/firewalls")
parser.add_argument('--host_maddrs', nargs='+', required=False,
help='Multiaddrs to listen for external connections from other peers')
parser.add_argument('--announce_maddrs', nargs='+', required=False,

@ -107,6 +107,8 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
num_workers=n_layer,
startup_timeout=config.daemon_startup_timeout,
start=True,
use_relay=True,
use_auto_relay=True,
)
)
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"

@ -0,0 +1,39 @@
import math
import time
import requests
from hivemind.utils.logging import get_logger
logger = get_logger(__file__)
def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:
for attempt_no in range(math.floor(wait_time / retry_delay) + 1):
try:
r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{peer_id}", timeout=10)
r.raise_for_status()
response = r.json()
if response["success"]:
logger.info("Server is reachable from the Internet. It will appear at http://health.petals.ml soon")
return
if attempt_no == 0:
# Usually, libp2p manages to set up relays before we finish loading blocks.
# In other cases, we may need to wait for up to `wait_time` seconds before it's done.
logger.info("Detected a NAT or a firewall, connecting to libp2p relays. This takes a few minutes")
time.sleep(retry_delay)
except Exception as e:
logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}")
return
raise RuntimeError(
f"Server has not become reachable from the Internet:\n\n"
f"{response['message']}\n\n"
f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
f" 1. Choose a specific port for the Petals server, for example, 31337.\n"
f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
)

@ -10,7 +10,6 @@ from typing import Dict, List, Optional, Sequence, Union
import numpy as np
import psutil
import requests
import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file
@ -28,6 +27,7 @@ from petals.server.backend import TransformerBackend
from petals.server.block_utils import get_block_size
from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache
from petals.server.reachability import check_reachability
from petals.server.throughput import get_host_throughput
from petals.utils.convert_block import check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
@ -78,6 +78,8 @@ class Server:
load_in_8bit: Optional[bool] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
skip_reachability_check: bool = False,
use_relay: bool = True,
use_auto_relay: bool = True,
**kwargs,
):
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
@ -117,14 +119,20 @@ class Server:
)
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs)
self.dht = DHT(
initial_peers=initial_peers,
start=True,
num_workers=self.block_config.n_layer,
use_relay=use_relay,
use_auto_relay=use_auto_relay,
**kwargs,
)
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
if initial_peers == PUBLIC_INITIAL_PEERS:
logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
if not skip_reachability_check:
self._check_reachability()
else:
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
self.need_reachability_check = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
@ -196,35 +204,14 @@ class Server:
self.stop = threading.Event()
def _check_reachability(self):
try:
r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{self.dht.peer_id}", timeout=10)
r.raise_for_status()
response = r.json()
except Exception as e:
logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}")
return
if not response["success"]:
# This happens only if health.petals.ml is up and explicitly told us that we are unreachable
raise RuntimeError(
f"Server is not reachable from the Internet:\n\n"
f"{response['message']}\n\n"
f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
f" 1. Choose a specific port for the Petals server, for example, 31337.\n"
f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
)
logger.info("Server is reachable from the Internet, it will appear at http://health.petals.ml soon")
def _choose_num_blocks(self) -> int:
assert (
self.converted_model_name_or_path == "bigscience/bloom-petals"
), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually"
assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually"
assert self.device.type == "cuda", (
"GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
"CPU-only servers in the public swarm are discouraged since they are much slower"
)
num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
if num_devices > 1:
@ -287,6 +274,7 @@ class Server:
use_auth_token=self.use_auth_token,
load_in_8bit=self.load_in_8bit,
tensor_parallel_devices=self.tensor_parallel_devices,
need_reachability_check=self.need_reachability_check,
start=True,
)
try:
@ -380,6 +368,7 @@ class ModuleContainer(threading.Thread):
use_auth_token: Optional[str],
load_in_8bit: bool,
tensor_parallel_devices: Sequence[torch.device],
need_reachability_check: bool,
**kwargs,
) -> ModuleContainer:
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
@ -433,6 +422,9 @@ class ModuleContainer(threading.Thread):
min_batch_size=min_batch_size,
max_batch_size=max_batch_size,
)
if need_reachability_check:
check_reachability(dht.peer_id)
except:
logger.debug("Shutting down backends")
for backend in blocks.values():

Loading…
Cancel
Save