|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
|
|
import gc
|
|
|
|
|
import math
|
|
|
|
|
import multiprocessing as mp
|
|
|
|
|
import os
|
|
|
|
|
import random
|
|
|
|
|
import threading
|
|
|
|
|
import time
|
|
|
|
@ -21,7 +22,7 @@ from transformers import PretrainedConfig
|
|
|
|
|
|
|
|
|
|
import petals
|
|
|
|
|
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
|
|
|
|
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
|
|
|
|
|
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState
|
|
|
|
|
from petals.server import block_selection
|
|
|
|
|
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
|
|
|
|
|
from petals.server.block_utils import get_block_size, resolve_block_dtype
|
|
|
|
@ -259,6 +260,9 @@ class Server:
|
|
|
|
|
using_relay=reachable_via_relay,
|
|
|
|
|
**throughput_info,
|
|
|
|
|
)
|
|
|
|
|
self.model_info = ModelInfo(num_blocks=self.block_config.num_hidden_layers)
|
|
|
|
|
if not os.path.isdir(converted_model_name_or_path):
|
|
|
|
|
self.model_info.repository = "https://huggingface.co/" + converted_model_name_or_path
|
|
|
|
|
|
|
|
|
|
self.balance_quality = balance_quality
|
|
|
|
|
self.mean_balance_check_period = mean_balance_check_period
|
|
|
|
@ -330,6 +334,7 @@ class Server:
|
|
|
|
|
block_config=self.block_config,
|
|
|
|
|
attn_cache_bytes=self.attn_cache_bytes,
|
|
|
|
|
server_info=self.server_info,
|
|
|
|
|
model_info=self.model_info,
|
|
|
|
|
block_indices=block_indices,
|
|
|
|
|
num_handlers=self.num_handlers,
|
|
|
|
|
min_batch_size=self.min_batch_size,
|
|
|
|
@ -436,6 +441,7 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
block_config: PretrainedConfig,
|
|
|
|
|
attn_cache_bytes: int,
|
|
|
|
|
server_info: ServerInfo,
|
|
|
|
|
model_info: ModelInfo,
|
|
|
|
|
block_indices: List[int],
|
|
|
|
|
min_batch_size: int,
|
|
|
|
|
max_batch_size: int,
|
|
|
|
@ -463,6 +469,7 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
module_uids,
|
|
|
|
|
dht,
|
|
|
|
|
server_info,
|
|
|
|
|
model_info,
|
|
|
|
|
block_config=block_config,
|
|
|
|
|
memory_cache=memory_cache,
|
|
|
|
|
update_period=update_period,
|
|
|
|
@ -671,6 +678,7 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
|
|
module_uids: List[str],
|
|
|
|
|
dht: DHT,
|
|
|
|
|
server_info: ServerInfo,
|
|
|
|
|
model_info: ModelInfo,
|
|
|
|
|
*,
|
|
|
|
|
block_config: PretrainedConfig,
|
|
|
|
|
memory_cache: MemoryCache,
|
|
|
|
@ -683,6 +691,7 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
|
|
self.module_uids = module_uids
|
|
|
|
|
self.dht = dht
|
|
|
|
|
self.server_info = server_info
|
|
|
|
|
self.model_info = model_info
|
|
|
|
|
self.memory_cache = memory_cache
|
|
|
|
|
|
|
|
|
|
self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype])
|
|
|
|
@ -693,10 +702,10 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
|
|
self.trigger = threading.Event()
|
|
|
|
|
|
|
|
|
|
self.max_pinged = max_pinged
|
|
|
|
|
dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
|
|
|
|
|
self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
|
|
|
|
|
block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
|
|
|
|
|
start_block, end_block = min(block_indices), max(block_indices) + 1
|
|
|
|
|
self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
|
|
|
|
|
self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
|
|
|
|
|
self.ping_aggregator = PingAggregator(self.dht)
|
|
|
|
|
|
|
|
|
|
def run(self) -> None:
|
|
|
|
@ -720,6 +729,13 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
|
|
)
|
|
|
|
|
if self.server_info.state == ServerState.OFFLINE:
|
|
|
|
|
break
|
|
|
|
|
if not self.dht_prefix.startswith("_"): # Not private
|
|
|
|
|
self.dht.store(
|
|
|
|
|
key="_petals.models",
|
|
|
|
|
subkey=self.dht_prefix,
|
|
|
|
|
value=self.model_info.to_dict(),
|
|
|
|
|
expiration_time=get_dht_time() + self.expiration,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
delay = self.update_period - (time.perf_counter() - start_time)
|
|
|
|
|
if delay < 0:
|
|
|
|
|