From 459933f84664af436b892468ba2aa184b1369306 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 30 Aug 2023 06:07:04 +0400 Subject: [PATCH] Remove no-op process in PrioritizedTaskPool (#484) Please revert this if you ever need to make `PrioritizedTaskPool` a process again. --- src/petals/server/task_pool.py | 43 +++++++++++----------------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/src/petals/server/task_pool.py b/src/petals/server/task_pool.py index e027d52..94bad79 100644 --- a/src/petals/server/task_pool.py +++ b/src/petals/server/task_pool.py @@ -9,7 +9,6 @@ from typing import Any, List, Optional, Sequence, Tuple, Union import torch from hivemind import get_logger -from hivemind.moe.server.task_pool import TaskPoolBase from hivemind.utils.mpfuture import ALL_STATES, MPFuture logger = get_logger(__name__) @@ -27,7 +26,7 @@ class Task: return self.future._uid -class PrioritizedTaskPool(TaskPoolBase): +class PrioritizedTaskPool(threading.Thread): """ Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then returns results (or exception) to the corresponding ConnectionHandler. Runs a background process. @@ -57,52 +56,41 @@ class PrioritizedTaskPool(TaskPoolBase): daemon=True, start=False, ): - super().__init__(process_func, daemon=daemon, name=name) + super().__init__(daemon=daemon, name=name) + self.process_func = process_func + # the lower the priority is, the more urgent it is to process this pool + self._priority = mp.Value(ctypes.c_double, 1.0) + self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size self.device = device self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime - self._prioritizer_thread = threading.Thread( - name=self.name + "_prioritizer", - target=self._prioritize_tasks, - args=[self.submitted_tasks, self._ordered_tasks], - daemon=True, - ) self._dispatched_tasks = {} self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False) self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0) self.priority = float("inf"), float("inf") # (first task priority, first task timestamp) - self._stop = mp.Event() if start: self.start() - @staticmethod - def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue): + def run(self): """Read tasks from incoming queue and put them into a local priority queue""" while True: - task = submitted_tasks.get() + task = self.submitted_tasks.get() if task is None: logger.debug("Shutting down prioritizer thread") break - ordered_tasks.put(task, block=True) - - def start(self): - assert not self.is_alive() and not self._prioritizer_thread.is_alive() - self._prioritizer_thread.start() - super().start() + self._ordered_tasks.put(task, block=True) - def shutdown(self, timeout: float = 3): - self.submitted_tasks.put(None) # Shuts down self._prioritizer_thread - self._stop.set() + def terminate(self): + """An alias for hivemind.Runtime that assumes that each TaskPool is a process""" + self.shutdown() - self.join(timeout) - if self.is_alive(): - logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM") - self.terminate() + def shutdown(self): + self.submitted_tasks.put(None) # Shuts down self.run() def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture: """Add task to this pool's queue, return Future for its output""" @@ -163,9 +151,6 @@ class PrioritizedTaskPool(TaskPoolBase): else: task.future.set_exception(exception) - def run(self, *args, **kwargs): - self._stop.wait() - @property def empty(self): return not self.batch_receiver.poll()