Fix "Too many open files" during rebalancing

pull/83/head
Aleksandr Borzunov 2 years ago
parent f64eb3a665
commit 75ed6ac49c

@ -64,6 +64,9 @@ def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModule
def should_choose_other_blocks(
local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], min_balance_quality: float
) -> bool:
if min_balance_quality > 1.0:
return True # Forces rebalancing on each check (may be used for debugging purposes)
spans, throughputs = _compute_spans(module_infos)
initial_throughput = throughputs.min()

@ -1,5 +1,6 @@
from __future__ import annotations
import gc
import multiprocessing as mp
import random
import threading
@ -7,6 +8,7 @@ import time
from typing import Dict, List, Optional, Sequence, Union
import numpy as np
import psutil
import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file
@ -187,6 +189,16 @@ class Server(threading.Thread):
finally:
self.module_container.shutdown()
self._clean_memory_and_fds()
def _clean_memory_and_fds(self):
del self.module_container
gc.collect() # In particular, this closes unused file descriptors
cur_proc = psutil.Process()
num_fds = [proc.num_fds() for proc in [cur_proc] + psutil.Process().children(recursive=True)]
logger.info(f"Cleanup complete, {sum(num_fds)} open file descriptors left")
def _choose_blocks(self) -> List[int]:
if self.strict_block_indices is not None:
return self.strict_block_indices
@ -418,6 +430,11 @@ class ModuleContainer(threading.Thread):
self.checkpoint_saver.stop.set()
self.checkpoint_saver.join()
logger.debug(f"Shutting down pools")
for pool in self.runtime.pools:
if pool.is_alive():
pool.shutdown()
logger.debug(f"Shutting down runtime")
self.runtime.shutdown()

@ -70,6 +70,8 @@ class PrioritizedTaskPool(TaskPoolBase):
self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
self.priority = float("inf"), float("inf") # (first task priority, first task timestamp)
self._stop = mp.Event()
if start:
self.start()
@ -89,10 +91,14 @@ class PrioritizedTaskPool(TaskPoolBase):
self._prioritizer_thread.start()
super().start()
def shutdown(self, timeout: Optional[float] = None):
self.submitted_tasks.put(None)
self.terminate()
self._prioritizer_thread.join(timeout)
def shutdown(self, timeout: float = 3):
self.submitted_tasks.put(None) # Shuts down self._prioritizer_thread
self._stop.set()
self.join(timeout)
if self.is_alive():
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
self.terminate()
def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
"""Add task to this pool's queue, return Future for its output"""
@ -154,7 +160,7 @@ class PrioritizedTaskPool(TaskPoolBase):
task.future.set_exception(exception)
def run(self, *args, **kwargs):
mp.Event().wait()
self._stop.wait()
@property
def empty(self):

Loading…
Cancel
Save