Add service checking direct reachability from peers (#195)

Servers joining from behind NATs/firewalls usually take several minutes to join a libp2p relay before they become accessible from the outside Internet. Moreover, requests to such servers are slower and more likely to fail (e.g., if the server switches a relay at the moment). If such servers host certain DHT keys, the swarm may occasionally lose read/write access to these keys, which results in:

- Clients being unable to find any servers hosting a certain block.
- All servers starting rebalancing to the same place to close the alleged "gap" in the swarm.

This PRs modifies servers so that DHT keys are only hosted on **directly reachable** servers (the ones who aren't behind NAT/firewall). This way, DHT becomes more stable and works faster. Of course, trhe servers behind NATs/firewalls still accept requests for running inference/forward/backward for blocks they hold (it's more acceptable for this kind of requests to be slower or fail).

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
pull/210/head
justheuristic 1 year ago committed by GitHub
parent 5f58f00649
commit 771ca590e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,104 @@
"""
A copy of run_dht.py from hivemind with the ReachabilityProtocol added:
https://github.com/learning-at-home/hivemind/blob/master/hivemind/hivemind_cli/run_dht.py
This script may be used for launching lightweight CPU machines serving as bootstrap nodes to a Petals swarm.
This may be eventually merged to the hivemind upstream.
"""
import time
from argparse import ArgumentParser
from secrets import token_hex
from hivemind.dht import DHT, DHTNode
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.networking import log_visible_maddrs
from petals.server.reachability import ReachabilityProtocol
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)
async def report_status(dht: DHT, node: DHTNode):
logger.info(
f"{len(node.protocol.routing_table.uid_to_peer_id) + 1} DHT nodes (including this one) "
f"are in the local routing table "
)
logger.debug(f"Routing table contents: {node.protocol.routing_table}")
logger.info(f"Local storage contains {len(node.protocol.storage)} keys")
logger.debug(f"Local storage contents: {node.protocol.storage}")
# Contact peers and keep the routing table healthy (remove stale PeerIDs)
await node.get(f"heartbeat_{token_hex(16)}", latest=True)
def main():
parser = ArgumentParser()
parser.add_argument(
"--initial_peers",
nargs="*",
help="Multiaddrs of the peers that will welcome you into the existing DHT. "
"Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY",
)
parser.add_argument(
"--host_maddrs",
nargs="*",
default=["/ip4/0.0.0.0/tcp/0", "/ip6/::/tcp/0"],
help="Multiaddrs to listen for external connections from other DHT instances. "
"Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0",
)
parser.add_argument(
"--announce_maddrs",
nargs="*",
help="Visible multiaddrs the host announces for external connections from other DHT instances",
)
parser.add_argument(
"--use_ipfs",
action="store_true",
help='Use IPFS to find initial_peers. If enabled, you only need to provide the "/p2p/XXXX" '
"part of the multiaddrs for the initial_peers "
"(no need to specify a particular IPv4/IPv6 host and port)",
)
parser.add_argument(
"--identity_path",
help="Path to a private key file. If defined, makes the peer ID deterministic. "
"If the file does not exist, writes a new private key to this file.",
)
parser.add_argument(
"--no_relay",
action="store_false",
dest="use_relay",
help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)",
)
parser.add_argument(
"--use_auto_relay", action="store_true", help="Look for libp2p relays to reach peers behind NATs/firewalls"
)
parser.add_argument(
"--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"
)
args = parser.parse_args()
dht = DHT(
start=True,
initial_peers=args.initial_peers,
host_maddrs=args.host_maddrs,
announce_maddrs=args.announce_maddrs,
use_ipfs=args.use_ipfs,
identity_path=args.identity_path,
use_relay=args.use_relay,
use_auto_relay=args.use_auto_relay,
)
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs)
reachability_protocol = ReachabilityProtocol.attach_to_dht(dht, await_ready=True)
while True:
dht.run_coroutine(report_status, return_future=False)
time.sleep(args.refresh_period)
if __name__ == "__main__":
main()

