@ -13,8 +13,10 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src import declare_active_modules , BloomConfig
from src . bloom . from_pretrained import DTYPE_MAP , load_pretrained_block
from src . data_structures import CHAIN_DELIMITER , UID_DELIMITER
from src . data_structures import CHAIN_DELIMITER , UID_DELIMITER , ServerState
from src . dht_utils import get_remote_module_infos
from src . server . backend import TransformerBackend
from src . server . block_selection import choose_best_blocks
from src . server . cache import MemoryCache
from src . server . handler import TransformerConnectionHandler
@ -32,19 +34,26 @@ class Server(threading.Thread):
* ,
device : torch . device ,
num_connection_handlers : int = 8 ,
throughput : float ,
update_period : float = 30 ,
expiration : Optional [ float ] = None ,
start : bool ,
* * kwargs ,
) :
threading . Thread . __init__ ( self )
self . dht , self . module_backends , self . update_period = dht , module_backends , update_period
self . dht , self . module_backends = dht , module_backends
self . throughput , self . update_period , self . expiration = throughput , update_period , expiration
self . conn_handlers = [
TransformerConnectionHandler ( dht , self . module_backends ) for _ in range ( num_connection_handlers )
]
self . runtime = Runtime ( self . module_backends , device = device , * * kwargs )
self . dht_handler_thread = ModuleAnnouncerThread (
self . module_backends , dht , update_period , expiration , daemon = True
self . module_backends ,
dht ,
throughput = throughput ,
update_period = update_period ,
expiration = expiration ,
daemon = True ,
)
self . checkpoint_saver = None # no need to save checkpoints since we do not change model state
@ -86,6 +95,7 @@ class Server(threading.Thread):
cls ,
prefix : Optional [ str ] ,
converted_model_name_or_path : str ,
throughput : float ,
num_blocks : Optional [ int ] = None ,
block_indices : Optional [ str ] = None ,
num_handlers : Optional [ int ] = None ,
@ -116,6 +126,9 @@ class Server(threading.Thread):
)
logger . info ( f " Automatic dht prefix: { prefix } " )
assert ( block_indices is None ) != ( num_blocks is None ) , " please specify num_blocks or block_indices, not both "
if expiration is None :
expiration = max ( 2 * update_period , MAX_DHT_TIME_DISCREPANCY_SECONDS )
dht = DHT ( initial_peers = initial_peers , start = True , * * kwargs )
visible_maddrs_str = [ str ( a ) for a in dht . get_visible_maddrs ( ) ]
logger . info ( f " Running DHT node on { visible_maddrs_str } , initial peers = { initial_peers } " )
@ -127,6 +140,10 @@ class Server(threading.Thread):
torch_dtype = DTYPE_MAP [ torch_dtype ]
assert torch_dtype in DTYPE_MAP . values ( ) , f " torch_dtype must be one of { list ( DTYPE_MAP . values ( ) ) } "
block_config = BloomConfig . from_pretrained (
converted_model_name_or_path , use_auth_token = use_auth_token
)
if block_indices is not None :
try :
first_block_index , last_block_index = block_indices . split ( " : " )
@ -137,16 +154,22 @@ class Server(threading.Thread):
block_indices = range ( first_block_index , last_block_index )
else :
assert num_blocks is not None
block_indices = range ( num_blocks ) # TODO replace with proper load balancing
uids = [ f " { prefix } . { block_index } " for block_index in range ( block_config . n_layer ) ]
module_infos = get_remote_module_infos ( dht , uids , expiration_time = float ( " inf " ) )
block_indices = choose_best_blocks ( num_blocks , module_infos )
block_config = BloomConfig . from_pretrained (
converted_model_name_or_path , use_auth_token = use_auth_token
module_uids = [ f " { prefix } . { block_index } " for block_index in block_indices ]
declare_active_modules (
dht ,
module_uids ,
expiration_time = get_dht_time ( ) + expiration ,
state = ServerState . JOINING ,
throughput = throughput ,
)
logger . info ( f " Announced that blocks { block_indices } are joining " )
# initialize modules
blocks = { }
for block_index in block_indices :
module_uid = f " { prefix } . { block_index } "
for module_uid , block_index in zip ( module_uids , block_indices ) :
block = load_pretrained_block (
converted_model_name_or_path ,
block_index ,
@ -173,6 +196,7 @@ class Server(threading.Thread):
return cls (
dht ,
blocks ,
throughput = throughput ,
num_connection_handlers = num_handlers ,
device = device ,
stats_report_interval = stats_report_interval ,
@ -209,6 +233,16 @@ class Server(threading.Thread):
Please note that terminating server otherwise ( e . g . by killing processes ) may result in zombie processes .
If you did already cause a zombie outbreak , your only option is to kill them with - 9 ( SIGKILL ) .
"""
if self . module_backends :
declare_active_modules (
self . dht ,
self . module_backends . keys ( ) ,
expiration_time = get_dht_time ( ) + self . expiration ,
state = ServerState . OFFLINE ,
throughput = self . throughput ,
)
logger . info ( f " Announced that blocks { list ( self . module_backends . keys ( ) ) } are offline " )
self . ready . clear ( )
for process in self . conn_handlers :
@ -230,25 +264,38 @@ class Server(threading.Thread):
logger . debug ( f " Shutting down runtime " )
self . runtime . shutdown ( )
logger . info ( " Server shut down succesfully" )
logger . info ( " Server shut down succesfully" )
class ModuleAnnouncerThread ( threading . Thread ) :
""" Periodically announces that this server hosts the specified modules, visible to all DHT peers """
def __init__ (
self , module_backends , dht : DHT , update_period : float = 30 , expiration : Optional [ int ] = None , * * kwargs
self ,
module_backends : Dict [ str , TransformerBackend ] ,
dht : DHT ,
* ,
throughput : float ,
update_period : float = 30 ,
expiration : float ,
* * kwargs
) :
super ( ) . __init__ ( * * kwargs )
if expiration is None :
expiration = max ( 2 * update_period , MAX_DHT_TIME_DISCREPANCY_SECONDS )
self . module_backends = module_backends
self . dht = dht
self . throughput = throughput
self . update_period = update_period
self . expiration = expiration
self . stop = threading . Event ( )
def run ( self ) - > None :
declare_active_modules ( self . dht , self . module_backends . keys ( ) , get_dht_time ( ) + self . expiration )
while not self . stop . wait ( self . update_period ) :
declare_active_modules ( self . dht , self . module_backends . keys ( ) , get_dht_time ( ) + self . expiration )
while True :
declare_active_modules (
self . dht ,
self . module_backends . keys ( ) ,
expiration_time = get_dht_time ( ) + self . expiration ,
state = ServerState . ONLINE ,
throughput = self . throughput ,
)
if self . stop . wait ( self . update_period ) :
break