diff --git a/cli/run_server.py b/cli/run_server.py index 6ab7aa2..3637338 100644 --- a/cli/run_server.py +++ b/cli/run_server.py @@ -124,10 +124,9 @@ def main(): use_auth_token = args.pop("use_auth_token") args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token - server = Server(**args, compression=compression, attn_cache_size=attn_cache_size, start=True) - + server = Server(**args, compression=compression, attn_cache_size=attn_cache_size) try: - server.join() + server.run() except KeyboardInterrupt: logger.info("Caught KeyboardInterrupt, shutting down") finally: diff --git a/src/server/server.py b/src/server/server.py index 93b884b..00d54b2 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -32,7 +32,7 @@ use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) -class Server(threading.Thread): +class Server: """ Runs ModuleContainer, periodically checks that the network is balanced, restarts the ModuleContainer with other layers if the imbalance is significant @@ -68,13 +68,10 @@ class Server(threading.Thread): mean_block_selection_delay: float = 0.5, use_auth_token: Optional[str] = None, load_in_8bit: bool = False, - start: bool, **kwargs, ): """Create a server with one or more bloom blocks. See run_server.py for documentation.""" - super().__init__() - self.converted_model_name_or_path = converted_model_name_or_path self.num_handlers = num_handlers self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size @@ -147,8 +144,6 @@ class Server(threading.Thread): self.mean_block_selection_delay = mean_block_selection_delay self.stop = threading.Event() - if start: - self.start() def run(self): while True: @@ -312,7 +307,19 @@ class ModuleContainer(threading.Thread): min_batch_size=min_batch_size, max_batch_size=max_batch_size, ) - finally: + except: + joining_announcer.stop.set() + joining_announcer.join() + declare_active_modules( + dht, + module_uids, + expiration_time=get_dht_time() + expiration, + state=ServerState.OFFLINE, + throughput=throughput, + ) + logger.info(f"Announced that blocks {module_uids} are offline") + raise + else: joining_announcer.stop.set() joining_announcer.join()