@ -4,3 +4,6 @@ PUBLIC_INITIAL_PEERS = [
"/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
"/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
]
# The reachability API is currently used only when connecting to the public swarm
REACHABILITY_API_URL = "http://health.petals.ml"

@ -1,16 +1,30 @@
import asyncio
import math
import threading
import time
from concurrent.futures import Future
from contextlib import asynccontextmanager
from functools import partial
from secrets import token_hex
from typing import Optional
import requests
from hivemind.utils.logging import get_logger
from hivemind.dht import DHT, DHTNode
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
from hivemind.proto import dht_pb2
from hivemind.utils import get_logger
logger = get_logger(__file__)
from petals.constants import REACHABILITY_API_URL
logger = get_logger(__name__)
def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:
def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:
"""verify that your peer is reachable from a (centralized) validator, whether directly or through a relay"""
for attempt_no in range(math.floor(wait_time / retry_delay) + 1):
try:
r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{peer_id}", timeout=10)
r = requests.get(f"{REACHABILITY_API_URL}/api/v1/is_reachable/{peer_id}", timeout=10)
r.raise_for_status()
response = r.json()
@ -37,3 +51,116 @@ def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float =
f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
)
def check_direct_reachability(max_peers: int = 5, threshold: float = 0.5, **kwargs) -> Optional[bool]:
"""test if your peer is accessible by others in the swarm with the specified network options in **kwargs"""
async def _check_direct_reachability():
target_dht = await DHTNode.create(client_mode=True, **kwargs)
try:
protocol = ReachabilityProtocol(probe=target_dht.protocol.p2p)
async with protocol.serve(target_dht.protocol.p2p):
successes = requests = 0
for remote_peer in list(target_dht.protocol.routing_table.peer_id_to_uid.keys()):
probe_available = await protocol.call_check(remote_peer=remote_peer, check_peer=target_dht.peer_id)
if probe_available is None:
continue # remote peer failed to check probe
successes += probe_available
requests += 1
if requests >= max_peers:
break
logger.info(f"Direct reachability: {successes}/{requests}")
return (successes / requests) >= threshold if requests > 0 else None
finally:
await target_dht.shutdown()
return RemoteExpertWorker.run_coroutine(_check_direct_reachability())
STRIPPED_PROBE_ARGS = dict(
dht_mode="client", use_relay=False, auto_nat=False, nat_port_map=False, no_listen=True, startup_timeout=60
)
class ReachabilityProtocol(ServicerBase):
"""Mini protocol to test if a locally running peer is accessible by other devices in the swarm"""
def __init__(self, *, probe: Optional[P2P] = None, wait_timeout: float = 5.0):
self.probe = probe
self.wait_timeout = wait_timeout
self._event_loop = self._stop = None
async def call_check(self, remote_peer: PeerID, *, check_peer: PeerID) -> Optional[bool]:
"""Returns True if remote_peer can reach check_peer, False if it cannot, None if it did not respond"""
try:
request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=check_peer.to_bytes()))
timeout = self.wait_timeout if check_peer == remote_peer else self.wait_timeout * 2
response = await self.get_stub(self.probe, remote_peer).rpc_check(request, timeout=timeout)
logger.debug(f"call_check(remote_peer={remote_peer}, check_peer={check_peer}) -> {response.available}")
return response.available
except Exception as e:
logger.debug(f"Requested {remote_peer} to check {check_peer}, but got:", exc_info=True)
return None
async def rpc_check(self, request: dht_pb2.PingRequest, context: P2PContext) -> dht_pb2.PingResponse:
"""Help another peer to check its reachability"""
response = dht_pb2.PingResponse(available=True)
check_peer = PeerID(request.peer.node_id)
if check_peer != context.local_id: # remote peer wants us to check someone other than ourselves
response.available = await self.call_check(check_peer, check_peer=check_peer) is True
logger.info(
f"reachability.rpc_check(remote_peer=...{str(context.remote_id)[-6:]}, "
f"check_peer=...{str(check_peer)[-6:]}) -> {response.available}"
)
return response
@asynccontextmanager
async def serve(self, p2p: P2P):
try:
await self.add_p2p_handlers(p2p)
yield self
finally:
await self.remove_p2p_handlers(p2p)
@classmethod
def attach_to_dht(cls, dht: DHT, await_ready: bool = False, **kwargs) -> Optional["ReachabilityProtocol"]:
protocol = cls(**kwargs)
ready = Future()
async def _serve_with_probe():
try:
common_p2p = await dht.replicate_p2p()
protocol._event_loop = asyncio.get_event_loop()
protocol._stop = asyncio.Event()
initial_peers = [str(addr) for addr in await common_p2p.get_visible_maddrs(latest=True)]
for info in await common_p2p.list_peers():
initial_peers.extend(f"{addr}/p2p/{info.peer_id}" for addr in info.addrs)
protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)
ready.set_result(True)
logger.info("Reachability service started")
async with protocol.serve(common_p2p):
await protocol._stop.wait()
except Exception as e:
logger.warning(f"Reachability service failed: {repr(e)}")
logger.debug("See detailed traceback below:", exc_info=True)
if not ready.done():
ready.set_exception(e)
finally:
if protocol is not None and protocol.probe is not None:
await protocol.probe.shutdown()
logger.debug("Reachability service shut down")
threading.Thread(target=partial(asyncio.run, _serve_with_probe()), daemon=True).start()
if await_ready:
ready.result() # Propagates startup exceptions, if any
return protocol
def shutdown(self):
if self._event_loop is not None and self._stop is not None:
self._event_loop.call_soon_threadsafe(self._stop.set)

