Create model index in DHT (#491)

This PR creates an index of models hosted in the swarm - it is useful to know which custom models users run and display them at https://health.petals.dev as "not officially supported" models.
pull/493/head
Alexander Borzunov 8 months ago committed by GitHub
parent 6bb3f54e39
commit 6ef6bf5fa2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -343,7 +343,7 @@ class InferenceSession:
n_prev_spans = len(self._server_sessions)
update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
if attempt_no >= 1:
logger.info(
logger.debug(
f"Due to a server failure, remote attention caches "
f"from block {block_idx} to {update_end} will be regenerated"
)

@ -20,6 +20,19 @@ class ServerState(Enum):
RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
@pydantic.dataclasses.dataclass
class ModelInfo:
num_blocks: int
repository: Optional[str] = None
def to_dict(self) -> dict:
return dataclasses.asdict(self)
@classmethod
def from_dict(cls, source: dict):
return cls(**source)
@pydantic.dataclasses.dataclass
class ServerInfo:
state: ServerState

@ -3,6 +3,7 @@ from __future__ import annotations
import gc
import math
import multiprocessing as mp
import os
import random
import threading
import time
@ -21,7 +22,7 @@ from transformers import PretrainedConfig
import petals
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState
from petals.server import block_selection
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
from petals.server.block_utils import get_block_size, resolve_block_dtype
@ -259,6 +260,9 @@ class Server:
using_relay=reachable_via_relay,
**throughput_info,
)
self.model_info = ModelInfo(num_blocks=self.block_config.num_hidden_layers)
if not os.path.isdir(converted_model_name_or_path):
self.model_info.repository = "https://huggingface.co/" + converted_model_name_or_path
self.balance_quality = balance_quality
self.mean_balance_check_period = mean_balance_check_period
@ -330,6 +334,7 @@ class Server:
block_config=self.block_config,
attn_cache_bytes=self.attn_cache_bytes,
server_info=self.server_info,
model_info=self.model_info,
block_indices=block_indices,
num_handlers=self.num_handlers,
min_batch_size=self.min_batch_size,
@ -436,6 +441,7 @@ class ModuleContainer(threading.Thread):
block_config: PretrainedConfig,
attn_cache_bytes: int,
server_info: ServerInfo,
model_info: ModelInfo,
block_indices: List[int],
min_batch_size: int,
max_batch_size: int,
@ -463,6 +469,7 @@ class ModuleContainer(threading.Thread):
module_uids,
dht,
server_info,
model_info,
block_config=block_config,
memory_cache=memory_cache,
update_period=update_period,
@ -671,6 +678,7 @@ class ModuleAnnouncerThread(threading.Thread):
module_uids: List[str],
dht: DHT,
server_info: ServerInfo,
model_info: ModelInfo,
*,
block_config: PretrainedConfig,
memory_cache: MemoryCache,
@ -683,6 +691,7 @@ class ModuleAnnouncerThread(threading.Thread):
self.module_uids = module_uids
self.dht = dht
self.server_info = server_info
self.model_info = model_info
self.memory_cache = memory_cache
self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype])
@ -693,10 +702,10 @@ class ModuleAnnouncerThread(threading.Thread):
self.trigger = threading.Event()
self.max_pinged = max_pinged
dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
start_block, end_block = min(block_indices), max(block_indices) + 1
self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
self.ping_aggregator = PingAggregator(self.dht)
def run(self) -> None:
@ -720,6 +729,13 @@ class ModuleAnnouncerThread(threading.Thread):
)
if self.server_info.state == ServerState.OFFLINE:
break
if not self.dht_prefix.startswith("_"): # Not private
self.dht.store(
key="_petals.models",
subkey=self.dht_prefix,
value=self.model_info.to_dict(),
expiration_time=get_dht_time() + self.expiration,
)
delay = self.update_period - (time.perf_counter() - start_time)
if delay < 0:

Loading…
Cancel
Save