|
|
|
@ -8,6 +8,7 @@ import threading
|
|
|
|
|
import time
|
|
|
|
|
from typing import Dict, List, Optional, Sequence, Union
|
|
|
|
|
|
|
|
|
|
import hivemind
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
|
|
|
|
from hivemind.moe.server.layers import add_custom_models_from_file
|
|
|
|
@ -30,6 +31,7 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
|
|
|
|
|
from petals.server.throughput import get_dtype_name, get_server_throughput
|
|
|
|
|
from petals.utils.auto_config import AutoDistributedConfig
|
|
|
|
|
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
|
|
|
|
|
from petals.utils.ping import PingAggregator
|
|
|
|
|
from petals.utils.version import get_compatible_model_repo
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
@ -64,7 +66,7 @@ class Server:
|
|
|
|
|
compression=CompressionType.NONE,
|
|
|
|
|
stats_report_interval: Optional[int] = None,
|
|
|
|
|
custom_module_path=None,
|
|
|
|
|
update_period: float = 150,
|
|
|
|
|
update_period: float = 60,
|
|
|
|
|
expiration: Optional[float] = None,
|
|
|
|
|
request_timeout: float = 3 * 60,
|
|
|
|
|
session_timeout: float = 30 * 60,
|
|
|
|
@ -220,7 +222,7 @@ class Server:
|
|
|
|
|
throughput=throughput,
|
|
|
|
|
adapters=tuple(adapters),
|
|
|
|
|
version=petals.__version__,
|
|
|
|
|
torch_dtype=str(torch_dtype).lstrip("torch."),
|
|
|
|
|
torch_dtype=str(torch_dtype).replace("torch.", ""),
|
|
|
|
|
quant_type=quant_type.name.lower(),
|
|
|
|
|
using_relay=self.dht.client_mode,
|
|
|
|
|
)
|
|
|
|
@ -413,8 +415,8 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
|
|
|
|
|
memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
|
|
|
|
|
|
|
|
|
|
server_info.state = ServerState.JOINING
|
|
|
|
|
joining_announcer = ModuleAnnouncerThread(
|
|
|
|
|
assert server_info.state == ServerState.JOINING
|
|
|
|
|
dht_announcer = ModuleAnnouncerThread(
|
|
|
|
|
module_uids,
|
|
|
|
|
dht,
|
|
|
|
|
server_info,
|
|
|
|
@ -424,7 +426,7 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
expiration=expiration,
|
|
|
|
|
daemon=True,
|
|
|
|
|
)
|
|
|
|
|
joining_announcer.start()
|
|
|
|
|
dht_announcer.start()
|
|
|
|
|
logger.info(f"Announced that blocks {block_indices} are joining")
|
|
|
|
|
|
|
|
|
|
assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
|
|
|
|
@ -476,6 +478,8 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
max_batch_size=max_batch_size,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
merge_inference_pools_inplace(blocks)
|
|
|
|
|
|
|
|
|
|
if should_validate_reachability:
|
|
|
|
|
validate_reachability(dht.peer_id)
|
|
|
|
|
except:
|
|
|
|
@ -483,29 +487,15 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
for backend in blocks.values():
|
|
|
|
|
backend.shutdown()
|
|
|
|
|
|
|
|
|
|
joining_announcer.stop.set()
|
|
|
|
|
joining_announcer.join()
|
|
|
|
|
server_info.state = ServerState.OFFLINE
|
|
|
|
|
declare_active_modules(
|
|
|
|
|
dht,
|
|
|
|
|
module_uids,
|
|
|
|
|
server_info,
|
|
|
|
|
expiration_time=get_dht_time() + expiration,
|
|
|
|
|
)
|
|
|
|
|
dht_announcer.announce(ServerState.OFFLINE)
|
|
|
|
|
logger.info(f"Announced that blocks {module_uids} are offline")
|
|
|
|
|
raise
|
|
|
|
|
else:
|
|
|
|
|
joining_announcer.stop.set()
|
|
|
|
|
joining_announcer.join()
|
|
|
|
|
|
|
|
|
|
merge_inference_pools_inplace(blocks)
|
|
|
|
|
|
|
|
|
|
return cls(
|
|
|
|
|
dht,
|
|
|
|
|
dht_prefix,
|
|
|
|
|
blocks,
|
|
|
|
|
block_config=block_config,
|
|
|
|
|
memory_cache=memory_cache,
|
|
|
|
|
dht_announcer=dht_announcer,
|
|
|
|
|
server_info=server_info,
|
|
|
|
|
update_period=update_period,
|
|
|
|
|
expiration=expiration,
|
|
|
|
@ -518,10 +508,9 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
dht_prefix: str,
|
|
|
|
|
module_backends: Dict[str, TransformerBackend],
|
|
|
|
|
*,
|
|
|
|
|
block_config: PretrainedConfig,
|
|
|
|
|
memory_cache: MemoryCache,
|
|
|
|
|
inference_max_length: int,
|
|
|
|
|
num_handlers: int,
|
|
|
|
|
dht_announcer: ModuleAnnouncerThread,
|
|
|
|
|
server_info: ServerInfo,
|
|
|
|
|
update_period: float,
|
|
|
|
|
expiration: Optional[float] = None,
|
|
|
|
@ -558,17 +547,8 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
|
|
|
|
|
# note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
|
|
|
|
|
|
|
|
|
|
self.server_info.state = ServerState.ONLINE
|
|
|
|
|
self.online_announcer = ModuleAnnouncerThread(
|
|
|
|
|
list(self.module_backends.keys()),
|
|
|
|
|
dht,
|
|
|
|
|
self.server_info,
|
|
|
|
|
block_config=block_config,
|
|
|
|
|
memory_cache=memory_cache,
|
|
|
|
|
update_period=update_period,
|
|
|
|
|
expiration=expiration,
|
|
|
|
|
daemon=True,
|
|
|
|
|
)
|
|
|
|
|
dht_announcer.announce(ServerState.ONLINE)
|
|
|
|
|
self.dht_announcer = dht_announcer
|
|
|
|
|
|
|
|
|
|
if start:
|
|
|
|
|
self.run_in_background(await_ready=True)
|
|
|
|
@ -578,11 +558,6 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
|
|
|
|
|
runs Runtime (self.runtime) to process incoming requests.
|
|
|
|
|
"""
|
|
|
|
|
if not self.dht.is_alive():
|
|
|
|
|
self.dht.run_in_background(await_ready=True)
|
|
|
|
|
|
|
|
|
|
self.online_announcer.start()
|
|
|
|
|
|
|
|
|
|
for handler in self.conn_handlers:
|
|
|
|
|
handler.run_in_background()
|
|
|
|
|
|
|
|
|
@ -621,16 +596,7 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
|
|
|
|
|
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
|
|
|
|
|
"""
|
|
|
|
|
self.online_announcer.stop.set()
|
|
|
|
|
self.online_announcer.join()
|
|
|
|
|
|
|
|
|
|
self.server_info.state = ServerState.OFFLINE
|
|
|
|
|
declare_active_modules(
|
|
|
|
|
self.dht,
|
|
|
|
|
self.module_backends.keys(),
|
|
|
|
|
self.server_info,
|
|
|
|
|
expiration_time=get_dht_time() + self.expiration,
|
|
|
|
|
)
|
|
|
|
|
self.dht_announcer.announce(ServerState.OFFLINE)
|
|
|
|
|
logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
|
|
|
|
|
|
|
|
|
|
self.ready.clear()
|
|
|
|
@ -666,8 +632,10 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
|
|
*,
|
|
|
|
|
block_config: PretrainedConfig,
|
|
|
|
|
memory_cache: MemoryCache,
|
|
|
|
|
update_period: float = 30,
|
|
|
|
|
update_period: float,
|
|
|
|
|
expiration: float,
|
|
|
|
|
max_pinged: int = 5,
|
|
|
|
|
max_reported: int = 10,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
@ -678,20 +646,58 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
|
|
self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
|
|
|
|
|
self.update_period = update_period
|
|
|
|
|
self.expiration = expiration
|
|
|
|
|
self.stop = threading.Event()
|
|
|
|
|
self.trigger = threading.Event()
|
|
|
|
|
|
|
|
|
|
self.max_pinged, self.max_reported = max_pinged, max_reported
|
|
|
|
|
last_uid = max(module_uids, key=lambda uid: int(uid.split(UID_DELIMITER)[-1]))
|
|
|
|
|
dht_prefix, block_index = last_uid.split(UID_DELIMITER)
|
|
|
|
|
self.next_uid = f"{dht_prefix}{UID_DELIMITER}{int(block_index) + 1}"
|
|
|
|
|
self.ping_aggregator = PingAggregator(self.dht)
|
|
|
|
|
|
|
|
|
|
def run(self) -> None:
|
|
|
|
|
while True:
|
|
|
|
|
start_time = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token
|
|
|
|
|
if self.server_info.state != ServerState.OFFLINE:
|
|
|
|
|
self._ping_next_servers()
|
|
|
|
|
self.server_info.next_pings = {
|
|
|
|
|
peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.fastest(self.max_reported).items()
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
self.server_info.next_pings = None # No need to ping if we're disconnecting
|
|
|
|
|
|
|
|
|
|
declare_active_modules(
|
|
|
|
|
self.dht,
|
|
|
|
|
self.module_uids,
|
|
|
|
|
self.server_info,
|
|
|
|
|
expiration_time=get_dht_time() + self.expiration,
|
|
|
|
|
)
|
|
|
|
|
if self.stop.wait(self.update_period):
|
|
|
|
|
if self.server_info.state == ServerState.OFFLINE:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
delay = self.update_period - (time.perf_counter() - start_time)
|
|
|
|
|
if delay < 0:
|
|
|
|
|
logger.warning("Declaring blocs to DHT takes more than --update_period, consider increasing it")
|
|
|
|
|
self.trigger.wait(max(delay, 0))
|
|
|
|
|
self.trigger.clear()
|
|
|
|
|
|
|
|
|
|
def announce(self, state: ServerState) -> None:
|
|
|
|
|
self.server_info.state = state
|
|
|
|
|
self.trigger.set()
|
|
|
|
|
if state == ServerState.OFFLINE:
|
|
|
|
|
self.join()
|
|
|
|
|
|
|
|
|
|
def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
|
|
|
|
|
[module_info] = get_remote_module_infos(self.dht, [self.next_uid], latest=True)
|
|
|
|
|
if module_info is None:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
next_servers = list(module_info.servers)
|
|
|
|
|
if len(next_servers) > self.max_pinged:
|
|
|
|
|
next_servers = random.sample(next_servers, self.max_pinged)
|
|
|
|
|
self.ping_aggregator.ping(next_servers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RuntimeWithDeduplicatedPools(Runtime):
|
|
|
|
|
"""A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool"""
|
|
|
|
|