Priority tasks (#47)

* priority in handlers and backend pools
* simple points system on server side
* priortize task in handler before submit task
* fix tests
* s/expert/block/g

Co-authored-by: justheuristic <justheuristic@gmail.com>
fix-pb2
Pavel Samygin 2 years ago committed by GitHub
parent 892d18fea7
commit 50535a8435
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,15 +31,19 @@ def main():
parser.add_argument('--num_handlers', type=int, default=8, required=False,
help='server will use this many processes to handle incoming requests')
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all expert operations')
help='Minimum required batch size for all operations (in total tokens)')
parser.add_argument('--max_batch_size', type=int, default=16384,
help='The total number of tokens in the same batch will not exceed this value')
parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
help='Pre-form this many subsequent batches while GPU is processing the current one')
parser.add_argument('--sender_threads', type=int, default=1, required=False,
help='Use this many threads to pass results/exceptions from Runtime to Pools')
parser.add_argument('--inference_max_length', type=int, default=16384,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
parser.add_argument('--device', type=str, default=None, required=False,
help='all experts will use this device in torch notation; default: cuda if available else cpu')
help='all blocks will use this device in torch notation; default: cuda if available else cpu')
parser.add_argument("--torch_dtype", type=str, default="auto",
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
@ -58,7 +62,7 @@ def main():
'on the first run and uses these estimates for future runs. '
'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
parser.add_argument('--update_period', type=float, required=False, default=30,
help='Server will report experts to DHT once in this many seconds')
help='Server will report blocks to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,
help='DHT entries will expire after this many seconds')
parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],

@ -2,3 +2,4 @@ from src.client.inference_session import RemoteSequentialInferenceSession, Remot
from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
from src.client.sequence_manager import RemoteSequenceManager
from src.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase

@ -43,6 +43,7 @@ class RemoteTransformerBlockInferenceSession:
outputs_aiter: AsyncIterator,
*,
max_length: int,
points: int = 0,
):
self.uid, self.rpc_info = uid, rpc_info
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
@ -50,7 +51,7 @@ class RemoteTransformerBlockInferenceSession:
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length))
self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
self.stepped = False
self.closed = False

@ -0,0 +1,156 @@
"""
Utility functions that call RPC forward or backward on a single remote server
"""
import asyncio
from typing import Iterable, List, Sequence, Tuple
import torch
from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
from hivemind.p2p import StubBase
from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
from hivemind.utils.streaming import split_for_streaming
from src.data_structures import ModuleUID, RPCInfo
async def run_remote_forward(
uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b"", **kwargs
) -> Tuple[torch.Tensor, ...]:
"""
Serializes input tensors and calls "rpc_forward" on a remote server.
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
# detach to avoid pickling the computation graph
assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = (inputs, kwargs)
# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
# TODO: rm this assert when support arbitrary number of input tensors
assert len(args_schema) == 1 and len(inputs) == 2
forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
if not nested_compare(forward_inputs, forward_schema_with_prompts):
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
forward_inputs = nested_flatten(forward_inputs)
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
)
)
# call RPC on remote server
size = sum(t.element_size() * t.nelement() for t in inputs)
if size > MAX_UNARY_PAYLOAD_SIZE:
deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, **kwargs)
else:
deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, **kwargs)
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
async def _forward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
outputs = await stub.rpc_forward_stream(
amap_in_executor(
lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
iter_as_aiter(split),
),
)
tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
return await deserialize_tensor_stream(tensors_stream)
async def _forward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
)
return [deserialize_torch_tensor(t) for t in outputs.tensors]
async def _backward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
grad_inputs = await stub.rpc_backward_stream(
amap_in_executor(
lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
iter_as_aiter(split),
),
)
tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
return await deserialize_tensor_stream(tensors_stream)
async def run_remote_backward(
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
inputs: torch.Tensor,
grad_outputs: List[torch.Tensor],
*extra_tensors: torch.Tensor,
**kwargs,
) -> Sequence[torch.Tensor]:
"""
Serializes grad outputs and calls "rpc_backward" on a remote server.
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
# TODO generalize this
prompts_schema = next(iter(args_schema))
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
)
)
size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
if size > MAX_UNARY_PAYLOAD_SIZE:
deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, **kwargs)
else:
deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, **kwargs)
return deserialized_grad_inputs
async def _backward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
)
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]

