@ -3,6 +3,7 @@ from __future__ import annotations
import gc
import gc
import math
import math
import multiprocessing as mp
import multiprocessing as mp
import os
import random
import random
import threading
import threading
import time
import time
@ -21,7 +22,7 @@ from transformers import PretrainedConfig
import petals
import petals
from petals . constants import DTYPE_MAP , PUBLIC_INITIAL_PEERS
from petals . constants import DTYPE_MAP , PUBLIC_INITIAL_PEERS
from petals . data_structures import CHAIN_DELIMITER , UID_DELIMITER , ServerInfo, ServerState
from petals . data_structures import CHAIN_DELIMITER , UID_DELIMITER , ModelInfo, ServerInfo, ServerState
from petals . server import block_selection
from petals . server import block_selection
from petals . server . backend import TransformerBackend , merge_inference_pools_inplace
from petals . server . backend import TransformerBackend , merge_inference_pools_inplace
from petals . server . block_utils import get_block_size , resolve_block_dtype
from petals . server . block_utils import get_block_size , resolve_block_dtype
@ -259,6 +260,9 @@ class Server:
using_relay = reachable_via_relay ,
using_relay = reachable_via_relay ,
* * throughput_info ,
* * throughput_info ,
)
)
self . model_info = ModelInfo ( num_blocks = self . block_config . num_hidden_layers )
if not os . path . isdir ( converted_model_name_or_path ) :
self . model_info . repository = " https://huggingface.co/ " + converted_model_name_or_path
self . balance_quality = balance_quality
self . balance_quality = balance_quality
self . mean_balance_check_period = mean_balance_check_period
self . mean_balance_check_period = mean_balance_check_period
@ -330,6 +334,7 @@ class Server:
block_config = self . block_config ,
block_config = self . block_config ,
attn_cache_bytes = self . attn_cache_bytes ,
attn_cache_bytes = self . attn_cache_bytes ,
server_info = self . server_info ,
server_info = self . server_info ,
model_info = self . model_info ,
block_indices = block_indices ,
block_indices = block_indices ,
num_handlers = self . num_handlers ,
num_handlers = self . num_handlers ,
min_batch_size = self . min_batch_size ,
min_batch_size = self . min_batch_size ,
@ -436,6 +441,7 @@ class ModuleContainer(threading.Thread):
block_config : PretrainedConfig ,
block_config : PretrainedConfig ,
attn_cache_bytes : int ,
attn_cache_bytes : int ,
server_info : ServerInfo ,
server_info : ServerInfo ,
model_info : ModelInfo ,
block_indices : List [ int ] ,
block_indices : List [ int ] ,
min_batch_size : int ,
min_batch_size : int ,
max_batch_size : int ,
max_batch_size : int ,
@ -463,6 +469,7 @@ class ModuleContainer(threading.Thread):
module_uids ,
module_uids ,
dht ,
dht ,
server_info ,
server_info ,
model_info ,
block_config = block_config ,
block_config = block_config ,
memory_cache = memory_cache ,
memory_cache = memory_cache ,
update_period = update_period ,
update_period = update_period ,
@ -671,6 +678,7 @@ class ModuleAnnouncerThread(threading.Thread):
module_uids : List [ str ] ,
module_uids : List [ str ] ,
dht : DHT ,
dht : DHT ,
server_info : ServerInfo ,
server_info : ServerInfo ,
model_info : ModelInfo ,
* ,
* ,
block_config : PretrainedConfig ,
block_config : PretrainedConfig ,
memory_cache : MemoryCache ,
memory_cache : MemoryCache ,
@ -683,6 +691,7 @@ class ModuleAnnouncerThread(threading.Thread):
self . module_uids = module_uids
self . module_uids = module_uids
self . dht = dht
self . dht = dht
self . server_info = server_info
self . server_info = server_info
self . model_info = model_info
self . memory_cache = memory_cache
self . memory_cache = memory_cache
self . bytes_per_token = block_config . hidden_size * get_size_in_bytes ( DTYPE_MAP [ server_info . torch_dtype ] )
self . bytes_per_token = block_config . hidden_size * get_size_in_bytes ( DTYPE_MAP [ server_info . torch_dtype ] )
@ -693,10 +702,10 @@ class ModuleAnnouncerThread(threading.Thread):
self . trigger = threading . Event ( )
self . trigger = threading . Event ( )
self . max_pinged = max_pinged
self . max_pinged = max_pinged
dht_prefix = module_uids [ 0 ] . split ( UID_DELIMITER ) [ 0 ]
self . dht_prefix = module_uids [ 0 ] . split ( UID_DELIMITER ) [ 0 ]
block_indices = [ int ( uid . split ( UID_DELIMITER ) [ - 1 ] ) for uid in module_uids ]
block_indices = [ int ( uid . split ( UID_DELIMITER ) [ - 1 ] ) for uid in module_uids ]
start_block , end_block = min ( block_indices ) , max ( block_indices ) + 1
start_block , end_block = min ( block_indices ) , max ( block_indices ) + 1
self . next_uids = [ f " { dht_prefix } { UID_DELIMITER } { i } " for i in range ( start_block + 1 , end_block + 1 ) ]
self . next_uids = [ f " { self . dht_prefix } { UID_DELIMITER } { i } " for i in range ( start_block + 1 , end_block + 1 ) ]
self . ping_aggregator = PingAggregator ( self . dht )
self . ping_aggregator = PingAggregator ( self . dht )
def run ( self ) - > None :
def run ( self ) - > None :
@ -720,6 +729,13 @@ class ModuleAnnouncerThread(threading.Thread):
)
)
if self . server_info . state == ServerState . OFFLINE :
if self . server_info . state == ServerState . OFFLINE :
break
break
if not self . dht_prefix . startswith ( " _ " ) : # Not private
self . dht . store (
key = " _petals.models " ,
subkey = self . dht_prefix ,
value = self . model_info . to_dict ( ) ,
expiration_time = get_dht_time ( ) + self . expiration ,
)
delay = self . update_period - ( time . perf_counter ( ) - start_time )
delay = self . update_period - ( time . perf_counter ( ) - start_time )
if delay < 0 :
if delay < 0 :