@ -26,7 +26,7 @@ from petals.server.backend import TransformerBackend
from petals.server.block_utils import get_block_size
from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache
from petals.server.reachability import check_reachability
from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
from petals.server.throughput import get_dtype_name, get_host_throughput
from petals.utils.convert_block import check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
@ -77,6 +77,7 @@ class Server:
load_in_8bit: Optional[bool] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
skip_reachability_check: bool = False,
dht_client_mode: Optional[bool] = None,
use_relay: bool = True,
use_auto_relay: bool = True,
**kwargs,
@ -118,20 +119,27 @@ class Server:
)
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
if dht_client_mode is None:
is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
dht_client_mode = is_reachable is False # if could not check reachability (returns None), run a full peer
logger.info(f"This server will run DHT in {'client' if dht_client_mode else 'full peer'} mode")
self.dht = DHT(
initial_peers=initial_peers,
start=True,
num_workers=self.block_config.n_layer,
use_relay=use_relay,
use_auto_relay=use_auto_relay,
client_mode=dht_client_mode,
**kwargs,
)
self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not dht_client_mode else None
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
if initial_peers == PUBLIC_INITIAL_PEERS:
logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
else:
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
self.need_reachability_check = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
@ -277,7 +285,7 @@ class Server:
use_auth_token=self.use_auth_token,
load_in_8bit=self.load_in_8bit,
tensor_parallel_devices=self.tensor_parallel_devices,
need_reachability_check=self.need_reachability_check,
should_validate_reachability=self.should_validate_reachability,
start=True,
)
try:
@ -335,6 +343,8 @@ class Server:
def shutdown(self):
self.stop.set()
if self.reachability_protocol is not None:
self.reachability_protocol.shutdown()
self.dht.shutdown()
self.dht.join()
@ -367,7 +377,7 @@ class ModuleContainer(threading.Thread):
use_auth_token: Optional[str],
load_in_8bit: bool,
tensor_parallel_devices: Sequence[torch.device],
need_reachability_check: bool,
should_validate_reachability: bool,
**kwargs,
) -> ModuleContainer:
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
@ -422,8 +432,8 @@ class ModuleContainer(threading.Thread):
max_batch_size=max_batch_size,
)
if need_reachability_check:
check_reachability(dht.peer_id)
if should_validate_reachability:
validate_reachability(dht.peer_id)
except:
logger.debug("Shutting down backends")
for backend in blocks.values():

Loading…
Cancel
Save