Server is not a thread anymore, so it catches KeyboardInterrupt

pull/93/head
Aleksandr Borzunov 2 years ago
parent 06a8246ae9
commit 52ea24730b

@ -124,10 +124,9 @@ def main():
use_auth_token = args.pop("use_auth_token") use_auth_token = args.pop("use_auth_token")
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
server = Server(**args, compression=compression, attn_cache_size=attn_cache_size, start=True) server = Server(**args, compression=compression, attn_cache_size=attn_cache_size)
try: try:
server.join() server.run()
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Caught KeyboardInterrupt, shutting down") logger.info("Caught KeyboardInterrupt, shutting down")
finally: finally:

@ -32,7 +32,7 @@ use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)
class Server(threading.Thread): class Server:
""" """
Runs ModuleContainer, periodically checks that the network is balanced, Runs ModuleContainer, periodically checks that the network is balanced,
restarts the ModuleContainer with other layers if the imbalance is significant restarts the ModuleContainer with other layers if the imbalance is significant
@ -68,13 +68,10 @@ class Server(threading.Thread):
mean_block_selection_delay: float = 0.5, mean_block_selection_delay: float = 0.5,
use_auth_token: Optional[str] = None, use_auth_token: Optional[str] = None,
load_in_8bit: bool = False, load_in_8bit: bool = False,
start: bool,
**kwargs, **kwargs,
): ):
"""Create a server with one or more bloom blocks. See run_server.py for documentation.""" """Create a server with one or more bloom blocks. See run_server.py for documentation."""
super().__init__()
self.converted_model_name_or_path = converted_model_name_or_path self.converted_model_name_or_path = converted_model_name_or_path
self.num_handlers = num_handlers self.num_handlers = num_handlers
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
@ -147,8 +144,6 @@ class Server(threading.Thread):
self.mean_block_selection_delay = mean_block_selection_delay self.mean_block_selection_delay = mean_block_selection_delay
self.stop = threading.Event() self.stop = threading.Event()
if start:
self.start()
def run(self): def run(self):
while True: while True:
@ -312,7 +307,19 @@ class ModuleContainer(threading.Thread):
min_batch_size=min_batch_size, min_batch_size=min_batch_size,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
) )
finally: except:
joining_announcer.stop.set()
joining_announcer.join()
declare_active_modules(
dht,
module_uids,
expiration_time=get_dht_time() + expiration,
state=ServerState.OFFLINE,
throughput=throughput,
)
logger.info(f"Announced that blocks {module_uids} are offline")
raise
else:
joining_announcer.stop.set() joining_announcer.stop.set()
joining_announcer.join() joining_announcer.join()

Loading…
Cancel
Save