Remove no-op process in PrioritizedTaskPool (#484)

Please revert this if you ever need to make `PrioritizedTaskPool` a process again.
pull/486/head
Alexander Borzunov 9 months ago committed by GitHub
parent 26ebbfe8f0
commit 459933f846
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,7 +9,6 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
import torch import torch
from hivemind import get_logger from hivemind import get_logger
from hivemind.moe.server.task_pool import TaskPoolBase
from hivemind.utils.mpfuture import ALL_STATES, MPFuture from hivemind.utils.mpfuture import ALL_STATES, MPFuture
logger = get_logger(__name__) logger = get_logger(__name__)
@ -27,7 +26,7 @@ class Task:
return self.future._uid return self.future._uid
class PrioritizedTaskPool(TaskPoolBase): class PrioritizedTaskPool(threading.Thread):
""" """
Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
returns results (or exception) to the corresponding ConnectionHandler. Runs a background process. returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
@ -57,52 +56,41 @@ class PrioritizedTaskPool(TaskPoolBase):
daemon=True, daemon=True,
start=False, start=False,
): ):
super().__init__(process_func, daemon=daemon, name=name) super().__init__(daemon=daemon, name=name)
self.process_func = process_func
# the lower the priority is, the more urgent it is to process this pool
self._priority = mp.Value(ctypes.c_double, 1.0)
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.device = device self.device = device
self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers
self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime
self._prioritizer_thread = threading.Thread(
name=self.name + "_prioritizer",
target=self._prioritize_tasks,
args=[self.submitted_tasks, self._ordered_tasks],
daemon=True,
)
self._dispatched_tasks = {} self._dispatched_tasks = {}
self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False) self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0) self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
self.priority = float("inf"), float("inf") # (first task priority, first task timestamp) self.priority = float("inf"), float("inf") # (first task priority, first task timestamp)
self._stop = mp.Event()
if start: if start:
self.start() self.start()
@staticmethod def run(self):
def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
"""Read tasks from incoming queue and put them into a local priority queue""" """Read tasks from incoming queue and put them into a local priority queue"""
while True: while True:
task = submitted_tasks.get() task = self.submitted_tasks.get()
if task is None: if task is None:
logger.debug("Shutting down prioritizer thread") logger.debug("Shutting down prioritizer thread")
break break
ordered_tasks.put(task, block=True) self._ordered_tasks.put(task, block=True)
def start(self):
assert not self.is_alive() and not self._prioritizer_thread.is_alive()
self._prioritizer_thread.start()
super().start()
def shutdown(self, timeout: float = 3): def terminate(self):
self.submitted_tasks.put(None) # Shuts down self._prioritizer_thread """An alias for hivemind.Runtime that assumes that each TaskPool is a process"""
self._stop.set() self.shutdown()
self.join(timeout) def shutdown(self):
if self.is_alive(): self.submitted_tasks.put(None) # Shuts down self.run()
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
self.terminate()
def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture: def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
"""Add task to this pool's queue, return Future for its output""" """Add task to this pool's queue, return Future for its output"""
@ -163,9 +151,6 @@ class PrioritizedTaskPool(TaskPoolBase):
else: else:
task.future.set_exception(exception) task.future.set_exception(exception)
def run(self, *args, **kwargs):
self._stop.wait()
@property @property
def empty(self): def empty(self):
return not self.batch_receiver.poll() return not self.batch_receiver.poll()

Loading…
Cancel
Save