@ -9,6 +9,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src.client.spending_policy import NoSpendingPolicy
from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
from src.dht_utils import get_remote_module_infos
from src.server.handler import TransformerConnectionHandler
@ -24,6 +25,7 @@ class RemoteSequenceManager:
"""
def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3):
assert len(block_uids) > 0, "Sequences must contain at least one block"
self.dht, self.p2p = dht, p2p
self.block_uids: List[ModuleUID] = list(block_uids)
self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
@ -39,7 +41,7 @@ class RemoteSequenceManager:
assert info is not None, f"Found no remote peers for block {uid}"
assert self.spans_by_priority and self.spans_containing_block
def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> Sequence[RemoteSpanInfo]:
def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
"""
Form a sequence of remote servers that collectively serve all consecutive layers

@ -1,102 +1,22 @@
"""
A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner
"""
import asyncio
import logging
from typing import List, Optional, Sequence, Tuple
import torch
from hivemind import serialize_torch_tensor
from hivemind.moe.client.expert import expert_backward, expert_forward
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import StubBase
from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
from src.client.sequence_manager import RemoteSequenceManager
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from src.server.handler import TransformerConnectionHandler
from src.utils.misc import DUMMY, is_dummy
MAX_TOKENS_IN_BATCH = 1024
async def run_expert_forward(
uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
) -> Tuple[torch.Tensor, ...]:
"""
Serializes input tensors and calls "expert_forward".
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
# detach to avoid pickling the computation graph
assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = (inputs, kwargs)
# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
# TODO: rm this assert when support arbitrary number of input tensors
assert len(args_schema) == 1 and len(inputs) == 2
forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
if not nested_compare(forward_inputs, forward_schema_with_prompts):
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
forward_inputs = nested_flatten(forward_inputs)
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
)
)
deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
flat_outputs = tuple(deserialized_outputs)
return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
async def run_expert_backward(
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
inputs: torch.Tensor,
grad_outputs: List[torch.Tensor],
*extra_tensors: torch.Tensor,
) -> Sequence[torch.Tensor]:
"""
Serializes grad outputs and calls "expert_backward".
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
# TODO generalize this
prompts_schema = next(iter(args_schema))
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
)
)
deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
return deserialized_grad_inputs
async def sequential_forward(
inputs: torch.Tensor,
prompts: torch.Tensor,
@ -121,16 +41,17 @@ async def sequential_forward(
sequences = sequence_manager.make_sequence(start_index, end_index)
intermediate_inputs = []
done_sequences = []
outputs = inputs
while len(sequences) > 0:
while True:
span = sequences.pop(0)
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
try:
span = sequences.pop(0)
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
(outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
(outputs,) = await run_remote_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
@ -171,7 +92,7 @@ async def sequential_backward(
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
try:
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
grad_outputs, *span_grad_prompts = await run_expert_backward(
grad_outputs, *span_grad_prompts = await run_remote_backward(
span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end]
)
grad_outputs = [grad_outputs]

@ -0,0 +1,14 @@
from abc import ABC, abstractmethod
from hivemind.proto.runtime_pb2 import ExpertRequest
class SpendingPolicyBase(ABC):
@abstractmethod
def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
pass
class NoSpendingPolicy(SpendingPolicyBase):
def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
return 0.0

@ -1,45 +1,20 @@
"""Code for serving bloom blocks via hivemind-server"""
from queue import Empty
from typing import Any, Dict, Optional, Sequence, Tuple
import torch
from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.moe.server.task_pool import TaskPool
from hivemind.utils import InvalidStateError, get_logger
from hivemind.utils import get_logger
from src.bloom.from_pretrained import BloomBlock
from src.server.cache import MemoryCache
from src.server.task_pool import PrioritizedTaskPool
from src.utils.misc import is_dummy
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class InferenceTaskPool(TaskPool):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1"
def iterate_minibatches(self, *args, **kwargs):
"""Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
while True:
try:
logger.debug(f"{self.name} getting next task")
task = self.tasks.get(timeout=self.timeout)
except Empty:
logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
continue
try:
if task.future.set_running_or_notify_cancel():
yield [task]
except InvalidStateError as e:
logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
class TransformerBackend(ModuleBackend):
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
@ -52,8 +27,15 @@ class TransformerBackend(ModuleBackend):
for name, buf in self.module.named_buffers():
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
self.inference_pool = InferenceTaskPool(
self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
max_batch_size = self.forward_pool.max_batch_size
self.inference_pool = PrioritizedTaskPool(
self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference"
)
self.forward_pool = PrioritizedTaskPool(
self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward"
)
self.backward_pool = PrioritizedTaskPool(
self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
)
self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
self.inference_schema = (
@ -94,9 +76,9 @@ class TransformerBackend(ModuleBackend):
cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
return (hidden_states,)
def get_pools(self) -> Sequence[TaskPool]:
def get_pools(self) -> Sequence[PrioritizedTaskPool]:
return self.forward_pool, self.backward_pool, self.inference_pool
def get_info(self) -> Dict[str, Any]:
"""Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
"""Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
return dict(super().get_info(), inference_schema=self.inference_schema)

@ -1,5 +1,5 @@
import contextlib
from typing import AsyncIterator, Dict, List, Optional, Sequence, Union
from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
import torch
from hivemind import (
@ -7,6 +7,7 @@ from hivemind import (
MSGPackSerializer,
P2PContext,
TensorDescriptor,
deserialize_tensor_stream,
deserialize_torch_tensor,
nested_flatten,
serialize_torch_tensor,
@ -14,12 +15,13 @@ from hivemind import (
from hivemind.moe.server.connection_handler import ConnectionHandler
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
from hivemind.proto import runtime_pb2
from hivemind.utils import as_aiter
from hivemind.utils.asyncio import anext
from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
from hivemind.utils.streaming import split_for_streaming
from src.data_structures import CHAIN_DELIMITER, ModuleUID
from src.server.backend import TransformerBackend
from src.server.task_pool import PrioritizedTaskPool
from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
from src.utils.misc import DUMMY, is_dummy
@ -28,11 +30,41 @@ class TransformerConnectionHandler(ConnectionHandler):
module_backends: Dict[ModuleUID, TransformerBackend]
def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend], inference_max_length: int):
def __init__(
self,
dht: DHT,
module_backends: Dict[str, TransformerBackend],
inference_max_length: int,
task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
):
super().__init__(dht, module_backends)
for module_backend in self.module_backends.values():
assert isinstance(module_backend, TransformerBackend)
self.inference_max_length = inference_max_length
self._prioritizer = task_prioritizer
async def _gather_inputs(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> Tuple[str, List[torch.Tensor], Dict]:
block_uid, metadata = None, None
def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
nonlocal block_uid, metadata
if block_uid is None:
block_uid = req.uid
elif block_uid != req.uid:
raise ValueError("Block uids differ in one request")
if metadata is None:
metadata = MSGPackSerializer.loads(req.metadata) if req.metadata else {}
return req.tensors
tensors_stream = amap_in_executor(_unpack, requests)
inputs = await deserialize_tensor_stream(tensors_stream)
assert isinstance(block_uid, str) and isinstance(metadata, dict)
return block_uid, inputs, metadata
async def rpc_inference(
self,
@ -47,13 +79,18 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
max_length = metadata.get("max_length")
points = metadata.get("points", 0)
if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none")
assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
assert isinstance(
points, (float, int)
), f"rpc_inference should have number of points as a number or None, got {points}"
if not 0 <= max_length <= self.inference_max_length:
raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
point_per_piece = points / max_length if max_length > 0 else 0.0
batch_size = request.tensors[0].size[0] if request.tensors else 1
cache_metadata = torch.tensor(
@ -98,8 +135,19 @@ class TransformerConnectionHandler(ConnectionHandler):
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
assert isinstance(
backend.inference_pool, PrioritizedTaskPool
), "petals support only prioritized pools"
priority = self._prioritizer.prioritize(
cache_metadata,
hidden_states,
hypo_ids,
points=point_per_piece / len(requested_backends),
backend=backend,
type="inference",
)
(hidden_states,) = await backend.inference_pool.submit_task(
cache_metadata, hidden_states, hypo_ids
cache_metadata, hidden_states, hypo_ids, priority=priority
)
# serialize and send last layer outputs
@ -123,8 +171,15 @@ class TransformerConnectionHandler(ConnectionHandler):
flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
requested_uids = self._check_uids(request.uid)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_forward should have number of points as number or None, got {points}"
hidden_states = await _rpc_forward(
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
# Serialize output and respond to client
@ -139,11 +194,17 @@ class TransformerConnectionHandler(ConnectionHandler):
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
# Parse requests and prepare backends
uid_str, flat_inputs = await self._gather_inputs(requests, context)
uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
requested_uids = self._check_uids(uid_str)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_forward_stream should have number of points as number or None, got {points}"
hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
hidden_states = await _rpc_forward(
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
# Serialize the overall output
@ -164,8 +225,15 @@ class TransformerConnectionHandler(ConnectionHandler):
flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
requested_uids = self._check_uids(request.uid)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_backward should have number of points as number or None, got {points}"
grads = await _rpc_backward(
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
# Modify grad_inputs_schema to support grad_prompts
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
@ -187,11 +255,17 @@ class TransformerConnectionHandler(ConnectionHandler):
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
uids_header, flat_tensors = await self._gather_inputs(requests, context)
uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
requested_uids = self._check_uids(uids_header)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_backward_stream should have number of points as number or None, got {points}"
grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
grads = await _rpc_backward(
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
# Modify grad_inputs_schema to support grad_prompts
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
@ -244,7 +318,12 @@ class TransformerConnectionHandler(ConnectionHandler):
yield handles
async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]) -> torch.Tensor:
async def _rpc_forward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> torch.Tensor:
"""
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
@ -267,7 +346,15 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
for backend, prompt in zip(requested_backends, prompts):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
(hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
)
(hidden_states,) = await backend.forward_pool.submit_task(
hidden_states,
priority=priority,
)
assert isinstance(hidden_states, torch.Tensor)
assert (
hidden_states.ndim == 3
@ -278,7 +365,10 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
async def _rpc_backward(
*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
inputs, grad_outputs, prompts = flat_tensors
# Cast inputs & grad outputs to backend dtype
@ -298,7 +388,12 @@ async def _rpc_backward(
if not is_dummy(prompt):
inputs[:, : prompt.shape[1]] += prompt
inter_inputs.append(inputs)
(inputs,) = await backend.forward_pool.submit_task(inputs)
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
)
(inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
assert isinstance(inputs, torch.Tensor)
if not is_dummy(prompts[-1]):
@ -309,7 +404,12 @@ async def _rpc_backward(
grad_prompts_reversed = []
# Run a chain of requested backends
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
)
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))

@ -0,0 +1,198 @@
import multiprocessing as mp
import multiprocessing.pool
import threading
from collections import defaultdict
from itertools import chain
from queue import SimpleQueue
from selectors import EVENT_READ, DefaultSelector
from statistics import mean
from time import time
from typing import Dict, NamedTuple, Optional
import torch
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.utils import get_logger
from prefetch_generator import BackgroundGenerator
logger = get_logger(__name__)
class Runtime(threading.Thread):
"""
A group of processes that processes incoming requests for multiple module backends on a shared device.
Runtime is usually created and managed by Server, humans need not apply.
For debugging, you can start runtime manually with .start() or .run()
>>> module_backends = {'block_uid': ModuleBackend(**kwargs)}
>>> runtime = Runtime(module_backends)
>>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run()
>>> runtime.ready.wait() # await for runtime to load all blocks on device and create request pools
>>> future = runtime.module_backends['block_uid'].forward_pool.submit_task(*module_inputs)
>>> print("Returned:", future.result())
>>> runtime.shutdown()
:param module_backends: a dict [block uid -> ModuleBackend]
:param prefetch_batches: form up to this many batches in advance
:param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
:param device: if specified, moves all blocks and data to this device via .to(device=device).
If you want to manually specify devices for each block (in their forward pass), leave device=None (default)
:param stats_report_interval: interval to collect and log statistics about runtime performance
"""
SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
def __init__(
self,
module_backends: Dict[str, ModuleBackend],
prefetch_batches: int = 1,
sender_threads: int = 1,
device: torch.device = None,
stats_report_interval: Optional[int] = None,
):
super().__init__()
self.module_backends = module_backends
self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
self.shutdown_trigger = mp.Event()
self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches
self.stats_report_interval = stats_report_interval
if self.stats_report_interval is not None:
self.stats_reporter = StatsReporter(self.stats_report_interval)
def run(self):
for pool in self.pools:
if not pool.is_alive():
pool.start()
if self.device is not None:
for backend in self.module_backends.values():
backend.module.to(self.device)
with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
try:
self.ready.set()
if self.stats_report_interval is not None:
self.stats_reporter.start()
logger.info("Started")
batch_iterator = self.iterate_minibatches_from_pools()
if self.prefetch_batches > 0:
batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
for pool, batch_index, batch in batch_iterator:
logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
start = time()
try:
outputs = pool.process_func(*batch)
output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
batch_processing_time = time() - start
batch_size = outputs[0].size(0)
logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
if self.stats_report_interval is not None:
self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
except KeyboardInterrupt:
raise
except BaseException as exception:
logger.exception(f"Caught {exception}, attempting to recover")
output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])
finally:
if not self.shutdown_trigger.is_set():
self.shutdown()
def shutdown(self):
"""Gracefully terminate a running runtime."""
logger.info("Shutting down")
self.ready.clear()
if self.stats_report_interval is not None:
self.stats_reporter.stop.set()
self.stats_reporter.join()
logger.debug("Terminating pools")
for pool in self.pools:
if pool.is_alive():
pool.shutdown()
logger.debug("Pools terminated")
# trigger background thread to shutdown
self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
self.shutdown_trigger.set()
def iterate_minibatches_from_pools(self, timeout=None):
"""
Chooses pool according to priority, then copies exposed batch and frees the buffer
"""
with DefaultSelector() as selector:
for pool in self.pools:
selector.register(pool.batch_receiver, EVENT_READ, pool)
selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
while True:
# wait until at least one batch_receiver becomes available
logger.debug("Waiting for inputs from task pools")
ready_fds = selector.select()
ready_objects = {key.data for (key, events) in ready_fds}
if self.SHUTDOWN_TRIGGER in ready_objects:
break # someone asked us to shutdown, break from the loop
logger.debug("Choosing the pool with first priority")
pool = min(ready_objects, key=lambda pool: pool.priority)
logger.debug(f"Loading batch from {pool.name}")
batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
logger.debug(f"Loaded batch from {pool.name}")
yield pool, batch_index, batch_tensors
BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float)))
class StatsReporter(threading.Thread):
def __init__(self, report_interval: int):
super().__init__()
self.report_interval = report_interval
self.stop = threading.Event()
self.stats_queue = SimpleQueue()
def run(self):
while not self.stop.wait(self.report_interval):
pool_batch_stats = defaultdict(list)
while not self.stats_queue.empty():
pool_uid, batch_stats = self.stats_queue.get()
pool_batch_stats[pool_uid].append(batch_stats)
total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:")
for pool_uid, pool_stats in pool_batch_stats.items():
total_batches = len(pool_stats)
total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
batches_to_time = total_batches / total_time
batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch")
examples_to_time = total_examples / total_time
example_performance = f"{examples_to_time:.2f} " + (
"examples/s" if examples_to_time > 1 else "s/example"
)
logger.info(
f"{pool_uid}: "
f"{total_batches} batches ({batch_performance}), "
f"{total_examples} examples ({example_performance}), "
f"avg batch size {avg_batch_size:.2f}"
)
def report_stats(self, pool_uid, batch_size, processing_time):
batch_stats = BatchStats(batch_size, processing_time)
self.stats_queue.put_nowait((pool_uid, batch_stats))

