Fix "Too many open files" during rebalancing

pull/83/head
Aleksandr Borzunov 2 years ago
parent f64eb3a665
commit 75ed6ac49c

@ -64,6 +64,9 @@ def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModule
def should_choose_other_blocks( def should_choose_other_blocks(
local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], min_balance_quality: float local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], min_balance_quality: float
) -> bool: ) -> bool:
if min_balance_quality > 1.0:
return True # Forces rebalancing on each check (may be used for debugging purposes)
spans, throughputs = _compute_spans(module_infos) spans, throughputs = _compute_spans(module_infos)
initial_throughput = throughputs.min() initial_throughput = throughputs.min()

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import gc
import multiprocessing as mp import multiprocessing as mp
import random import random
import threading import threading
@ -7,6 +8,7 @@ import time
from typing import Dict, List, Optional, Sequence, Union from typing import Dict, List, Optional, Sequence, Union
import numpy as np import numpy as np
import psutil
import torch import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file from hivemind.moe.server.layers import add_custom_models_from_file
@ -187,6 +189,16 @@ class Server(threading.Thread):
finally: finally:
self.module_container.shutdown() self.module_container.shutdown()
self._clean_memory_and_fds()
def _clean_memory_and_fds(self):
del self.module_container
gc.collect() # In particular, this closes unused file descriptors
cur_proc = psutil.Process()
num_fds = [proc.num_fds() for proc in [cur_proc] + psutil.Process().children(recursive=True)]
logger.info(f"Cleanup complete, {sum(num_fds)} open file descriptors left")
def _choose_blocks(self) -> List[int]: def _choose_blocks(self) -> List[int]:
if self.strict_block_indices is not None: if self.strict_block_indices is not None:
return self.strict_block_indices return self.strict_block_indices
@ -418,6 +430,11 @@ class ModuleContainer(threading.Thread):
self.checkpoint_saver.stop.set() self.checkpoint_saver.stop.set()
self.checkpoint_saver.join() self.checkpoint_saver.join()
logger.debug(f"Shutting down pools")
for pool in self.runtime.pools:
if pool.is_alive():
pool.shutdown()
logger.debug(f"Shutting down runtime") logger.debug(f"Shutting down runtime")
self.runtime.shutdown() self.runtime.shutdown()

@ -70,6 +70,8 @@ class PrioritizedTaskPool(TaskPoolBase):
self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False) self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0) self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
self.priority = float("inf"), float("inf") # (first task priority, first task timestamp) self.priority = float("inf"), float("inf") # (first task priority, first task timestamp)
self._stop = mp.Event()
if start: if start:
self.start() self.start()
@ -89,10 +91,14 @@ class PrioritizedTaskPool(TaskPoolBase):
self._prioritizer_thread.start() self._prioritizer_thread.start()
super().start() super().start()
def shutdown(self, timeout: Optional[float] = None): def shutdown(self, timeout: float = 3):
self.submitted_tasks.put(None) self.submitted_tasks.put(None) # Shuts down self._prioritizer_thread
self.terminate() self._stop.set()
self._prioritizer_thread.join(timeout)
self.join(timeout)
if self.is_alive():
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
self.terminate()
def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture: def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
"""Add task to this pool's queue, return Future for its output""" """Add task to this pool's queue, return Future for its output"""
@ -154,7 +160,7 @@ class PrioritizedTaskPool(TaskPoolBase):
task.future.set_exception(exception) task.future.set_exception(exception)
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
mp.Event().wait() self._stop.wait()
@property @property
def empty(self): def empty(self):

Loading…
Cancel
Save