Create model index in DHT (#491)

This PR creates an index of models hosted in the swarm - it is useful to know which custom models users run and display them at https://health.petals.dev as "not officially supported" models.
pull/493/head
Alexander Borzunov 9 months ago committed by GitHub
parent 6bb3f54e39
commit 6ef6bf5fa2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -343,7 +343,7 @@ class InferenceSession:
n_prev_spans = len(self._server_sessions) n_prev_spans = len(self._server_sessions)
update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
if attempt_no >= 1: if attempt_no >= 1:
logger.info( logger.debug(
f"Due to a server failure, remote attention caches " f"Due to a server failure, remote attention caches "
f"from block {block_idx} to {update_end} will be regenerated" f"from block {block_idx} to {update_end} will be regenerated"
) )

@ -20,6 +20,19 @@ class ServerState(Enum):
RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True) RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
@pydantic.dataclasses.dataclass
class ModelInfo:
num_blocks: int
repository: Optional[str] = None
def to_dict(self) -> dict:
return dataclasses.asdict(self)
@classmethod
def from_dict(cls, source: dict):
return cls(**source)
@pydantic.dataclasses.dataclass @pydantic.dataclasses.dataclass
class ServerInfo: class ServerInfo:
state: ServerState state: ServerState

@ -3,6 +3,7 @@ from __future__ import annotations
import gc import gc
import math import math
import multiprocessing as mp import multiprocessing as mp
import os
import random import random
import threading import threading
import time import time
@ -21,7 +22,7 @@ from transformers import PretrainedConfig
import petals import petals
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState
from petals.server import block_selection from petals.server import block_selection
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
from petals.server.block_utils import get_block_size, resolve_block_dtype from petals.server.block_utils import get_block_size, resolve_block_dtype
@ -259,6 +260,9 @@ class Server:
using_relay=reachable_via_relay, using_relay=reachable_via_relay,
**throughput_info, **throughput_info,
) )
self.model_info = ModelInfo(num_blocks=self.block_config.num_hidden_layers)
if not os.path.isdir(converted_model_name_or_path):
self.model_info.repository = "https://huggingface.co/" + converted_model_name_or_path
self.balance_quality = balance_quality self.balance_quality = balance_quality
self.mean_balance_check_period = mean_balance_check_period self.mean_balance_check_period = mean_balance_check_period
@ -330,6 +334,7 @@ class Server:
block_config=self.block_config, block_config=self.block_config,
attn_cache_bytes=self.attn_cache_bytes, attn_cache_bytes=self.attn_cache_bytes,
server_info=self.server_info, server_info=self.server_info,
model_info=self.model_info,
block_indices=block_indices, block_indices=block_indices,
num_handlers=self.num_handlers, num_handlers=self.num_handlers,
min_batch_size=self.min_batch_size, min_batch_size=self.min_batch_size,
@ -436,6 +441,7 @@ class ModuleContainer(threading.Thread):
block_config: PretrainedConfig, block_config: PretrainedConfig,
attn_cache_bytes: int, attn_cache_bytes: int,
server_info: ServerInfo, server_info: ServerInfo,
model_info: ModelInfo,
block_indices: List[int], block_indices: List[int],
min_batch_size: int, min_batch_size: int,
max_batch_size: int, max_batch_size: int,
@ -463,6 +469,7 @@ class ModuleContainer(threading.Thread):
module_uids, module_uids,
dht, dht,
server_info, server_info,
model_info,
block_config=block_config, block_config=block_config,
memory_cache=memory_cache, memory_cache=memory_cache,
update_period=update_period, update_period=update_period,
@ -671,6 +678,7 @@ class ModuleAnnouncerThread(threading.Thread):
module_uids: List[str], module_uids: List[str],
dht: DHT, dht: DHT,
server_info: ServerInfo, server_info: ServerInfo,
model_info: ModelInfo,
*, *,
block_config: PretrainedConfig, block_config: PretrainedConfig,
memory_cache: MemoryCache, memory_cache: MemoryCache,
@ -683,6 +691,7 @@ class ModuleAnnouncerThread(threading.Thread):
self.module_uids = module_uids self.module_uids = module_uids
self.dht = dht self.dht = dht
self.server_info = server_info self.server_info = server_info
self.model_info = model_info
self.memory_cache = memory_cache self.memory_cache = memory_cache
self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype]) self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype])
@ -693,10 +702,10 @@ class ModuleAnnouncerThread(threading.Thread):
self.trigger = threading.Event() self.trigger = threading.Event()
self.max_pinged = max_pinged self.max_pinged = max_pinged
dht_prefix = module_uids[0].split(UID_DELIMITER)[0] self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids] block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
start_block, end_block = min(block_indices), max(block_indices) + 1 start_block, end_block = min(block_indices), max(block_indices) + 1
self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)] self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
self.ping_aggregator = PingAggregator(self.dht) self.ping_aggregator = PingAggregator(self.dht)
def run(self) -> None: def run(self) -> None:
@ -720,6 +729,13 @@ class ModuleAnnouncerThread(threading.Thread):
) )
if self.server_info.state == ServerState.OFFLINE: if self.server_info.state == ServerState.OFFLINE:
break break
if not self.dht_prefix.startswith("_"): # Not private
self.dht.store(
key="_petals.models",
subkey=self.dht_prefix,
value=self.model_info.to_dict(),
expiration_time=get_dht_time() + self.expiration,
)
delay = self.update_period - (time.perf_counter() - start_time) delay = self.update_period - (time.perf_counter() - start_time)
if delay < 0: if delay < 0:

Loading…
Cancel
Save