@ -71,9 +71,9 @@ class Server(threading.Thread):
runs Runtime (self.runtime) to process incoming requests.
"""
logger.info(f"Serving {len(self.module_backends)} blocks:")
for expert_name, backend in self.module_backends.items():
for block_name, backend in self.module_backends.items():
num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
logger.info(f"{block_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
if not self.dht.is_alive():
self.dht.run_in_background(await_ready=True)
@ -118,6 +118,8 @@ class Server(threading.Thread):
custom_module_path=None,
update_period: float = 30,
expiration: Optional[float] = None,
prefetch_batches: int = 1,
sender_threads: int = 1,
max_block_selection_delay: float = 1,
use_auth_token: Optional[str] = None,
load_in_8bit: bool = False,
@ -236,6 +238,8 @@ class Server(threading.Thread):
stats_report_interval=stats_report_interval,
update_period=update_period,
expiration=expiration,
prefetch_batches=prefetch_batches,
sender_threads=sender_threads,
start=start,
)

@ -0,0 +1,175 @@
import ctypes
import multiprocessing as mp
import threading
import time
from dataclasses import dataclass, field
from queue import PriorityQueue
from typing import Any, Generator, List, Optional, Sequence, Tuple
import torch
from hivemind import MPFuture, get_logger, use_hivemind_log_handler
from hivemind.moe.server.task_pool import TaskPoolBase
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@dataclass(order=True, frozen=True)
class Task:
priority: float
time_submitted: float
future: MPFuture = field(compare=False)
args: Sequence[torch.Tensor] = field(compare=False)
@property
def uid(self) -> int:
return self.future._uid
class PrioritizedTaskPool(TaskPoolBase):
"""
Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
A single PrioritizedTaskPool services a specific function (e.g. layer1.forward, layer2.forward or layer1.backward)
:note: unlike hivemind.moe TaskPool, this pool does *not* combine incoming requests into batches.
This would require grouping requests of different length.
:param process_func: function to be applied to every formed batch; called by Runtime
Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
:param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
Measured in the total number of tokens (i.e. batch size * sequence length)
:param name: pool name, used for logging
:param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
:param start: if True, start automatically at the end of __init__
"""
def __init__(
self,
process_func: callable,
max_batch_size: int,
name: str,
min_batch_size=1,
daemon=True,
start=False,
):
super().__init__(process_func, daemon=daemon, name=name)
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers
self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime
self._prioritizer_thread = threading.Thread(
name=self.name + "_prioritizer",
target=self._prioritize_tasks,
args=[self.submitted_tasks, self._ordered_tasks],
daemon=True,
)
self._dispatched_tasks = {}
self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
self.priority = float("inf"), float("inf") # (first task priority, first task timestamp)
if start:
self.start()
@staticmethod
def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
"""Read tasks from incoming queue and put them into a local priority queue"""
while True:
task = submitted_tasks.get()
if task is None:
logger.debug("Shutting down prioritizer thread")
break
ordered_tasks.put(task, block=True)
def start(self):
assert not self.is_alive() and not self._prioritizer_thread.is_alive()
self._prioritizer_thread.start()
super().start()
def shutdown(self, timeout: Optional[float] = None):
self.submitted_tasks.put(None)
self.terminate()
self._prioritizer_thread.join(timeout)
def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
"""Add task to this pool's queue, return Future for its output"""
task = Task(priority, time.monotonic(), MPFuture(), args)
if self.get_task_size(task) > self.max_batch_size:
exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
task.future.set_exception(exc)
else:
self.submitted_tasks.put(task)
self.batch_sender.send(None) # use this pipe to count the number of unfinished batches
if (task.priority, task.time_submitted) < self.priority:
self.priority = (task.priority, task.time_submitted)
return task.future
def get_task_size(self, task: Task) -> int:
"""compute task processing complexity; defaults to the total number of tokens"""
if task.args and task.args[0].ndim >= 2:
return task.args[0].shape[0] * task.args[0].shape[1]
return 1
def load_batch_to_runtime(
self, timeout: Optional[float] = None, device: Optional[torch.device] = None
) -> Tuple[Any, List[torch.Tensor]]:
"""receive next batch of arrays"""
task = self._ordered_tasks.get(block=True, timeout=timeout)
batch_inputs = [
tensor.detach().to(device, non_blocking=True).requires_grad_(tensor.requires_grad) for tensor in task.args
]
self._dispatched_tasks[task.uid] = task
self.batch_receiver.recv() # reduce the number of active batches
if not self._ordered_tasks.empty():
first_remaining_task: Task = self._ordered_tasks.queue[0]
self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted)
return task.uid, batch_inputs
def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
"""send results for a processed batch, previously loaded through load_batch_to_runtime"""
batch_outputs = [
tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
for tensor in batch_outputs
]
task = self._dispatched_tasks.pop(uid, None)
if task is None:
logger.error(
f"Internal error: task task with index {uid} is missing from the dictionary; " f"Could not set result"
)
else:
task.future.set_result(batch_outputs)
def send_exception_from_runtime(self, uid: int, exception: BaseException):
task = self._dispatched_tasks.pop(uid, None)
if task is None:
logger.error(
f"Internal error: task task with index {uid} is missing from the dictionary; "
f"Could not set exception {exception}"
)
else:
task.future.set_exception(exception)
def run(self, *args, **kwargs):
mp.Event().wait()
@property
def empty(self):
return not self.batch_receiver.poll()
@property
def priority(self) -> Tuple[float, float]:
"""The priority of this pool equals the (priority, timestamp) of the most important task in it."""
return float(self._priority.value), float(self._oldest_undispatched_timestamp.value)
@priority.setter
def priority(self, item: Tuple[float, float]):
assert len(item) == 2
self._priority.value = float(item[0])
self._oldest_undispatched_timestamp.value = float(item[1])
def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
raise NotImplementedError()

