You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
165 lines
7.5 KiB
Python
165 lines
7.5 KiB
Python
import asyncio
|
|
import math
|
|
import threading
|
|
import time
|
|
from concurrent.futures import Future
|
|
from contextlib import asynccontextmanager
|
|
from functools import partial
|
|
from typing import Optional
|
|
|
|
import requests
|
|
from hivemind.dht import DHT, DHTNode
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
|
|
from hivemind.proto import dht_pb2
|
|
from hivemind.utils import get_logger
|
|
|
|
from petals.constants import REACHABILITY_API_URL
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:
|
|
"""verify that your peer is reachable from a (centralized) validator, whether directly or through a relay"""
|
|
for attempt_no in range(math.floor(wait_time / retry_delay) + 1):
|
|
try:
|
|
r = requests.get(f"{REACHABILITY_API_URL}/api/v1/is_reachable/{peer_id}", timeout=10)
|
|
r.raise_for_status()
|
|
response = r.json()
|
|
|
|
if response["success"]:
|
|
logger.info("Server is reachable from the Internet. It will appear at https://health.petals.dev soon")
|
|
return
|
|
|
|
if attempt_no == 0:
|
|
# Usually, libp2p manages to set up relays before we finish loading blocks.
|
|
# In other cases, we may need to wait for up to `wait_time` seconds before it's done.
|
|
logger.info("Detected a NAT or a firewall, connecting to libp2p relays. This takes a few minutes")
|
|
time.sleep(retry_delay)
|
|
except Exception as e:
|
|
logger.warning(f"Skipping reachability check because health.petals.dev is down: {repr(e)}")
|
|
return
|
|
|
|
raise RuntimeError(
|
|
f"Server has not become reachable from the Internet:\n\n"
|
|
f"{response['message']}\n\n"
|
|
f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
|
|
f" 1. Choose a specific port for the Petals server, for example, 31337.\n"
|
|
f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
|
|
f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
|
|
f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
|
|
f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
|
|
)
|
|
|
|
|
|
def check_direct_reachability(max_peers: int = 5, threshold: float = 0.5, **kwargs) -> Optional[bool]:
|
|
"""test if your peer is accessible by others in the swarm with the specified network options in **kwargs"""
|
|
|
|
async def _check_direct_reachability():
|
|
target_dht = await DHTNode.create(client_mode=True, **kwargs)
|
|
try:
|
|
protocol = ReachabilityProtocol(probe=target_dht.protocol.p2p)
|
|
async with protocol.serve(target_dht.protocol.p2p):
|
|
successes = requests = 0
|
|
for remote_peer in list(target_dht.protocol.routing_table.peer_id_to_uid.keys()):
|
|
probe_available = await protocol.call_check(remote_peer=remote_peer, check_peer=target_dht.peer_id)
|
|
if probe_available is None:
|
|
continue # remote peer failed to check probe
|
|
successes += probe_available
|
|
requests += 1
|
|
if requests >= max_peers:
|
|
break
|
|
|
|
logger.debug(f"Direct reachability: {successes}/{requests}")
|
|
return (successes / requests) >= threshold if requests > 0 else None
|
|
finally:
|
|
await target_dht.shutdown()
|
|
|
|
return RemoteExpertWorker.run_coroutine(_check_direct_reachability())
|
|
|
|
|
|
STRIPPED_PROBE_ARGS = dict(
|
|
dht_mode="client", use_relay=False, auto_nat=False, nat_port_map=False, no_listen=True, startup_timeout=60
|
|
)
|
|
|
|
|
|
class ReachabilityProtocol(ServicerBase):
|
|
"""Mini protocol to test if a locally running peer is accessible by other devices in the swarm"""
|
|
|
|
def __init__(self, *, probe: Optional[P2P] = None, wait_timeout: float = 5.0):
|
|
self.probe = probe
|
|
self.wait_timeout = wait_timeout
|
|
self._event_loop = self._stop = None
|
|
|
|
async def call_check(self, remote_peer: PeerID, *, check_peer: PeerID) -> Optional[bool]:
|
|
"""Returns True if remote_peer can reach check_peer, False if it cannot, None if it did not respond"""
|
|
try:
|
|
request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=check_peer.to_bytes()))
|
|
timeout = self.wait_timeout if check_peer == remote_peer else self.wait_timeout * 2
|
|
response = await self.get_stub(self.probe, remote_peer).rpc_check(request, timeout=timeout)
|
|
logger.debug(f"call_check(remote_peer={remote_peer}, check_peer={check_peer}) -> {response.available}")
|
|
return response.available
|
|
except Exception as e:
|
|
logger.debug(f"Requested {remote_peer} to check {check_peer}, but got:", exc_info=True)
|
|
return None
|
|
|
|
async def rpc_check(self, request: dht_pb2.PingRequest, context: P2PContext) -> dht_pb2.PingResponse:
|
|
"""Help another peer to check its reachability"""
|
|
response = dht_pb2.PingResponse(available=True)
|
|
check_peer = PeerID(request.peer.node_id)
|
|
if check_peer != context.local_id: # remote peer wants us to check someone other than ourselves
|
|
response.available = await self.call_check(check_peer, check_peer=check_peer) is True
|
|
logger.info(
|
|
f"reachability.rpc_check(remote_peer=...{str(context.remote_id)[-6:]}, "
|
|
f"check_peer=...{str(check_peer)[-6:]}) -> {response.available}"
|
|
)
|
|
return response
|
|
|
|
@asynccontextmanager
|
|
async def serve(self, p2p: P2P):
|
|
try:
|
|
await self.add_p2p_handlers(p2p)
|
|
yield self
|
|
finally:
|
|
await self.remove_p2p_handlers(p2p)
|
|
|
|
@classmethod
|
|
def attach_to_dht(cls, dht: DHT, await_ready: bool = False, **kwargs) -> Optional["ReachabilityProtocol"]:
|
|
protocol = cls(**kwargs)
|
|
ready = Future()
|
|
|
|
async def _serve_with_probe():
|
|
try:
|
|
common_p2p = await dht.replicate_p2p()
|
|
protocol._event_loop = asyncio.get_event_loop()
|
|
protocol._stop = asyncio.Event()
|
|
|
|
initial_peers = [str(addr) for addr in await common_p2p.get_visible_maddrs(latest=True)]
|
|
for info in await common_p2p.list_peers():
|
|
initial_peers.extend(f"{addr}/p2p/{info.peer_id}" for addr in info.addrs)
|
|
protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)
|
|
|
|
ready.set_result(True)
|
|
logger.debug("Reachability service started")
|
|
|
|
async with protocol.serve(common_p2p):
|
|
await protocol._stop.wait()
|
|
except Exception as e:
|
|
logger.debug("Reachability service failed:", exc_info=True)
|
|
|
|
if not ready.done():
|
|
ready.set_exception(e)
|
|
finally:
|
|
if protocol is not None and protocol.probe is not None:
|
|
await protocol.probe.shutdown()
|
|
logger.debug("Reachability service shut down")
|
|
|
|
threading.Thread(target=partial(asyncio.run, _serve_with_probe()), daemon=True).start()
|
|
if await_ready:
|
|
ready.result() # Propagates startup exceptions, if any
|
|
return protocol
|
|
|
|
def shutdown(self):
|
|
if self._event_loop is not None and self._stop is not None:
|
|
self._event_loop.call_soon_threadsafe(self._stop.set)
|