Make a server ping next servers (#356)

This PR makes a server ping potential next servers in a chain and report the RTTs to DHT. This will be used for shortest-path routing.
pull/357/head
Alexander Borzunov 10 months ago committed by GitHub
parent 2c8959e713
commit 81c4a45ca2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -98,7 +98,7 @@ def main():
'If set to "auto" (default), the script evaluates network and compute throughput '
'on the first run and uses these estimates for future runs. '
'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
parser.add_argument('--update_period', type=float, required=False, default=150,
parser.add_argument('--update_period', type=float, required=False, default=60,
help='Server will report blocks to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,
help='DHT entries will expire after this many seconds')

@ -30,6 +30,7 @@ class ServerInfo:
quant_type: Optional[str] = None
using_relay: Optional[bool] = None
cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None
next_pings: Optional[Dict[str, pydantic.confloat(ge=0, strict=True)]] = None
def to_tuple(self) -> Tuple[int, float, dict]:
extra_info = dataclasses.asdict(self)

@ -8,6 +8,7 @@ import threading
import time
from typing import Dict, List, Optional, Sequence, Union
import hivemind
import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file
@ -30,6 +31,7 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
from petals.utils.ping import PingAggregator
from petals.utils.version import get_compatible_model_repo
logger = get_logger(__name__)
@ -64,7 +66,7 @@ class Server:
compression=CompressionType.NONE,
stats_report_interval: Optional[int] = None,
custom_module_path=None,
update_period: float = 150,
update_period: float = 60,
expiration: Optional[float] = None,
request_timeout: float = 3 * 60,
session_timeout: float = 30 * 60,
@ -220,7 +222,7 @@ class Server:
throughput=throughput,
adapters=tuple(adapters),
version=petals.__version__,
torch_dtype=str(torch_dtype).lstrip("torch."),
torch_dtype=str(torch_dtype).replace("torch.", ""),
quant_type=quant_type.name.lower(),
using_relay=self.dht.client_mode,
)
@ -413,8 +415,8 @@ class ModuleContainer(threading.Thread):
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
server_info.state = ServerState.JOINING
joining_announcer = ModuleAnnouncerThread(
assert server_info.state == ServerState.JOINING
dht_announcer = ModuleAnnouncerThread(
module_uids,
dht,
server_info,
@ -424,7 +426,7 @@ class ModuleContainer(threading.Thread):
expiration=expiration,
daemon=True,
)
joining_announcer.start()
dht_announcer.start()
logger.info(f"Announced that blocks {block_indices} are joining")
assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
@ -476,6 +478,8 @@ class ModuleContainer(threading.Thread):
max_batch_size=max_batch_size,
)
merge_inference_pools_inplace(blocks)
if should_validate_reachability:
validate_reachability(dht.peer_id)
except:
@ -483,29 +487,15 @@ class ModuleContainer(threading.Thread):
for backend in blocks.values():
backend.shutdown()
joining_announcer.stop.set()
joining_announcer.join()
server_info.state = ServerState.OFFLINE
declare_active_modules(
dht,
module_uids,
server_info,
expiration_time=get_dht_time() + expiration,
)
dht_announcer.announce(ServerState.OFFLINE)
logger.info(f"Announced that blocks {module_uids} are offline")
raise
else:
joining_announcer.stop.set()
joining_announcer.join()
merge_inference_pools_inplace(blocks)
return cls(
dht,
dht_prefix,
blocks,
block_config=block_config,
memory_cache=memory_cache,
dht_announcer=dht_announcer,
server_info=server_info,
update_period=update_period,
expiration=expiration,
@ -518,10 +508,9 @@ class ModuleContainer(threading.Thread):
dht_prefix: str,
module_backends: Dict[str, TransformerBackend],
*,
block_config: PretrainedConfig,
memory_cache: MemoryCache,
inference_max_length: int,
num_handlers: int,
dht_announcer: ModuleAnnouncerThread,
server_info: ServerInfo,
update_period: float,
expiration: Optional[float] = None,
@ -558,17 +547,8 @@ class ModuleContainer(threading.Thread):
self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
# note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
self.server_info.state = ServerState.ONLINE
self.online_announcer = ModuleAnnouncerThread(
list(self.module_backends.keys()),
dht,
self.server_info,
block_config=block_config,
memory_cache=memory_cache,
update_period=update_period,
expiration=expiration,
daemon=True,
)
dht_announcer.announce(ServerState.ONLINE)
self.dht_announcer = dht_announcer
if start:
self.run_in_background(await_ready=True)
@ -578,11 +558,6 @@ class ModuleContainer(threading.Thread):
Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
runs Runtime (self.runtime) to process incoming requests.
"""
if not self.dht.is_alive():
self.dht.run_in_background(await_ready=True)
self.online_announcer.start()
for handler in self.conn_handlers:
handler.run_in_background()
@ -621,16 +596,7 @@ class ModuleContainer(threading.Thread):
Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
"""
self.online_announcer.stop.set()
self.online_announcer.join()
self.server_info.state = ServerState.OFFLINE
declare_active_modules(
self.dht,
self.module_backends.keys(),
self.server_info,
expiration_time=get_dht_time() + self.expiration,
)
self.dht_announcer.announce(ServerState.OFFLINE)
logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
self.ready.clear()
@ -666,8 +632,10 @@ class ModuleAnnouncerThread(threading.Thread):
*,
block_config: PretrainedConfig,
memory_cache: MemoryCache,
update_period: float = 30,
update_period: float,
expiration: float,
max_pinged: int = 5,
max_reported: int = 10,
**kwargs,
):
super().__init__(**kwargs)
@ -678,20 +646,58 @@ class ModuleAnnouncerThread(threading.Thread):
self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
self.update_period = update_period
self.expiration = expiration
self.stop = threading.Event()
self.trigger = threading.Event()
self.max_pinged, self.max_reported = max_pinged, max_reported
last_uid = max(module_uids, key=lambda uid: int(uid.split(UID_DELIMITER)[-1]))
dht_prefix, block_index = last_uid.split(UID_DELIMITER)
self.next_uid = f"{dht_prefix}{UID_DELIMITER}{int(block_index) + 1}"
self.ping_aggregator = PingAggregator(self.dht)
def run(self) -> None:
while True:
start_time = time.perf_counter()
self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token
if self.server_info.state != ServerState.OFFLINE:
self._ping_next_servers()
self.server_info.next_pings = {
peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.fastest(self.max_reported).items()
}
else:
self.server_info.next_pings = None # No need to ping if we're disconnecting
declare_active_modules(
self.dht,
self.module_uids,
self.server_info,
expiration_time=get_dht_time() + self.expiration,
)
if self.stop.wait(self.update_period):
if self.server_info.state == ServerState.OFFLINE:
break
delay = self.update_period - (time.perf_counter() - start_time)
if delay < 0:
logger.warning("Declaring blocs to DHT takes more than --update_period, consider increasing it")
self.trigger.wait(max(delay, 0))
self.trigger.clear()
def announce(self, state: ServerState) -> None:
self.server_info.state = state
self.trigger.set()
if state == ServerState.OFFLINE:
self.join()
def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
[module_info] = get_remote_module_infos(self.dht, [self.next_uid], latest=True)
if module_info is None:
return
next_servers = list(module_info.servers)
if len(next_servers) > self.max_pinged:
next_servers = random.sample(next_servers, self.max_pinged)
self.ping_aggregator.ping(next_servers)
class RuntimeWithDeduplicatedPools(Runtime):
"""A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool"""

@ -0,0 +1,60 @@
import asyncio
import math
import time
from functools import partial
from typing import Dict, Sequence
import hivemind
from hivemind.proto import dht_pb2
from hivemind.utils.logging import get_logger
logger = get_logger(__name__)
async def ping(
peer_id: hivemind.PeerID,
_dht: hivemind.DHT,
node: hivemind.dht.DHTNode,
*,
wait_timeout: float = 1,
) -> float:
try:
ping_request = dht_pb2.PingRequest(peer=node.protocol.node_info)
start_time = time.perf_counter()
await node.protocol.get_stub(peer_id).rpc_ping(ping_request, timeout=wait_timeout)
return time.perf_counter() - start_time
except Exception:
logger.debug(f"Failed to ping {peer_id}:", exc_info=True)
return math.inf
async def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kwargs) -> Dict[hivemind.PeerID, float]:
rpc_infos = await asyncio.gather(*[ping(peer_id, *args, **kwargs) for peer_id in peer_ids])
return dict(zip(peer_ids, rpc_infos))
class PingAggregator:
def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 3600):
self.dht = dht
self.ema_alpha = ema_alpha
self.expiration = expiration
self.ping_emas = hivemind.TimedStorage()
def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs):
current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs))
logger.debug(f"Current RTTs: {current_rtts}")
expiration = hivemind.get_dht_time() + self.expiration
for peer_id, rtt in current_rtts.items():
prev_rtt = self.ping_emas.get(peer_id)
if prev_rtt is not None and prev_rtt.value != math.inf:
rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value # Exponential smoothing
self.ping_emas.store(peer_id, rtt, expiration)
def fastest(self, n_peers: int) -> Dict[hivemind.PeerID, float]:
with self.ping_emas.freeze():
smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()}
logger.debug(f"Smothed RTTs: {smoothed_rtts}")
fastest_rtts = sorted(smoothed_rtts.items(), key=lambda item: item[1])[:n_peers]
return dict(fastest_rtts)
Loading…
Cancel
Save