|
|
|
@ -32,7 +32,7 @@ use_hivemind_log_handler("in_root_logger")
|
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Server(threading.Thread):
|
|
|
|
|
class Server:
|
|
|
|
|
"""
|
|
|
|
|
Runs ModuleContainer, periodically checks that the network is balanced,
|
|
|
|
|
restarts the ModuleContainer with other layers if the imbalance is significant
|
|
|
|
@ -68,13 +68,10 @@ class Server(threading.Thread):
|
|
|
|
|
mean_block_selection_delay: float = 0.5,
|
|
|
|
|
use_auth_token: Optional[str] = None,
|
|
|
|
|
load_in_8bit: bool = False,
|
|
|
|
|
start: bool,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
|
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.converted_model_name_or_path = converted_model_name_or_path
|
|
|
|
|
self.num_handlers = num_handlers
|
|
|
|
|
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
|
|
@ -147,8 +144,6 @@ class Server(threading.Thread):
|
|
|
|
|
self.mean_block_selection_delay = mean_block_selection_delay
|
|
|
|
|
|
|
|
|
|
self.stop = threading.Event()
|
|
|
|
|
if start:
|
|
|
|
|
self.start()
|
|
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
|
while True:
|
|
|
|
@ -312,7 +307,19 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
min_batch_size=min_batch_size,
|
|
|
|
|
max_batch_size=max_batch_size,
|
|
|
|
|
)
|
|
|
|
|
finally:
|
|
|
|
|
except:
|
|
|
|
|
joining_announcer.stop.set()
|
|
|
|
|
joining_announcer.join()
|
|
|
|
|
declare_active_modules(
|
|
|
|
|
dht,
|
|
|
|
|
module_uids,
|
|
|
|
|
expiration_time=get_dht_time() + expiration,
|
|
|
|
|
state=ServerState.OFFLINE,
|
|
|
|
|
throughput=throughput,
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"Announced that blocks {module_uids} are offline")
|
|
|
|
|
raise
|
|
|
|
|
else:
|
|
|
|
|
joining_announcer.stop.set()
|
|
|
|
|
joining_announcer.join()
|
|
|
|
|
|
|
|
|
|