|
|
|
@ -9,7 +9,6 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind import get_logger
|
|
|
|
|
from hivemind.moe.server.task_pool import TaskPoolBase
|
|
|
|
|
from hivemind.utils.mpfuture import ALL_STATES, MPFuture
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
@ -27,7 +26,7 @@ class Task:
|
|
|
|
|
return self.future._uid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PrioritizedTaskPool(TaskPoolBase):
|
|
|
|
|
class PrioritizedTaskPool(threading.Thread):
|
|
|
|
|
"""
|
|
|
|
|
Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
|
|
|
|
|
returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
|
|
|
|
@ -57,52 +56,41 @@ class PrioritizedTaskPool(TaskPoolBase):
|
|
|
|
|
daemon=True,
|
|
|
|
|
start=False,
|
|
|
|
|
):
|
|
|
|
|
super().__init__(process_func, daemon=daemon, name=name)
|
|
|
|
|
super().__init__(daemon=daemon, name=name)
|
|
|
|
|
self.process_func = process_func
|
|
|
|
|
# the lower the priority is, the more urgent it is to process this pool
|
|
|
|
|
self._priority = mp.Value(ctypes.c_double, 1.0)
|
|
|
|
|
|
|
|
|
|
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
|
|
|
self.device = device
|
|
|
|
|
|
|
|
|
|
self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers
|
|
|
|
|
self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime
|
|
|
|
|
|
|
|
|
|
self._prioritizer_thread = threading.Thread(
|
|
|
|
|
name=self.name + "_prioritizer",
|
|
|
|
|
target=self._prioritize_tasks,
|
|
|
|
|
args=[self.submitted_tasks, self._ordered_tasks],
|
|
|
|
|
daemon=True,
|
|
|
|
|
)
|
|
|
|
|
self._dispatched_tasks = {}
|
|
|
|
|
self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
|
|
|
|
|
self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
|
|
|
|
|
self.priority = float("inf"), float("inf") # (first task priority, first task timestamp)
|
|
|
|
|
|
|
|
|
|
self._stop = mp.Event()
|
|
|
|
|
if start:
|
|
|
|
|
self.start()
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
|
|
|
|
|
def run(self):
|
|
|
|
|
"""Read tasks from incoming queue and put them into a local priority queue"""
|
|
|
|
|
while True:
|
|
|
|
|
task = submitted_tasks.get()
|
|
|
|
|
task = self.submitted_tasks.get()
|
|
|
|
|
if task is None:
|
|
|
|
|
logger.debug("Shutting down prioritizer thread")
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
ordered_tasks.put(task, block=True)
|
|
|
|
|
self._ordered_tasks.put(task, block=True)
|
|
|
|
|
|
|
|
|
|
def start(self):
|
|
|
|
|
assert not self.is_alive() and not self._prioritizer_thread.is_alive()
|
|
|
|
|
self._prioritizer_thread.start()
|
|
|
|
|
super().start()
|
|
|
|
|
def terminate(self):
|
|
|
|
|
"""An alias for hivemind.Runtime that assumes that each TaskPool is a process"""
|
|
|
|
|
self.shutdown()
|
|
|
|
|
|
|
|
|
|
def shutdown(self, timeout: float = 3):
|
|
|
|
|
self.submitted_tasks.put(None) # Shuts down self._prioritizer_thread
|
|
|
|
|
self._stop.set()
|
|
|
|
|
|
|
|
|
|
self.join(timeout)
|
|
|
|
|
if self.is_alive():
|
|
|
|
|
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
|
|
|
|
|
self.terminate()
|
|
|
|
|
def shutdown(self):
|
|
|
|
|
self.submitted_tasks.put(None) # Shuts down self.run()
|
|
|
|
|
|
|
|
|
|
def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
|
|
|
|
|
"""Add task to this pool's queue, return Future for its output"""
|
|
|
|
@ -163,9 +151,6 @@ class PrioritizedTaskPool(TaskPoolBase):
|
|
|
|
|
else:
|
|
|
|
|
task.future.set_exception(exception)
|
|
|
|
|
|
|
|
|
|
def run(self, *args, **kwargs):
|
|
|
|
|
self._stop.wait()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def empty(self):
|
|
|
|
|
return not self.batch_receiver.poll()
|
|
|
|
|