@ -0,0 +1,20 @@
from abc import ABC, abstractmethod
import torch
from hivemind.moe.server.task_pool import Task
class TaskPrioritizerBase(ABC):
"""Abstract class for TaskPrioritizer whose reponsibility is to evaluate task priority"""
@abstractmethod
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
"""Evaluates task value by the amout of points given, task input and additional kwargs. Lower priority is better"""
pass
class DummyTaskPrioritizer(TaskPrioritizerBase):
"""Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
return 0.0

@ -0,0 +1,71 @@
import multiprocessing as mp
import time
import pytest
import torch
from src.server.runtime import Runtime
from src.server.task_pool import PrioritizedTaskPool
@pytest.mark.forked
def test_priority_pools():
outputs_queue = mp.SimpleQueue()
results_valid = mp.Event()
def dummy_pool_func(x):
time.sleep(0.1)
y = x**2
outputs_queue.put((x, y))
return (y,)
class DummyBackend:
def __init__(self, pools):
self.pools = pools
def get_pools(self):
return self.pools
pools = (
PrioritizedTaskPool(dummy_pool_func, name="A", max_batch_size=1),
PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
)
runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
runtime.start()
def process_tasks():
futures = []
futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
time.sleep(0.01)
futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
for i, f in enumerate(futures):
assert f.result()[0].item() == i**2
results_valid.set()
proc = mp.Process(target=process_tasks)
proc.start()
proc.join()
assert results_valid.is_set()
ordered_outputs = []
while not outputs_queue.empty():
ordered_outputs.append(outputs_queue.get()[0].item())
assert ordered_outputs == [0, 5, 1, 2, 6, 8, 3, 4, 7]
# 0 - first batch is loaded immediately, before everything else
# 5 - highest priority task overall
# 1 - first of several tasks with equal lowest priority (1)
# 2 - second earliest task with priority 1, fetched from pool B
# 6 - third earliest task with priority 1, fetched from pool A again
# 8 - last priority-1 task, pool B
# 3 - task with priority 2 from pool A
# 4 - task with priority 10 from pool A
# 7 - task with priority 11 from pool B
Loading…
Cancel
Save