Support libp2p relays for NAT traversal (#186)

- Added relay options to servers
- Enabled relay options by default
- Changed hivemind version to 1.1.5
- Moved reachability check to be performed after blocks are loaded

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
pull/189/head
Egiazarian Vage 1 year ago committed by GitHub
parent 16b69d6050
commit 93bed7da5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -37,7 +37,7 @@ install_requires =
huggingface-hub==0.11.1 huggingface-hub==0.11.1
transformers==4.25.1 transformers==4.25.1
speedtest-cli==2.1.3 speedtest-cli==2.1.3
hivemind==1.1.3 hivemind==1.1.5
tensor_parallel==1.0.23 tensor_parallel==1.0.23
humanfriendly humanfriendly
async-timeout>=4.0.2 async-timeout>=4.0.2

@ -38,6 +38,9 @@ def main():
'This is a simplified way to set the --announce_maddrs option (see below).' 'This is a simplified way to set the --announce_maddrs option (see below).'
'Default: server announces IPv4/IPv6 addresses of your network interfaces') 'Default: server announces IPv4/IPv6 addresses of your network interfaces')
parser.add_argument("--no_auto_relay", action="store_false", dest="use_auto_relay",
help="Do not look for libp2p relays to reach peers behind NATs/firewalls")
parser.add_argument('--host_maddrs', nargs='+', required=False, parser.add_argument('--host_maddrs', nargs='+', required=False,
help='Multiaddrs to listen for external connections from other peers') help='Multiaddrs to listen for external connections from other peers')
parser.add_argument('--announce_maddrs', nargs='+', required=False, parser.add_argument('--announce_maddrs', nargs='+', required=False,

@ -107,6 +107,8 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
num_workers=n_layer, num_workers=n_layer,
startup_timeout=config.daemon_startup_timeout, startup_timeout=config.daemon_startup_timeout,
start=True, start=True,
use_relay=True,
use_auto_relay=True,
) )
) )
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance" assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"

@ -0,0 +1,39 @@
import math
import time
import requests
from hivemind.utils.logging import get_logger
logger = get_logger(__file__)
def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:
for attempt_no in range(math.floor(wait_time / retry_delay) + 1):
try:
r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{peer_id}", timeout=10)
r.raise_for_status()
response = r.json()
if response["success"]:
logger.info("Server is reachable from the Internet. It will appear at http://health.petals.ml soon")
return
if attempt_no == 0:
# Usually, libp2p manages to set up relays before we finish loading blocks.
# In other cases, we may need to wait for up to `wait_time` seconds before it's done.
logger.info("Detected a NAT or a firewall, connecting to libp2p relays. This takes a few minutes")
time.sleep(retry_delay)
except Exception as e:
logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}")
return
raise RuntimeError(
f"Server has not become reachable from the Internet:\n\n"
f"{response['message']}\n\n"
f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
f" 1. Choose a specific port for the Petals server, for example, 31337.\n"
f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
)

@ -10,7 +10,6 @@ from typing import Dict, List, Optional, Sequence, Union
import numpy as np import numpy as np
import psutil import psutil
import requests
import torch import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file from hivemind.moe.server.layers import add_custom_models_from_file
@ -28,6 +27,7 @@ from petals.server.backend import TransformerBackend
from petals.server.block_utils import get_block_size from petals.server.block_utils import get_block_size
from petals.server.handler import TransformerConnectionHandler from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache from petals.server.memory_cache import MemoryCache
from petals.server.reachability import check_reachability
from petals.server.throughput import get_host_throughput from petals.server.throughput import get_host_throughput
from petals.utils.convert_block import check_device_balance, convert_block from petals.utils.convert_block import check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR from petals.utils.disk_cache import DEFAULT_CACHE_DIR
@ -78,6 +78,8 @@ class Server:
load_in_8bit: Optional[bool] = None, load_in_8bit: Optional[bool] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
skip_reachability_check: bool = False, skip_reachability_check: bool = False,
use_relay: bool = True,
use_auto_relay: bool = True,
**kwargs, **kwargs,
): ):
"""Create a server with one or more bloom blocks. See run_server.py for documentation.""" """Create a server with one or more bloom blocks. See run_server.py for documentation."""
@ -117,14 +119,20 @@ class Server:
) )
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs) self.dht = DHT(
initial_peers=initial_peers,
start=True,
num_workers=self.block_config.n_layer,
use_relay=use_relay,
use_auto_relay=use_auto_relay,
**kwargs,
)
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()] visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
if initial_peers == PUBLIC_INITIAL_PEERS: if initial_peers == PUBLIC_INITIAL_PEERS:
logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}") logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
if not skip_reachability_check:
self._check_reachability()
else: else:
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
self.need_reachability_check = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
if device is None: if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
@ -196,35 +204,14 @@ class Server:
self.stop = threading.Event() self.stop = threading.Event()
def _check_reachability(self):
try:
r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{self.dht.peer_id}", timeout=10)
r.raise_for_status()
response = r.json()
except Exception as e:
logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}")
return
if not response["success"]:
# This happens only if health.petals.ml is up and explicitly told us that we are unreachable
raise RuntimeError(
f"Server is not reachable from the Internet:\n\n"
f"{response['message']}\n\n"
f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
f" 1. Choose a specific port for the Petals server, for example, 31337.\n"
f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
)
logger.info("Server is reachable from the Internet, it will appear at http://health.petals.ml soon")
def _choose_num_blocks(self) -> int: def _choose_num_blocks(self) -> int:
assert ( assert (
self.converted_model_name_or_path == "bigscience/bloom-petals" self.converted_model_name_or_path == "bigscience/bloom-petals"
), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually" ), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually"
assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually" assert self.device.type == "cuda", (
"GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
"CPU-only servers in the public swarm are discouraged since they are much slower"
)
num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1 num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
if num_devices > 1: if num_devices > 1:
@ -287,6 +274,7 @@ class Server:
use_auth_token=self.use_auth_token, use_auth_token=self.use_auth_token,
load_in_8bit=self.load_in_8bit, load_in_8bit=self.load_in_8bit,
tensor_parallel_devices=self.tensor_parallel_devices, tensor_parallel_devices=self.tensor_parallel_devices,
need_reachability_check=self.need_reachability_check,
start=True, start=True,
) )
try: try:
@ -380,6 +368,7 @@ class ModuleContainer(threading.Thread):
use_auth_token: Optional[str], use_auth_token: Optional[str],
load_in_8bit: bool, load_in_8bit: bool,
tensor_parallel_devices: Sequence[torch.device], tensor_parallel_devices: Sequence[torch.device],
need_reachability_check: bool,
**kwargs, **kwargs,
) -> ModuleContainer: ) -> ModuleContainer:
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices] module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
@ -433,6 +422,9 @@ class ModuleContainer(threading.Thread):
min_batch_size=min_batch_size, min_batch_size=min_batch_size,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
) )
if need_reachability_check:
check_reachability(dht.peer_id)
except: except:
logger.debug("Shutting down backends") logger.debug("Shutting down backends")
for backend in blocks.values(): for backend in blocks.values():

Loading…
Cancel
Save