diff --git a/src/server/block_selection.py b/src/server/block_selection.py index fe926e9..af875c2 100644 --- a/src/server/block_selection.py +++ b/src/server/block_selection.py @@ -64,6 +64,9 @@ def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModule def should_choose_other_blocks( local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], min_balance_quality: float ) -> bool: + if min_balance_quality > 1.0: + return True # Forces rebalancing on each check (may be used for debugging purposes) + spans, throughputs = _compute_spans(module_infos) initial_throughput = throughputs.min() diff --git a/src/server/server.py b/src/server/server.py index a0c1c09..1ed11ca 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -1,5 +1,6 @@ from __future__ import annotations +import gc import multiprocessing as mp import random import threading @@ -7,6 +8,7 @@ import time from typing import Dict, List, Optional, Sequence, Union import numpy as np +import psutil import torch from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind.moe.server.layers import add_custom_models_from_file @@ -187,6 +189,16 @@ class Server(threading.Thread): finally: self.module_container.shutdown() + self._clean_memory_and_fds() + + def _clean_memory_and_fds(self): + del self.module_container + gc.collect() # In particular, this closes unused file descriptors + + cur_proc = psutil.Process() + num_fds = [proc.num_fds() for proc in [cur_proc] + psutil.Process().children(recursive=True)] + logger.info(f"Cleanup complete, {sum(num_fds)} open file descriptors left") + def _choose_blocks(self) -> List[int]: if self.strict_block_indices is not None: return self.strict_block_indices @@ -418,6 +430,11 @@ class ModuleContainer(threading.Thread): self.checkpoint_saver.stop.set() self.checkpoint_saver.join() + logger.debug(f"Shutting down pools") + for pool in self.runtime.pools: + if pool.is_alive(): + pool.shutdown() + logger.debug(f"Shutting down runtime") self.runtime.shutdown() diff --git a/src/server/task_pool.py b/src/server/task_pool.py index eec80bc..672248f 100644 --- a/src/server/task_pool.py +++ b/src/server/task_pool.py @@ -70,6 +70,8 @@ class PrioritizedTaskPool(TaskPoolBase): self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False) self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0) self.priority = float("inf"), float("inf") # (first task priority, first task timestamp) + + self._stop = mp.Event() if start: self.start() @@ -89,10 +91,14 @@ class PrioritizedTaskPool(TaskPoolBase): self._prioritizer_thread.start() super().start() - def shutdown(self, timeout: Optional[float] = None): - self.submitted_tasks.put(None) - self.terminate() - self._prioritizer_thread.join(timeout) + def shutdown(self, timeout: float = 3): + self.submitted_tasks.put(None) # Shuts down self._prioritizer_thread + self._stop.set() + + self.join(timeout) + if self.is_alive(): + logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM") + self.terminate() def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture: """Add task to this pool's queue, return Future for its output""" @@ -154,7 +160,7 @@ class PrioritizedTaskPool(TaskPoolBase): task.future.set_exception(exception) def run(self, *args, **kwargs): - mp.Event().wait() + self._stop.wait() @property def empty(self):