Remove no-op process in PrioritizedTaskPool (#484)

Please revert this if you ever need to make `PrioritizedTaskPool` a process again.
pull/486/head
Alexander Borzunov 8 months ago committed by GitHub
parent 26ebbfe8f0
commit 459933f846
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,7 +9,6 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
import torch
from hivemind import get_logger
from hivemind.moe.server.task_pool import TaskPoolBase
from hivemind.utils.mpfuture import ALL_STATES, MPFuture
logger = get_logger(__name__)
@ -27,7 +26,7 @@ class Task:
return self.future._uid
class PrioritizedTaskPool(TaskPoolBase):
class PrioritizedTaskPool(threading.Thread):
"""
Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
@ -57,52 +56,41 @@ class PrioritizedTaskPool(TaskPoolBase):
daemon=True,
start=False,
):
super().__init__(process_func, daemon=daemon, name=name)
super().__init__(daemon=daemon, name=name)
self.process_func = process_func
# the lower the priority is, the more urgent it is to process this pool
self._priority = mp.Value(ctypes.c_double, 1.0)
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.device = device
self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers
self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime
self._prioritizer_thread = threading.Thread(
name=self.name + "_prioritizer",
target=self._prioritize_tasks,
args=[self.submitted_tasks, self._ordered_tasks],
daemon=True,
)
self._dispatched_tasks = {}
self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
self.priority = float("inf"), float("inf") # (first task priority, first task timestamp)
self._stop = mp.Event()
if start:
self.start()
@staticmethod
def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
def run(self):
"""Read tasks from incoming queue and put them into a local priority queue"""
while True:
task = submitted_tasks.get()
task = self.submitted_tasks.get()
if task is None:
logger.debug("Shutting down prioritizer thread")
break
ordered_tasks.put(task, block=True)
def start(self):
assert not self.is_alive() and not self._prioritizer_thread.is_alive()
self._prioritizer_thread.start()
super().start()
self._ordered_tasks.put(task, block=True)
def shutdown(self, timeout: float = 3):
self.submitted_tasks.put(None) # Shuts down self._prioritizer_thread
self._stop.set()
def terminate(self):
"""An alias for hivemind.Runtime that assumes that each TaskPool is a process"""
self.shutdown()
self.join(timeout)
if self.is_alive():
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
self.terminate()
def shutdown(self):
self.submitted_tasks.put(None) # Shuts down self.run()
def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
"""Add task to this pool's queue, return Future for its output"""
@ -163,9 +151,6 @@ class PrioritizedTaskPool(TaskPoolBase):
else:
task.future.set_exception(exception)
def run(self, *args, **kwargs):
self._stop.wait()
@property
def empty(self):
return not self.batch_receiver.poll()

Loading…
Cancel
Save