|
|
|
@ -2,13 +2,15 @@ import ctypes
|
|
|
|
|
import multiprocessing as mp
|
|
|
|
|
import threading
|
|
|
|
|
import time
|
|
|
|
|
from concurrent.futures._base import PENDING
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
from queue import PriorityQueue
|
|
|
|
|
from typing import Any, List, Optional, Sequence, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind import MPFuture, get_logger, use_hivemind_log_handler
|
|
|
|
|
from hivemind import get_logger, use_hivemind_log_handler
|
|
|
|
|
from hivemind.moe.server.task_pool import TaskPoolBase
|
|
|
|
|
from hivemind.utils.mpfuture import ALL_STATES, MPFuture
|
|
|
|
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
@ -102,7 +104,12 @@ class PrioritizedTaskPool(TaskPoolBase):
|
|
|
|
|
|
|
|
|
|
def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
|
|
|
|
|
"""Add task to this pool's queue, return Future for its output"""
|
|
|
|
|
task = Task(priority, time.monotonic(), MPFuture(), args)
|
|
|
|
|
future = MPFuture()
|
|
|
|
|
# Remove shmem from MPFuture. This disables the .cancel() feature but
|
|
|
|
|
# saves the server from "could not unlink the shared memory file" crashes during rebalancing
|
|
|
|
|
future._shared_state_code = torch.tensor([ALL_STATES.index(PENDING)], dtype=torch.uint8)
|
|
|
|
|
|
|
|
|
|
task = Task(priority, time.monotonic(), future, args)
|
|
|
|
|
if self.get_task_size(task) > self.max_batch_size:
|
|
|
|
|
exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
|
|
|
|
|
task.future.set_exception(exc)
|
|
|
|
|