@ -3,6 +3,7 @@ from __future__ import annotations
import gc
import math
import multiprocessing as mp
import os
import random
import threading
import time
@ -21,7 +22,7 @@ from transformers import PretrainedConfig
import petals
from petals . constants import DTYPE_MAP , PUBLIC_INITIAL_PEERS
from petals . data_structures import CHAIN_DELIMITER , UID_DELIMITER , ServerInfo, ServerState
from petals . data_structures import CHAIN_DELIMITER , UID_DELIMITER , ModelInfo, ServerInfo, ServerState
from petals . server import block_selection
from petals . server . backend import TransformerBackend , merge_inference_pools_inplace
from petals . server . block_utils import get_block_size , resolve_block_dtype
@ -259,6 +260,9 @@ class Server:
using_relay = reachable_via_relay ,
* * throughput_info ,
)
self . model_info = ModelInfo ( num_blocks = self . block_config . num_hidden_layers )
if not os . path . isdir ( converted_model_name_or_path ) :
self . model_info . repository = " https://huggingface.co/ " + converted_model_name_or_path
self . balance_quality = balance_quality
self . mean_balance_check_period = mean_balance_check_period
@ -330,6 +334,7 @@ class Server:
block_config = self . block_config ,
attn_cache_bytes = self . attn_cache_bytes ,
server_info = self . server_info ,
model_info = self . model_info ,
block_indices = block_indices ,
num_handlers = self . num_handlers ,
min_batch_size = self . min_batch_size ,
@ -436,6 +441,7 @@ class ModuleContainer(threading.Thread):
block_config : PretrainedConfig ,
attn_cache_bytes : int ,
server_info : ServerInfo ,
model_info : ModelInfo ,
block_indices : List [ int ] ,
min_batch_size : int ,
max_batch_size : int ,
@ -463,6 +469,7 @@ class ModuleContainer(threading.Thread):
module_uids ,
dht ,
server_info ,
model_info ,
block_config = block_config ,
memory_cache = memory_cache ,
update_period = update_period ,
@ -671,6 +678,7 @@ class ModuleAnnouncerThread(threading.Thread):
module_uids : List [ str ] ,
dht : DHT ,
server_info : ServerInfo ,
model_info : ModelInfo ,
* ,
block_config : PretrainedConfig ,
memory_cache : MemoryCache ,
@ -683,6 +691,7 @@ class ModuleAnnouncerThread(threading.Thread):
self . module_uids = module_uids
self . dht = dht
self . server_info = server_info
self . model_info = model_info
self . memory_cache = memory_cache
self . bytes_per_token = block_config . hidden_size * get_size_in_bytes ( DTYPE_MAP [ server_info . torch_dtype ] )
@ -693,10 +702,10 @@ class ModuleAnnouncerThread(threading.Thread):
self . trigger = threading . Event ( )
self . max_pinged = max_pinged
dht_prefix = module_uids [ 0 ] . split ( UID_DELIMITER ) [ 0 ]
self . dht_prefix = module_uids [ 0 ] . split ( UID_DELIMITER ) [ 0 ]
block_indices = [ int ( uid . split ( UID_DELIMITER ) [ - 1 ] ) for uid in module_uids ]
start_block , end_block = min ( block_indices ) , max ( block_indices ) + 1
self . next_uids = [ f " { dht_prefix } { UID_DELIMITER } { i } " for i in range ( start_block + 1 , end_block + 1 ) ]
self . next_uids = [ f " { self . dht_prefix } { UID_DELIMITER } { i } " for i in range ( start_block + 1 , end_block + 1 ) ]
self . ping_aggregator = PingAggregator ( self . dht )
def run ( self ) - > None :
@ -720,6 +729,13 @@ class ModuleAnnouncerThread(threading.Thread):
)
if self . server_info . state == ServerState . OFFLINE :
break
if not self . dht_prefix . startswith ( " _ " ) : # Not private
self . dht . store (
key = " _petals.models " ,
subkey = self . dht_prefix ,
value = self . model_info . to_dict ( ) ,
expiration_time = get_dht_time ( ) + self . expiration ,
)
delay = self . update_period - ( time . perf_counter ( ) - start_time )
if delay < 0 :