diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 562b56e..d084f06 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -264,6 +264,7 @@ class Server: self.mean_balance_check_period = mean_balance_check_period self.mean_block_selection_delay = mean_block_selection_delay + self.module_container = None self.stop = threading.Event() def _choose_num_blocks(self) -> int: @@ -377,7 +378,7 @@ class Server: self._clean_memory_and_fds() def _clean_memory_and_fds(self): - del self.module_container + self.module_container = None gc.collect() # In particular, this closes unused file descriptors if self.device.type == "cuda": @@ -410,8 +411,10 @@ class Server: module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True) return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality) - def shutdown(self): + def shutdown(self, timeout: Optional[float] = 5): self.stop.set() + if self.module_container is not None and self.module_container.is_alive(): + self.module_container.join(timeout) if self.reachability_protocol is not None: self.reachability_protocol.shutdown()