Priority tasks (#47)
* priority in handlers and backend pools * simple points system on server side * priortize task in handler before submit task * fix tests * s/expert/block/g Co-authored-by: justheuristic <justheuristic@gmail.com>fix-pb2
parent
892d18fea7
commit
50535a8435
@ -0,0 +1,156 @@
|
||||
"""
|
||||
Utility functions that call RPC forward or backward on a single remote server
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Iterable, List, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
|
||||
from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
|
||||
from hivemind.p2p import StubBase
|
||||
from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
|
||||
from hivemind.proto import runtime_pb2
|
||||
from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
|
||||
from hivemind.utils.streaming import split_for_streaming
|
||||
|
||||
from src.data_structures import ModuleUID, RPCInfo
|
||||
|
||||
|
||||
async def run_remote_forward(
|
||||
uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b"", **kwargs
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Serializes input tensors and calls "rpc_forward" on a remote server.
|
||||
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
|
||||
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
|
||||
"""
|
||||
|
||||
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
|
||||
# detach to avoid pickling the computation graph
|
||||
assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
|
||||
kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
|
||||
|
||||
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
|
||||
forward_inputs = (inputs, kwargs)
|
||||
|
||||
# Modify forward_schema to support prompts
|
||||
args_schema, kwargs_schema = rpc_info["forward_schema"]
|
||||
# TODO: rm this assert when support arbitrary number of input tensors
|
||||
assert len(args_schema) == 1 and len(inputs) == 2
|
||||
forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
|
||||
|
||||
if not nested_compare(forward_inputs, forward_schema_with_prompts):
|
||||
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
|
||||
|
||||
forward_inputs = nested_flatten(forward_inputs)
|
||||
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
|
||||
|
||||
# Asynchronous serialization
|
||||
loop = asyncio.get_running_loop()
|
||||
serialized_tensors = await asyncio.gather(
|
||||
*(
|
||||
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
|
||||
for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
|
||||
)
|
||||
)
|
||||
|
||||
# call RPC on remote server
|
||||
size = sum(t.element_size() * t.nelement() for t in inputs)
|
||||
if size > MAX_UNARY_PAYLOAD_SIZE:
|
||||
deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, **kwargs)
|
||||
else:
|
||||
deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, **kwargs)
|
||||
|
||||
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
|
||||
|
||||
|
||||
async def _forward_stream(
|
||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
|
||||
) -> List[torch.Tensor]:
|
||||
split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
|
||||
|
||||
outputs = await stub.rpc_forward_stream(
|
||||
amap_in_executor(
|
||||
lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
|
||||
iter_as_aiter(split),
|
||||
),
|
||||
)
|
||||
|
||||
tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
|
||||
return await deserialize_tensor_stream(tensors_stream)
|
||||
|
||||
|
||||
async def _forward_unary(
|
||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
|
||||
) -> List[torch.Tensor]:
|
||||
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
|
||||
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
|
||||
)
|
||||
return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
||||
|
||||
|
||||
async def _backward_stream(
|
||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
|
||||
) -> List[torch.Tensor]:
|
||||
split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
|
||||
|
||||
grad_inputs = await stub.rpc_backward_stream(
|
||||
amap_in_executor(
|
||||
lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
|
||||
iter_as_aiter(split),
|
||||
),
|
||||
)
|
||||
tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
|
||||
return await deserialize_tensor_stream(tensors_stream)
|
||||
|
||||
|
||||
async def run_remote_backward(
|
||||
uid: ModuleUID,
|
||||
stub: StubBase,
|
||||
rpc_info: RPCInfo,
|
||||
inputs: torch.Tensor,
|
||||
grad_outputs: List[torch.Tensor],
|
||||
*extra_tensors: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> Sequence[torch.Tensor]:
|
||||
"""
|
||||
Serializes grad outputs and calls "rpc_backward" on a remote server.
|
||||
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
|
||||
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
|
||||
"""
|
||||
|
||||
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
|
||||
inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
|
||||
|
||||
# Modify forward_schema to support prompts
|
||||
args_schema, kwargs_schema = rpc_info["forward_schema"]
|
||||
assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
|
||||
# TODO generalize this
|
||||
prompts_schema = next(iter(args_schema))
|
||||
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
|
||||
|
||||
# Asynchronous serialization
|
||||
loop = asyncio.get_running_loop()
|
||||
serialized_tensors = await asyncio.gather(
|
||||
*(
|
||||
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
|
||||
for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
|
||||
)
|
||||
)
|
||||
|
||||
size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
|
||||
if size > MAX_UNARY_PAYLOAD_SIZE:
|
||||
deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, **kwargs)
|
||||
else:
|
||||
deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, **kwargs)
|
||||
|
||||
return deserialized_grad_inputs
|
||||
|
||||
|
||||
async def _backward_unary(
|
||||
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
|
||||
) -> List[torch.Tensor]:
|
||||
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
|
||||
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
|
||||
)
|
||||
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|
@ -0,0 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from hivemind.proto.runtime_pb2 import ExpertRequest
|
||||
|
||||
|
||||
class SpendingPolicyBase(ABC):
|
||||
@abstractmethod
|
||||
def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
|
||||
pass
|
||||
|
||||
|
||||
class NoSpendingPolicy(SpendingPolicyBase):
|
||||
def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
|
||||
return 0.0
|
@ -0,0 +1,198 @@
|
||||
import multiprocessing as mp
|
||||
import multiprocessing.pool
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from queue import SimpleQueue
|
||||
from selectors import EVENT_READ, DefaultSelector
|
||||
from statistics import mean
|
||||
from time import time
|
||||
from typing import Dict, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
from hivemind.moe.server.module_backend import ModuleBackend
|
||||
from hivemind.utils import get_logger
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Runtime(threading.Thread):
|
||||
"""
|
||||
A group of processes that processes incoming requests for multiple module backends on a shared device.
|
||||
Runtime is usually created and managed by Server, humans need not apply.
|
||||
|
||||
For debugging, you can start runtime manually with .start() or .run()
|
||||
|
||||
>>> module_backends = {'block_uid': ModuleBackend(**kwargs)}
|
||||
>>> runtime = Runtime(module_backends)
|
||||
>>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run()
|
||||
>>> runtime.ready.wait() # await for runtime to load all blocks on device and create request pools
|
||||
>>> future = runtime.module_backends['block_uid'].forward_pool.submit_task(*module_inputs)
|
||||
>>> print("Returned:", future.result())
|
||||
>>> runtime.shutdown()
|
||||
|
||||
:param module_backends: a dict [block uid -> ModuleBackend]
|
||||
:param prefetch_batches: form up to this many batches in advance
|
||||
:param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
|
||||
:param device: if specified, moves all blocks and data to this device via .to(device=device).
|
||||
If you want to manually specify devices for each block (in their forward pass), leave device=None (default)
|
||||
|
||||
:param stats_report_interval: interval to collect and log statistics about runtime performance
|
||||
"""
|
||||
|
||||
SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module_backends: Dict[str, ModuleBackend],
|
||||
prefetch_batches: int = 1,
|
||||
sender_threads: int = 1,
|
||||
device: torch.device = None,
|
||||
stats_report_interval: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.module_backends = module_backends
|
||||
self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
|
||||
self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
|
||||
self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
|
||||
self.shutdown_trigger = mp.Event()
|
||||
self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches
|
||||
|
||||
self.stats_report_interval = stats_report_interval
|
||||
if self.stats_report_interval is not None:
|
||||
self.stats_reporter = StatsReporter(self.stats_report_interval)
|
||||
|
||||
def run(self):
|
||||
for pool in self.pools:
|
||||
if not pool.is_alive():
|
||||
pool.start()
|
||||
if self.device is not None:
|
||||
for backend in self.module_backends.values():
|
||||
backend.module.to(self.device)
|
||||
|
||||
with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
|
||||
try:
|
||||
self.ready.set()
|
||||
if self.stats_report_interval is not None:
|
||||
self.stats_reporter.start()
|
||||
logger.info("Started")
|
||||
|
||||
batch_iterator = self.iterate_minibatches_from_pools()
|
||||
if self.prefetch_batches > 0:
|
||||
batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
|
||||
|
||||
for pool, batch_index, batch in batch_iterator:
|
||||
logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
|
||||
|
||||
start = time()
|
||||
try:
|
||||
outputs = pool.process_func(*batch)
|
||||
output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
|
||||
|
||||
batch_processing_time = time() - start
|
||||
|
||||
batch_size = outputs[0].size(0)
|
||||
logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
|
||||
|
||||
if self.stats_report_interval is not None:
|
||||
self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except BaseException as exception:
|
||||
logger.exception(f"Caught {exception}, attempting to recover")
|
||||
output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])
|
||||
|
||||
finally:
|
||||
if not self.shutdown_trigger.is_set():
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self):
|
||||
"""Gracefully terminate a running runtime."""
|
||||
logger.info("Shutting down")
|
||||
self.ready.clear()
|
||||
|
||||
if self.stats_report_interval is not None:
|
||||
self.stats_reporter.stop.set()
|
||||
self.stats_reporter.join()
|
||||
|
||||
logger.debug("Terminating pools")
|
||||
for pool in self.pools:
|
||||
if pool.is_alive():
|
||||
pool.shutdown()
|
||||
logger.debug("Pools terminated")
|
||||
|
||||
# trigger background thread to shutdown
|
||||
self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
|
||||
self.shutdown_trigger.set()
|
||||
|
||||
def iterate_minibatches_from_pools(self, timeout=None):
|
||||
"""
|
||||
Chooses pool according to priority, then copies exposed batch and frees the buffer
|
||||
"""
|
||||
with DefaultSelector() as selector:
|
||||
for pool in self.pools:
|
||||
selector.register(pool.batch_receiver, EVENT_READ, pool)
|
||||
selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
|
||||
|
||||
while True:
|
||||
# wait until at least one batch_receiver becomes available
|
||||
logger.debug("Waiting for inputs from task pools")
|
||||
ready_fds = selector.select()
|
||||
ready_objects = {key.data for (key, events) in ready_fds}
|
||||
if self.SHUTDOWN_TRIGGER in ready_objects:
|
||||
break # someone asked us to shutdown, break from the loop
|
||||
|
||||
logger.debug("Choosing the pool with first priority")
|
||||
|
||||
pool = min(ready_objects, key=lambda pool: pool.priority)
|
||||
|
||||
logger.debug(f"Loading batch from {pool.name}")
|
||||
batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
|
||||
logger.debug(f"Loaded batch from {pool.name}")
|
||||
yield pool, batch_index, batch_tensors
|
||||
|
||||
|
||||
BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float)))
|
||||
|
||||
|
||||
class StatsReporter(threading.Thread):
|
||||
def __init__(self, report_interval: int):
|
||||
super().__init__()
|
||||
self.report_interval = report_interval
|
||||
self.stop = threading.Event()
|
||||
self.stats_queue = SimpleQueue()
|
||||
|
||||
def run(self):
|
||||
while not self.stop.wait(self.report_interval):
|
||||
pool_batch_stats = defaultdict(list)
|
||||
while not self.stats_queue.empty():
|
||||
pool_uid, batch_stats = self.stats_queue.get()
|
||||
pool_batch_stats[pool_uid].append(batch_stats)
|
||||
|
||||
total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
|
||||
logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:")
|
||||
for pool_uid, pool_stats in pool_batch_stats.items():
|
||||
total_batches = len(pool_stats)
|
||||
total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
|
||||
avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
|
||||
total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
|
||||
batches_to_time = total_batches / total_time
|
||||
batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch")
|
||||
|
||||
examples_to_time = total_examples / total_time
|
||||
example_performance = f"{examples_to_time:.2f} " + (
|
||||
"examples/s" if examples_to_time > 1 else "s/example"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{pool_uid}: "
|
||||
f"{total_batches} batches ({batch_performance}), "
|
||||
f"{total_examples} examples ({example_performance}), "
|
||||
f"avg batch size {avg_batch_size:.2f}"
|
||||
)
|
||||
|
||||
def report_stats(self, pool_uid, batch_size, processing_time):
|
||||
batch_stats = BatchStats(batch_size, processing_time)
|
||||
self.stats_queue.put_nowait((pool_uid, batch_stats))
|
@ -0,0 +1,175 @@
|
||||
import ctypes
|
||||
import multiprocessing as mp
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from queue import PriorityQueue
|
||||
from typing import Any, Generator, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from hivemind import MPFuture, get_logger, use_hivemind_log_handler
|
||||
from hivemind.moe.server.task_pool import TaskPoolBase
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
@dataclass(order=True, frozen=True)
|
||||
class Task:
|
||||
priority: float
|
||||
time_submitted: float
|
||||
future: MPFuture = field(compare=False)
|
||||
args: Sequence[torch.Tensor] = field(compare=False)
|
||||
|
||||
@property
|
||||
def uid(self) -> int:
|
||||
return self.future._uid
|
||||
|
||||
|
||||
class PrioritizedTaskPool(TaskPoolBase):
|
||||
"""
|
||||
Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
|
||||
returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
|
||||
A single PrioritizedTaskPool services a specific function (e.g. layer1.forward, layer2.forward or layer1.backward)
|
||||
|
||||
:note: unlike hivemind.moe TaskPool, this pool does *not* combine incoming requests into batches.
|
||||
This would require grouping requests of different length.
|
||||
|
||||
:param process_func: function to be applied to every formed batch; called by Runtime
|
||||
Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
|
||||
:param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
|
||||
Measured in the total number of tokens (i.e. batch size * sequence length)
|
||||
|
||||
:param name: pool name, used for logging
|
||||
:param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
|
||||
:param start: if True, start automatically at the end of __init__
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_func: callable,
|
||||
max_batch_size: int,
|
||||
name: str,
|
||||
min_batch_size=1,
|
||||
daemon=True,
|
||||
start=False,
|
||||
):
|
||||
super().__init__(process_func, daemon=daemon, name=name)
|
||||
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
||||
|
||||
self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers
|
||||
self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime
|
||||
|
||||
self._prioritizer_thread = threading.Thread(
|
||||
name=self.name + "_prioritizer",
|
||||
target=self._prioritize_tasks,
|
||||
args=[self.submitted_tasks, self._ordered_tasks],
|
||||
daemon=True,
|
||||
)
|
||||
self._dispatched_tasks = {}
|
||||
self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
|
||||
self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
|
||||
self.priority = float("inf"), float("inf") # (first task priority, first task timestamp)
|
||||
if start:
|
||||
self.start()
|
||||
|
||||
@staticmethod
|
||||
def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
|
||||
"""Read tasks from incoming queue and put them into a local priority queue"""
|
||||
while True:
|
||||
task = submitted_tasks.get()
|
||||
if task is None:
|
||||
logger.debug("Shutting down prioritizer thread")
|
||||
break
|
||||
|
||||
ordered_tasks.put(task, block=True)
|
||||
|
||||
def start(self):
|
||||
assert not self.is_alive() and not self._prioritizer_thread.is_alive()
|
||||
self._prioritizer_thread.start()
|
||||
super().start()
|
||||
|
||||
def shutdown(self, timeout: Optional[float] = None):
|
||||
self.submitted_tasks.put(None)
|
||||
self.terminate()
|
||||
self._prioritizer_thread.join(timeout)
|
||||
|
||||
def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
|
||||
"""Add task to this pool's queue, return Future for its output"""
|
||||
task = Task(priority, time.monotonic(), MPFuture(), args)
|
||||
if self.get_task_size(task) > self.max_batch_size:
|
||||
exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
|
||||
task.future.set_exception(exc)
|
||||
else:
|
||||
self.submitted_tasks.put(task)
|
||||
self.batch_sender.send(None) # use this pipe to count the number of unfinished batches
|
||||
if (task.priority, task.time_submitted) < self.priority:
|
||||
self.priority = (task.priority, task.time_submitted)
|
||||
return task.future
|
||||
|
||||
def get_task_size(self, task: Task) -> int:
|
||||
"""compute task processing complexity; defaults to the total number of tokens"""
|
||||
if task.args and task.args[0].ndim >= 2:
|
||||
return task.args[0].shape[0] * task.args[0].shape[1]
|
||||
return 1
|
||||
|
||||
def load_batch_to_runtime(
|
||||
self, timeout: Optional[float] = None, device: Optional[torch.device] = None
|
||||
) -> Tuple[Any, List[torch.Tensor]]:
|
||||
"""receive next batch of arrays"""
|
||||
task = self._ordered_tasks.get(block=True, timeout=timeout)
|
||||
batch_inputs = [
|
||||
tensor.detach().to(device, non_blocking=True).requires_grad_(tensor.requires_grad) for tensor in task.args
|
||||
]
|
||||
self._dispatched_tasks[task.uid] = task
|
||||
self.batch_receiver.recv() # reduce the number of active batches
|
||||
if not self._ordered_tasks.empty():
|
||||
first_remaining_task: Task = self._ordered_tasks.queue[0]
|
||||
self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted)
|
||||
return task.uid, batch_inputs
|
||||
|
||||
def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
|
||||
"""send results for a processed batch, previously loaded through load_batch_to_runtime"""
|
||||
batch_outputs = [
|
||||
tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
|
||||
for tensor in batch_outputs
|
||||
]
|
||||
|
||||
task = self._dispatched_tasks.pop(uid, None)
|
||||
if task is None:
|
||||
logger.error(
|
||||
f"Internal error: task task with index {uid} is missing from the dictionary; " f"Could not set result"
|
||||
)
|
||||
else:
|
||||
task.future.set_result(batch_outputs)
|
||||
|
||||
def send_exception_from_runtime(self, uid: int, exception: BaseException):
|
||||
task = self._dispatched_tasks.pop(uid, None)
|
||||
if task is None:
|
||||
logger.error(
|
||||
f"Internal error: task task with index {uid} is missing from the dictionary; "
|
||||
f"Could not set exception {exception}"
|
||||
)
|
||||
else:
|
||||
task.future.set_exception(exception)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
mp.Event().wait()
|
||||
|
||||
@property
|
||||
def empty(self):
|
||||
return not self.batch_receiver.poll()
|
||||
|
||||
@property
|
||||
def priority(self) -> Tuple[float, float]:
|
||||
"""The priority of this pool equals the (priority, timestamp) of the most important task in it."""
|
||||
return float(self._priority.value), float(self._oldest_undispatched_timestamp.value)
|
||||
|
||||
@priority.setter
|
||||
def priority(self, item: Tuple[float, float]):
|
||||
assert len(item) == 2
|
||||
self._priority.value = float(item[0])
|
||||
self._oldest_undispatched_timestamp.value = float(item[1])
|
||||
|
||||
def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
|
||||
raise NotImplementedError()
|
@ -0,0 +1,20 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from hivemind.moe.server.task_pool import Task
|
||||
|
||||
|
||||
class TaskPrioritizerBase(ABC):
|
||||
"""Abstract class for TaskPrioritizer whose reponsibility is to evaluate task priority"""
|
||||
|
||||
@abstractmethod
|
||||
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
|
||||
"""Evaluates task value by the amout of points given, task input and additional kwargs. Lower priority is better"""
|
||||
pass
|
||||
|
||||
|
||||
class DummyTaskPrioritizer(TaskPrioritizerBase):
|
||||
"""Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
|
||||
|
||||
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
|
||||
return 0.0
|
@ -0,0 +1,71 @@
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from src.server.runtime import Runtime
|
||||
from src.server.task_pool import PrioritizedTaskPool
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_priority_pools():
|
||||
outputs_queue = mp.SimpleQueue()
|
||||
results_valid = mp.Event()
|
||||
|
||||
def dummy_pool_func(x):
|
||||
time.sleep(0.1)
|
||||
y = x**2
|
||||
outputs_queue.put((x, y))
|
||||
return (y,)
|
||||
|
||||
class DummyBackend:
|
||||
def __init__(self, pools):
|
||||
self.pools = pools
|
||||
|
||||
def get_pools(self):
|
||||
return self.pools
|
||||
|
||||
pools = (
|
||||
PrioritizedTaskPool(dummy_pool_func, name="A", max_batch_size=1),
|
||||
PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
|
||||
)
|
||||
|
||||
runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
|
||||
runtime.start()
|
||||
|
||||
def process_tasks():
|
||||
futures = []
|
||||
futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
|
||||
futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
|
||||
time.sleep(0.01)
|
||||
futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
|
||||
futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
|
||||
futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
|
||||
futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
|
||||
futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
|
||||
futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
|
||||
futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
|
||||
for i, f in enumerate(futures):
|
||||
assert f.result()[0].item() == i**2
|
||||
results_valid.set()
|
||||
|
||||
proc = mp.Process(target=process_tasks)
|
||||
proc.start()
|
||||
proc.join()
|
||||
assert results_valid.is_set()
|
||||
|
||||
ordered_outputs = []
|
||||
while not outputs_queue.empty():
|
||||
ordered_outputs.append(outputs_queue.get()[0].item())
|
||||
|
||||
assert ordered_outputs == [0, 5, 1, 2, 6, 8, 3, 4, 7]
|
||||
# 0 - first batch is loaded immediately, before everything else
|
||||
# 5 - highest priority task overall
|
||||
# 1 - first of several tasks with equal lowest priority (1)
|
||||
# 2 - second earliest task with priority 1, fetched from pool B
|
||||
# 6 - third earliest task with priority 1, fetched from pool A again
|
||||
# 8 - last priority-1 task, pool B
|
||||
# 3 - task with priority 2 from pool A
|
||||
# 4 - task with priority 10 from pool A
|
||||
# 7 - task with priority 11 from pool B
|
Loading…
Reference in New Issue