diff --git a/src/petals/cli/run_dht.py b/src/petals/cli/run_dht.py new file mode 100644 index 0000000..2f30516 --- /dev/null +++ b/src/petals/cli/run_dht.py @@ -0,0 +1,104 @@ +""" +A copy of run_dht.py from hivemind with the ReachabilityProtocol added: +https://github.com/learning-at-home/hivemind/blob/master/hivemind/hivemind_cli/run_dht.py + +This script may be used for launching lightweight CPU machines serving as bootstrap nodes to a Petals swarm. + +This may be eventually merged to the hivemind upstream. +""" + +import time +from argparse import ArgumentParser +from secrets import token_hex + +from hivemind.dht import DHT, DHTNode +from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from hivemind.utils.networking import log_visible_maddrs + +from petals.server.reachability import ReachabilityProtocol + +use_hivemind_log_handler("in_root_logger") +logger = get_logger(__name__) + + +async def report_status(dht: DHT, node: DHTNode): + logger.info( + f"{len(node.protocol.routing_table.uid_to_peer_id) + 1} DHT nodes (including this one) " + f"are in the local routing table " + ) + logger.debug(f"Routing table contents: {node.protocol.routing_table}") + logger.info(f"Local storage contains {len(node.protocol.storage)} keys") + logger.debug(f"Local storage contents: {node.protocol.storage}") + + # Contact peers and keep the routing table healthy (remove stale PeerIDs) + await node.get(f"heartbeat_{token_hex(16)}", latest=True) + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--initial_peers", + nargs="*", + help="Multiaddrs of the peers that will welcome you into the existing DHT. " + "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY", + ) + parser.add_argument( + "--host_maddrs", + nargs="*", + default=["/ip4/0.0.0.0/tcp/0", "/ip6/::/tcp/0"], + help="Multiaddrs to listen for external connections from other DHT instances. " + "Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0", + ) + parser.add_argument( + "--announce_maddrs", + nargs="*", + help="Visible multiaddrs the host announces for external connections from other DHT instances", + ) + parser.add_argument( + "--use_ipfs", + action="store_true", + help='Use IPFS to find initial_peers. If enabled, you only need to provide the "/p2p/XXXX" ' + "part of the multiaddrs for the initial_peers " + "(no need to specify a particular IPv4/IPv6 host and port)", + ) + parser.add_argument( + "--identity_path", + help="Path to a private key file. If defined, makes the peer ID deterministic. " + "If the file does not exist, writes a new private key to this file.", + ) + parser.add_argument( + "--no_relay", + action="store_false", + dest="use_relay", + help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)", + ) + parser.add_argument( + "--use_auto_relay", action="store_true", help="Look for libp2p relays to reach peers behind NATs/firewalls" + ) + parser.add_argument( + "--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT" + ) + + args = parser.parse_args() + + dht = DHT( + start=True, + initial_peers=args.initial_peers, + host_maddrs=args.host_maddrs, + announce_maddrs=args.announce_maddrs, + use_ipfs=args.use_ipfs, + identity_path=args.identity_path, + use_relay=args.use_relay, + use_auto_relay=args.use_auto_relay, + ) + log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs) + + reachability_protocol = ReachabilityProtocol.attach_to_dht(dht, await_ready=True) + + while True: + dht.run_coroutine(report_status, return_future=False) + time.sleep(args.refresh_period) + + +if __name__ == "__main__": + main() diff --git a/src/petals/constants.py b/src/petals/constants.py index a4620c3..da047f1 100644 --- a/src/petals/constants.py +++ b/src/petals/constants.py @@ -4,3 +4,6 @@ PUBLIC_INITIAL_PEERS = [ "/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", "/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", ] + +# The reachability API is currently used only when connecting to the public swarm +REACHABILITY_API_URL = "http://health.petals.ml" diff --git a/src/petals/server/reachability.py b/src/petals/server/reachability.py index d8b5fba..7ead055 100644 --- a/src/petals/server/reachability.py +++ b/src/petals/server/reachability.py @@ -1,16 +1,30 @@ +import asyncio import math +import threading import time +from concurrent.futures import Future +from contextlib import asynccontextmanager +from functools import partial +from secrets import token_hex +from typing import Optional import requests -from hivemind.utils.logging import get_logger +from hivemind.dht import DHT, DHTNode +from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker +from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase +from hivemind.proto import dht_pb2 +from hivemind.utils import get_logger -logger = get_logger(__file__) +from petals.constants import REACHABILITY_API_URL +logger = get_logger(__name__) -def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None: + +def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None: + """verify that your peer is reachable from a (centralized) validator, whether directly or through a relay""" for attempt_no in range(math.floor(wait_time / retry_delay) + 1): try: - r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{peer_id}", timeout=10) + r = requests.get(f"{REACHABILITY_API_URL}/api/v1/is_reachable/{peer_id}", timeout=10) r.raise_for_status() response = r.json() @@ -37,3 +51,116 @@ def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n" f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n" ) + + +def check_direct_reachability(max_peers: int = 5, threshold: float = 0.5, **kwargs) -> Optional[bool]: + """test if your peer is accessible by others in the swarm with the specified network options in **kwargs""" + + async def _check_direct_reachability(): + target_dht = await DHTNode.create(client_mode=True, **kwargs) + try: + protocol = ReachabilityProtocol(probe=target_dht.protocol.p2p) + async with protocol.serve(target_dht.protocol.p2p): + successes = requests = 0 + for remote_peer in list(target_dht.protocol.routing_table.peer_id_to_uid.keys()): + probe_available = await protocol.call_check(remote_peer=remote_peer, check_peer=target_dht.peer_id) + if probe_available is None: + continue # remote peer failed to check probe + successes += probe_available + requests += 1 + if requests >= max_peers: + break + + logger.info(f"Direct reachability: {successes}/{requests}") + return (successes / requests) >= threshold if requests > 0 else None + finally: + await target_dht.shutdown() + + return RemoteExpertWorker.run_coroutine(_check_direct_reachability()) + + +STRIPPED_PROBE_ARGS = dict( + dht_mode="client", use_relay=False, auto_nat=False, nat_port_map=False, no_listen=True, startup_timeout=60 +) + + +class ReachabilityProtocol(ServicerBase): + """Mini protocol to test if a locally running peer is accessible by other devices in the swarm""" + + def __init__(self, *, probe: Optional[P2P] = None, wait_timeout: float = 5.0): + self.probe = probe + self.wait_timeout = wait_timeout + self._event_loop = self._stop = None + + async def call_check(self, remote_peer: PeerID, *, check_peer: PeerID) -> Optional[bool]: + """Returns True if remote_peer can reach check_peer, False if it cannot, None if it did not respond""" + try: + request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=check_peer.to_bytes())) + timeout = self.wait_timeout if check_peer == remote_peer else self.wait_timeout * 2 + response = await self.get_stub(self.probe, remote_peer).rpc_check(request, timeout=timeout) + logger.debug(f"call_check(remote_peer={remote_peer}, check_peer={check_peer}) -> {response.available}") + return response.available + except Exception as e: + logger.debug(f"Requested {remote_peer} to check {check_peer}, but got:", exc_info=True) + return None + + async def rpc_check(self, request: dht_pb2.PingRequest, context: P2PContext) -> dht_pb2.PingResponse: + """Help another peer to check its reachability""" + response = dht_pb2.PingResponse(available=True) + check_peer = PeerID(request.peer.node_id) + if check_peer != context.local_id: # remote peer wants us to check someone other than ourselves + response.available = await self.call_check(check_peer, check_peer=check_peer) is True + logger.info( + f"reachability.rpc_check(remote_peer=...{str(context.remote_id)[-6:]}, " + f"check_peer=...{str(check_peer)[-6:]}) -> {response.available}" + ) + return response + + @asynccontextmanager + async def serve(self, p2p: P2P): + try: + await self.add_p2p_handlers(p2p) + yield self + finally: + await self.remove_p2p_handlers(p2p) + + @classmethod + def attach_to_dht(cls, dht: DHT, await_ready: bool = False, **kwargs) -> Optional["ReachabilityProtocol"]: + protocol = cls(**kwargs) + ready = Future() + + async def _serve_with_probe(): + try: + common_p2p = await dht.replicate_p2p() + protocol._event_loop = asyncio.get_event_loop() + protocol._stop = asyncio.Event() + + initial_peers = [str(addr) for addr in await common_p2p.get_visible_maddrs(latest=True)] + for info in await common_p2p.list_peers(): + initial_peers.extend(f"{addr}/p2p/{info.peer_id}" for addr in info.addrs) + protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS) + + ready.set_result(True) + logger.info("Reachability service started") + + async with protocol.serve(common_p2p): + await protocol._stop.wait() + except Exception as e: + logger.warning(f"Reachability service failed: {repr(e)}") + logger.debug("See detailed traceback below:", exc_info=True) + + if not ready.done(): + ready.set_exception(e) + finally: + if protocol is not None and protocol.probe is not None: + await protocol.probe.shutdown() + logger.debug("Reachability service shut down") + + threading.Thread(target=partial(asyncio.run, _serve_with_probe()), daemon=True).start() + if await_ready: + ready.result() # Propagates startup exceptions, if any + return protocol + + def shutdown(self): + if self._event_loop is not None and self._stop is not None: + self._event_loop.call_soon_threadsafe(self._stop.set) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 57d743e..7e76080 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -26,7 +26,7 @@ from petals.server.backend import TransformerBackend from petals.server.block_utils import get_block_size from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache -from petals.server.reachability import check_reachability +from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability from petals.server.throughput import get_dtype_name, get_host_throughput from petals.utils.convert_block import check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR @@ -77,6 +77,7 @@ class Server: load_in_8bit: Optional[bool] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None, skip_reachability_check: bool = False, + dht_client_mode: Optional[bool] = None, use_relay: bool = True, use_auto_relay: bool = True, **kwargs, @@ -118,20 +119,27 @@ class Server: ) self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] + if dht_client_mode is None: + is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs) + dht_client_mode = is_reachable is False # if could not check reachability (returns None), run a full peer + logger.info(f"This server will run DHT in {'client' if dht_client_mode else 'full peer'} mode") self.dht = DHT( initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, use_relay=use_relay, use_auto_relay=use_auto_relay, + client_mode=dht_client_mode, **kwargs, ) + self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not dht_client_mode else None + visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()] if initial_peers == PUBLIC_INITIAL_PEERS: logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}") else: logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") - self.need_reachability_check = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS + self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" @@ -277,7 +285,7 @@ class Server: use_auth_token=self.use_auth_token, load_in_8bit=self.load_in_8bit, tensor_parallel_devices=self.tensor_parallel_devices, - need_reachability_check=self.need_reachability_check, + should_validate_reachability=self.should_validate_reachability, start=True, ) try: @@ -335,6 +343,8 @@ class Server: def shutdown(self): self.stop.set() + if self.reachability_protocol is not None: + self.reachability_protocol.shutdown() self.dht.shutdown() self.dht.join() @@ -367,7 +377,7 @@ class ModuleContainer(threading.Thread): use_auth_token: Optional[str], load_in_8bit: bool, tensor_parallel_devices: Sequence[torch.device], - need_reachability_check: bool, + should_validate_reachability: bool, **kwargs, ) -> ModuleContainer: module_uids = [f"{prefix}.{block_index}" for block_index in block_indices] @@ -422,8 +432,8 @@ class ModuleContainer(threading.Thread): max_batch_size=max_batch_size, ) - if need_reachability_check: - check_reachability(dht.peer_id) + if should_validate_reachability: + validate_reachability(dht.peer_id) except: logger.debug("Shutting down backends") for backend in blocks.values():