@ -32,7 +32,7 @@ use_hivemind_log_handler("in_root_logger")
logger = get_logger ( __file__ )
class Server (threading . Thread ) :
class Server :
"""
Runs ModuleContainer , periodically checks that the network is balanced ,
restarts the ModuleContainer with other layers if the imbalance is significant
@ -68,13 +68,10 @@ class Server(threading.Thread):
mean_block_selection_delay : float = 0.5 ,
use_auth_token : Optional [ str ] = None ,
load_in_8bit : bool = False ,
start : bool ,
* * kwargs ,
) :
""" Create a server with one or more bloom blocks. See run_server.py for documentation. """
super ( ) . __init__ ( )
self . converted_model_name_or_path = converted_model_name_or_path
self . num_handlers = num_handlers
self . min_batch_size , self . max_batch_size = min_batch_size , max_batch_size
@ -147,8 +144,6 @@ class Server(threading.Thread):
self . mean_block_selection_delay = mean_block_selection_delay
self . stop = threading . Event ( )
if start :
self . start ( )
def run ( self ) :
while True :
@ -231,6 +226,118 @@ class Server(threading.Thread):
class ModuleContainer ( threading . Thread ) :
""" Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT. """
# noinspection PyMethodOverriding
@classmethod
def create (
cls ,
* ,
dht : DHT ,
prefix : str ,
converted_model_name_or_path : str ,
block_config : BloomConfig ,
memory_cache : MemoryCache ,
throughput : float ,
block_indices : List [ int ] ,
num_handlers : Optional [ int ] ,
min_batch_size : int ,
max_batch_size : int ,
inference_max_length : int ,
torch_dtype : torch . dtype ,
cache_dir : Optional [ str ] ,
device : Union [ str , torch . device ] ,
compression : CompressionType ,
stats_report_interval : Optional [ int ] ,
update_period : float ,
expiration : Optional [ float ] ,
prefetch_batches : int ,
sender_threads : int ,
use_auth_token : Optional [ str ] ,
load_in_8bit : bool ,
start : bool ,
) - > ModuleContainer :
module_uids = [ f " { prefix } . { block_index } " for block_index in block_indices ]
joining_announcer = ModuleAnnouncerThread (
module_uids ,
dht ,
ServerState . JOINING ,
throughput = throughput ,
update_period = update_period ,
expiration = expiration ,
daemon = True ,
)
joining_announcer . start ( )
logger . info ( f " Announced that blocks { block_indices } are joining " )
try :
blocks = { }
for module_uid , block_index in zip ( module_uids , block_indices ) :
block = load_pretrained_block (
converted_model_name_or_path ,
block_index ,
block_config ,
torch_dtype = torch_dtype ,
use_auth_token = use_auth_token ,
cache_dir = cache_dir ,
)
if load_in_8bit :
dtype = block . input_layernorm . weight . dtype
block = replace_8bit_linear ( block )
block = block . to ( device )
for param in block . parameters ( ) :
param . requires_grad = False
blocks [ module_uid ] = TransformerBackend (
module_uid ,
block ,
memory_cache = memory_cache ,
backend_dtype = None if torch_dtype == " auto " else torch_dtype ,
args_schema = (
BatchTensorDescriptor (
1 , 2048 , block_config . hidden_size , dtype = torch . float32 , compression = compression
) ,
) ,
kwargs_schema = { } ,
outputs_schema = (
BatchTensorDescriptor (
1 , 2048 , block_config . hidden_size , dtype = torch . float32 , compression = compression
) ,
) ,
min_batch_size = min_batch_size ,
max_batch_size = max_batch_size ,
)
except :
joining_announcer . stop . set ( )
joining_announcer . join ( )
declare_active_modules (
dht ,
module_uids ,
expiration_time = get_dht_time ( ) + expiration ,
state = ServerState . OFFLINE ,
throughput = throughput ,
)
logger . info ( f " Announced that blocks { module_uids } are offline " )
raise
else :
joining_announcer . stop . set ( )
joining_announcer . join ( )
return cls (
dht ,
blocks ,
throughput = throughput ,
num_connection_handlers = num_handlers ,
inference_max_length = inference_max_length ,
device = device ,
stats_report_interval = stats_report_interval ,
update_period = update_period ,
expiration = expiration ,
prefetch_batches = prefetch_batches ,
sender_threads = sender_threads ,
start = start ,
)
def __init__ (
self ,
dht : DHT ,
@ -253,9 +360,10 @@ class ModuleContainer(threading.Thread):
for _ in range ( num_connection_handlers )
]
self . runtime = Runtime ( self . module_backends , * * kwargs )
self . dht_handler_thread = ModuleAnnouncerThread (
self . module_backends ,
self . online_announcer = ModuleAnnouncerThread (
list ( self . module_backends . keys ( ) ) ,
dht ,
ServerState . ONLINE ,
throughput = throughput ,
update_period = update_period ,
expiration = expiration ,
@ -279,8 +387,7 @@ class ModuleContainer(threading.Thread):
if not self . dht . is_alive ( ) :
self . dht . run_in_background ( await_ready = True )
if self . module_backends :
self . dht_handler_thread . start ( )
self . online_announcer . start ( )
if self . checkpoint_saver is not None :
self . checkpoint_saver . start ( )
@ -290,99 +397,6 @@ class ModuleContainer(threading.Thread):
self . runtime . run ( )
# noinspection PyMethodOverriding
@classmethod
def create (
cls ,
* ,
dht : DHT ,
prefix : str ,
converted_model_name_or_path : str ,
block_config : BloomConfig ,
memory_cache : MemoryCache ,
throughput : float ,
block_indices : List [ int ] ,
num_handlers : Optional [ int ] ,
min_batch_size : int ,
max_batch_size : int ,
inference_max_length : int ,
torch_dtype : torch . dtype ,
cache_dir : Optional [ str ] ,
device : Union [ str , torch . device ] ,
compression : CompressionType ,
stats_report_interval : Optional [ int ] ,
update_period : float ,
expiration : Optional [ float ] ,
prefetch_batches : int ,
sender_threads : int ,
use_auth_token : Optional [ str ] ,
load_in_8bit : bool ,
start : bool ,
) - > ModuleContainer :
module_uids = [ f " { prefix } . { block_index } " for block_index in block_indices ]
declare_active_modules (
dht ,
module_uids ,
expiration_time = get_dht_time ( ) + expiration ,
state = ServerState . JOINING ,
throughput = throughput ,
)
logger . info ( f " Announced that blocks { block_indices } are joining " )
blocks = { }
for module_uid , block_index in zip ( module_uids , block_indices ) :
block = load_pretrained_block (
converted_model_name_or_path ,
block_index ,
block_config ,
torch_dtype = torch_dtype ,
use_auth_token = use_auth_token ,
cache_dir = cache_dir ,
)
if load_in_8bit :
dtype = block . input_layernorm . weight . dtype
block = replace_8bit_linear ( block )
block = block . to ( device )
for param in block . parameters ( ) :
param . requires_grad = False
blocks [ module_uid ] = TransformerBackend (
module_uid ,
block ,
memory_cache = memory_cache ,
backend_dtype = None if torch_dtype == " auto " else torch_dtype ,
args_schema = (
BatchTensorDescriptor (
1 , 2048 , block_config . hidden_size , dtype = torch . float32 , compression = compression
) ,
) ,
kwargs_schema = { } ,
outputs_schema = (
BatchTensorDescriptor (
1 , 2048 , block_config . hidden_size , dtype = torch . float32 , compression = compression
) ,
) ,
min_batch_size = min_batch_size ,
max_batch_size = max_batch_size ,
)
return cls (
dht ,
blocks ,
throughput = throughput ,
num_connection_handlers = num_handlers ,
inference_max_length = inference_max_length ,
device = device ,
stats_report_interval = stats_report_interval ,
update_period = update_period ,
expiration = expiration ,
prefetch_batches = prefetch_batches ,
sender_threads = sender_threads ,
start = start ,
)
def run_in_background ( self , await_ready = True , timeout = None ) :
"""
Starts ModuleContainer in a background thread . if await_ready , this method will wait until the container
@ -411,18 +425,17 @@ class ModuleContainer(threading.Thread):
Please note that terminating container otherwise ( e . g . by killing processes ) may result in zombie processes .
If you did already cause a zombie outbreak , your only option is to kill them with - 9 ( SIGKILL ) .
"""
if self . module_backends :
self . dht_handler_thread . stop . set ( )
self . dht_handler_thread . join ( )
self . online_announcer . stop . set ( )
self . online_announcer . join ( )
declare_active_modules (
self . dht ,
self . module_backends . keys ( ) ,
expiration_time = get_dht_time ( ) + self . expiration ,
state = ServerState . OFFLINE ,
throughput = self . throughput ,
)
logger . info ( f " Announced that blocks { list ( self . module_backends . keys ( ) ) } are offline " )
declare_active_modules (
self . dht ,
self . module_backends . keys ( ) ,
expiration_time = get_dht_time ( ) + self . expiration ,
state = ServerState . OFFLINE ,
throughput = self . throughput ,
)
logger . info ( f " Announced that blocks { list ( self . module_backends . keys ( ) ) } are offline " )
self . ready . clear ( )
@ -450,8 +463,9 @@ class ModuleAnnouncerThread(threading.Thread):
def __init__ (
self ,
module_ backends: Dict [ str , TransformerBackend ] ,
module_ uids: List [ str ] ,
dht : DHT ,
state : ServerState ,
* ,
throughput : float ,
update_period : float = 30 ,
@ -459,8 +473,9 @@ class ModuleAnnouncerThread(threading.Thread):
* * kwargs ,
) :
super ( ) . __init__ ( * * kwargs )
self . module_ backends = module_backen ds
self . module_ uids = module_ui ds
self . dht = dht
self . state = state
self . throughput = throughput
self . update_period = update_period
self . expiration = expiration
@ -470,9 +485,9 @@ class ModuleAnnouncerThread(threading.Thread):
while True :
declare_active_modules (
self . dht ,
self . module_ backends. keys ( ) ,
self . module_ uids ,
expiration_time = get_dht_time ( ) + self . expiration ,
state = ServerState . ONLINE ,
state = self . state ,
throughput = self . throughput ,
)
if self . stop . wait ( self . update_period ) :