diff --git a/cli/run_server.py b/cli/run_server.py index 03055f7..ac7c9da 100644 --- a/cli/run_server.py +++ b/cli/run_server.py @@ -31,15 +31,19 @@ def main(): parser.add_argument('--num_handlers', type=int, default=8, required=False, help='server will use this many processes to handle incoming requests') parser.add_argument('--min_batch_size', type=int, default=1, - help='Minimum required batch size for all expert operations') + help='Minimum required batch size for all operations (in total tokens)') parser.add_argument('--max_batch_size', type=int, default=16384, help='The total number of tokens in the same batch will not exceed this value') + parser.add_argument('--prefetch_batches', type=int, default=1, required=False, + help='Pre-form this many subsequent batches while GPU is processing the current one') + parser.add_argument('--sender_threads', type=int, default=1, required=False, + help='Use this many threads to pass results/exceptions from Runtime to Pools') parser.add_argument('--inference_max_length', type=int, default=16384, help='Maximum total sequence length permitted per inference, defaults to 16384 tokens') parser.add_argument('--cache_dir', type=str, default=None, help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.') parser.add_argument('--device', type=str, default=None, required=False, - help='all experts will use this device in torch notation; default: cuda if available else cpu') + help='all blocks will use this device in torch notation; default: cuda if available else cpu') parser.add_argument("--torch_dtype", type=str, default="auto", help="Use this dtype to store block weights and do computations. " "By default, respect the dtypes in the pre-trained state dict.") @@ -58,7 +62,7 @@ def main(): 'on the first run and uses these estimates for future runs. ' 'If set to "eval", the script re-evaluates the throughput and overrides the cache.') parser.add_argument('--update_period', type=float, required=False, default=30, - help='Server will report experts to DHT once in this many seconds') + help='Server will report blocks to DHT once in this many seconds') parser.add_argument('--expiration', type=float, required=False, default=None, help='DHT entries will expire after this many seconds') parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[], diff --git a/src/client/__init__.py b/src/client/__init__.py index 165de67..e9217b1 100644 --- a/src/client/__init__.py +++ b/src/client/__init__.py @@ -2,3 +2,4 @@ from src.client.inference_session import RemoteSequentialInferenceSession, Remot from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock from src.client.sequence_manager import RemoteSequenceManager +from src.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase diff --git a/src/client/inference_session.py b/src/client/inference_session.py index bb1455f..812e953 100644 --- a/src/client/inference_session.py +++ b/src/client/inference_session.py @@ -43,6 +43,7 @@ class RemoteTransformerBlockInferenceSession: outputs_aiter: AsyncIterator, *, max_length: int, + points: int = 0, ): self.uid, self.rpc_info = uid, rpc_info self.num_blocks = uid.count(CHAIN_DELIMITER) + 1 @@ -50,7 +51,7 @@ class RemoteTransformerBlockInferenceSession: # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter - self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length)) + self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points)) self.stepped = False self.closed = False diff --git a/src/client/remote_forward_backward.py b/src/client/remote_forward_backward.py new file mode 100644 index 0000000..b8713ff --- /dev/null +++ b/src/client/remote_forward_backward.py @@ -0,0 +1,156 @@ +""" +Utility functions that call RPC forward or backward on a single remote server +""" +import asyncio +from typing import Iterable, List, Sequence, Tuple + +import torch +from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor +from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor +from hivemind.p2p import StubBase +from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE +from hivemind.proto import runtime_pb2 +from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter +from hivemind.utils.streaming import split_for_streaming + +from src.data_structures import ModuleUID, RPCInfo + + +async def run_remote_forward( + uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b"", **kwargs +) -> Tuple[torch.Tensor, ...]: + """ + Serializes input tensors and calls "rpc_forward" on a remote server. + Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198 + but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here. + """ + + # Note: *inputs are flattened input tensors that follow the expert's info['input_schema'] + # detach to avoid pickling the computation graph + assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}" + kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]} + + # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors + forward_inputs = (inputs, kwargs) + + # Modify forward_schema to support prompts + args_schema, kwargs_schema = rpc_info["forward_schema"] + # TODO: rm this assert when support arbitrary number of input tensors + assert len(args_schema) == 1 and len(inputs) == 2 + forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema) + + if not nested_compare(forward_inputs, forward_schema_with_prompts): + raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?") + + forward_inputs = nested_flatten(forward_inputs) + inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs) + + # Asynchronous serialization + loop = asyncio.get_running_loop() + serialized_tensors = await asyncio.gather( + *( + loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression) + for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts)) + ) + ) + + # call RPC on remote server + size = sum(t.element_size() * t.nelement() for t in inputs) + if size > MAX_UNARY_PAYLOAD_SIZE: + deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, **kwargs) + else: + deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, **kwargs) + + return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"]) + + +async def _forward_stream( + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs +) -> List[torch.Tensor]: + split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE)) + + outputs = await stub.rpc_forward_stream( + amap_in_executor( + lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs), + iter_as_aiter(split), + ), + ) + + tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs) + return await deserialize_tensor_stream(tensors_stream) + + +async def _forward_unary( + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs +) -> List[torch.Tensor]: + outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward( + runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs) + ) + return [deserialize_torch_tensor(t) for t in outputs.tensors] + + +async def _backward_stream( + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs +) -> List[torch.Tensor]: + split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)) + + grad_inputs = await stub.rpc_backward_stream( + amap_in_executor( + lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs), + iter_as_aiter(split), + ), + ) + tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs) + return await deserialize_tensor_stream(tensors_stream) + + +async def run_remote_backward( + uid: ModuleUID, + stub: StubBase, + rpc_info: RPCInfo, + inputs: torch.Tensor, + grad_outputs: List[torch.Tensor], + *extra_tensors: torch.Tensor, + **kwargs, +) -> Sequence[torch.Tensor]: + """ + Serializes grad outputs and calls "rpc_backward" on a remote server. + Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221 + but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here. + """ + + grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs) + inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors))) + + # Modify forward_schema to support prompts + args_schema, kwargs_schema = rpc_info["forward_schema"] + assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor) + # TODO generalize this + prompts_schema = next(iter(args_schema)) + backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema))) + + # Asynchronous serialization + loop = asyncio.get_running_loop() + serialized_tensors = await asyncio.gather( + *( + loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression) + for tensor, proto in zip(inputs_and_grad_outputs, backward_schema) + ) + ) + + size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs) + if size > MAX_UNARY_PAYLOAD_SIZE: + deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, **kwargs) + else: + deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, **kwargs) + + return deserialized_grad_inputs + + +async def _backward_unary( + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs +) -> List[torch.Tensor]: + grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward( + runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs) + ) + return [deserialize_torch_tensor(t) for t in grad_inputs.tensors] diff --git a/src/client/sequence_manager.py b/src/client/sequence_manager.py index c05ae72..0c15163 100644 --- a/src/client/sequence_manager.py +++ b/src/client/sequence_manager.py @@ -9,6 +9,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.proto import runtime_pb2 from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from src.client.spending_policy import NoSpendingPolicy from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState from src.dht_utils import get_remote_module_infos from src.server.handler import TransformerConnectionHandler @@ -24,6 +25,7 @@ class RemoteSequenceManager: """ def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3): + assert len(block_uids) > 0, "Sequences must contain at least one block" self.dht, self.p2p = dht, p2p self.block_uids: List[ModuleUID] = list(block_uids) self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids) @@ -39,7 +41,7 @@ class RemoteSequenceManager: assert info is not None, f"Found no remote peers for block {uid}" assert self.spans_by_priority and self.spans_containing_block - def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> Sequence[RemoteSpanInfo]: + def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]: """ Form a sequence of remote servers that collectively serve all consecutive layers diff --git a/src/client/sequential_autograd.py b/src/client/sequential_autograd.py index 1498236..71dc77a 100644 --- a/src/client/sequential_autograd.py +++ b/src/client/sequential_autograd.py @@ -1,102 +1,22 @@ +""" +A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner +""" import asyncio import logging from typing import List, Optional, Sequence, Tuple import torch -from hivemind import serialize_torch_tensor -from hivemind.moe.client.expert import expert_backward, expert_forward from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker -from hivemind.p2p import StubBase -from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack +from src.client.remote_forward_backward import run_remote_backward, run_remote_forward from src.client.sequence_manager import RemoteSequenceManager -from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo +from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo from src.server.handler import TransformerConnectionHandler from src.utils.misc import DUMMY, is_dummy MAX_TOKENS_IN_BATCH = 1024 -async def run_expert_forward( - uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs -) -> Tuple[torch.Tensor, ...]: - """ - Serializes input tensors and calls "expert_forward". - Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198 - but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here. - """ - - # Note: *inputs are flattened input tensors that follow the expert's info['input_schema'] - # detach to avoid pickling the computation graph - assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}" - kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]} - - # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors - forward_inputs = (inputs, kwargs) - - # Modify forward_schema to support prompts - args_schema, kwargs_schema = rpc_info["forward_schema"] - # TODO: rm this assert when support arbitrary number of input tensors - assert len(args_schema) == 1 and len(inputs) == 2 - forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema) - - if not nested_compare(forward_inputs, forward_schema_with_prompts): - raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?") - - forward_inputs = nested_flatten(forward_inputs) - inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs) - - # Asynchronous serialization - loop = asyncio.get_running_loop() - serialized_tensors = await asyncio.gather( - *( - loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression) - for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts)) - ) - ) - - deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub) - flat_outputs = tuple(deserialized_outputs) - return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"]) - - -async def run_expert_backward( - uid: ModuleUID, - stub: StubBase, - rpc_info: RPCInfo, - inputs: torch.Tensor, - grad_outputs: List[torch.Tensor], - *extra_tensors: torch.Tensor, -) -> Sequence[torch.Tensor]: - """ - Serializes grad outputs and calls "expert_backward". - Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221 - but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here. - """ - - grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs) - inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors))) - - # Modify forward_schema to support prompts - args_schema, kwargs_schema = rpc_info["forward_schema"] - assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor) - # TODO generalize this - prompts_schema = next(iter(args_schema)) - backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema))) - - # Asynchronous serialization - loop = asyncio.get_running_loop() - serialized_tensors = await asyncio.gather( - *( - loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression) - for tensor, proto in zip(inputs_and_grad_outputs, backward_schema) - ) - ) - - deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub) - return deserialized_grad_inputs - - async def sequential_forward( inputs: torch.Tensor, prompts: torch.Tensor, @@ -121,16 +41,17 @@ async def sequential_forward( sequences = sequence_manager.make_sequence(start_index, end_index) intermediate_inputs = [] done_sequences = [] + outputs = inputs while len(sequences) > 0: while True: + span = sequences.pop(0) + span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) try: - span = sequences.pop(0) - span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id) inputs_and_prompts = [inputs, prompts[span.start : span.end]] - (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts) + (outputs,) = await run_remote_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts) assert isinstance(outputs, torch.Tensor) assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}" @@ -171,7 +92,7 @@ async def sequential_backward( span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) try: stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id) - grad_outputs, *span_grad_prompts = await run_expert_backward( + grad_outputs, *span_grad_prompts = await run_remote_backward( span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end] ) grad_outputs = [grad_outputs] diff --git a/src/client/spending_policy.py b/src/client/spending_policy.py new file mode 100644 index 0000000..770d25a --- /dev/null +++ b/src/client/spending_policy.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod + +from hivemind.proto.runtime_pb2 import ExpertRequest + + +class SpendingPolicyBase(ABC): + @abstractmethod + def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float: + pass + + +class NoSpendingPolicy(SpendingPolicyBase): + def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float: + return 0.0 diff --git a/src/server/backend.py b/src/server/backend.py index 27ee1ad..ed8273d 100644 --- a/src/server/backend.py +++ b/src/server/backend.py @@ -1,45 +1,20 @@ """Code for serving bloom blocks via hivemind-server""" -from queue import Empty from typing import Any, Dict, Optional, Sequence, Tuple import torch from hivemind import BatchTensorDescriptor, use_hivemind_log_handler from hivemind.moe.server.module_backend import ModuleBackend -from hivemind.moe.server.task_pool import TaskPool -from hivemind.utils import InvalidStateError, get_logger +from hivemind.utils import get_logger from src.bloom.from_pretrained import BloomBlock from src.server.cache import MemoryCache +from src.server.task_pool import PrioritizedTaskPool from src.utils.misc import is_dummy use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) -class InferenceTaskPool(TaskPool): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1" - - def iterate_minibatches(self, *args, **kwargs): - """Form minibatches by grouping one or more tasks together up to self.max_batch_size""" - - while True: - try: - logger.debug(f"{self.name} getting next task") - task = self.tasks.get(timeout=self.timeout) - except Empty: - logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet") - continue - - try: - if task.future.set_running_or_notify_cancel(): - yield [task] - except InvalidStateError as e: - logger.debug(f"Failed to add task to batch: {task.future} raised {e}") - - class TransformerBackend(ModuleBackend): """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward""" @@ -52,8 +27,15 @@ class TransformerBackend(ModuleBackend): for name, buf in self.module.named_buffers(): assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does" - self.inference_pool = InferenceTaskPool( - self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference" + max_batch_size = self.forward_pool.max_batch_size + self.inference_pool = PrioritizedTaskPool( + self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference" + ) + self.forward_pool = PrioritizedTaskPool( + self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward" + ) + self.backward_pool = PrioritizedTaskPool( + self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward" ) self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype self.inference_schema = ( @@ -94,9 +76,9 @@ class TransformerBackend(ModuleBackend): cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length] return (hidden_states,) - def get_pools(self) -> Sequence[TaskPool]: + def get_pools(self) -> Sequence[PrioritizedTaskPool]: return self.forward_pool, self.backward_pool, self.inference_pool def get_info(self) -> Dict[str, Any]: - """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration.""" + """Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration.""" return dict(super().get_info(), inference_schema=self.inference_schema) diff --git a/src/server/handler.py b/src/server/handler.py index b2e15f7..7d7f76b 100644 --- a/src/server/handler.py +++ b/src/server/handler.py @@ -1,5 +1,5 @@ import contextlib -from typing import AsyncIterator, Dict, List, Optional, Sequence, Union +from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union import torch from hivemind import ( @@ -7,6 +7,7 @@ from hivemind import ( MSGPackSerializer, P2PContext, TensorDescriptor, + deserialize_tensor_stream, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor, @@ -14,12 +15,13 @@ from hivemind import ( from hivemind.moe.server.connection_handler import ConnectionHandler from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE from hivemind.proto import runtime_pb2 -from hivemind.utils import as_aiter -from hivemind.utils.asyncio import anext +from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter from hivemind.utils.streaming import split_for_streaming from src.data_structures import CHAIN_DELIMITER, ModuleUID from src.server.backend import TransformerBackend +from src.server.task_pool import PrioritizedTaskPool +from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase from src.utils.misc import DUMMY, is_dummy @@ -28,11 +30,41 @@ class TransformerConnectionHandler(ConnectionHandler): module_backends: Dict[ModuleUID, TransformerBackend] - def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend], inference_max_length: int): + def __init__( + self, + dht: DHT, + module_backends: Dict[str, TransformerBackend], + inference_max_length: int, + task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(), + ): super().__init__(dht, module_backends) for module_backend in self.module_backends.values(): assert isinstance(module_backend, TransformerBackend) self.inference_max_length = inference_max_length + self._prioritizer = task_prioritizer + + async def _gather_inputs( + self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext + ) -> Tuple[str, List[torch.Tensor], Dict]: + block_uid, metadata = None, None + + def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]: + nonlocal block_uid, metadata + + if block_uid is None: + block_uid = req.uid + elif block_uid != req.uid: + raise ValueError("Block uids differ in one request") + + if metadata is None: + metadata = MSGPackSerializer.loads(req.metadata) if req.metadata else {} + + return req.tensors + + tensors_stream = amap_in_executor(_unpack, requests) + inputs = await deserialize_tensor_stream(tensors_stream) + assert isinstance(block_uid, str) and isinstance(metadata, dict) + return block_uid, inputs, metadata async def rpc_inference( self, @@ -47,13 +79,18 @@ class TransformerConnectionHandler(ConnectionHandler): metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") + points = metadata.get("points", 0) if not requested_uids: raise ValueError("User must specify at least one block for inference, but got none") assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}" + assert isinstance( + points, (float, int) + ), f"rpc_inference should have number of points as a number or None, got {points}" if not 0 <= max_length <= self.inference_max_length: raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}") + point_per_piece = points / max_length if max_length > 0 else 0.0 batch_size = request.tensors[0].size[0] if request.tensors else 1 cache_metadata = torch.tensor( @@ -98,8 +135,19 @@ class TransformerConnectionHandler(ConnectionHandler): assert ( hidden_states.ndim == 3 ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" + assert isinstance( + backend.inference_pool, PrioritizedTaskPool + ), "petals support only prioritized pools" + priority = self._prioritizer.prioritize( + cache_metadata, + hidden_states, + hypo_ids, + points=point_per_piece / len(requested_backends), + backend=backend, + type="inference", + ) (hidden_states,) = await backend.inference_pool.submit_task( - cache_metadata, hidden_states, hypo_ids + cache_metadata, hidden_states, hypo_ids, priority=priority ) # serialize and send last layer outputs @@ -123,8 +171,15 @@ class TransformerConnectionHandler(ConnectionHandler): flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors] requested_uids = self._check_uids(request.uid) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) - - hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends) + metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} + points = metadata.get("points", 0) + assert isinstance( + points, (float, int) + ), f"rpc_forward should have number of points as number or None, got {points}" + + hidden_states = await _rpc_forward( + *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points + ) assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3 # Serialize output and respond to client @@ -139,11 +194,17 @@ class TransformerConnectionHandler(ConnectionHandler): self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext ) -> AsyncIterator[runtime_pb2.ExpertRequest]: # Parse requests and prepare backends - uid_str, flat_inputs = await self._gather_inputs(requests, context) + uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context) requested_uids = self._check_uids(uid_str) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) + points = metadata.get("points", 0) + assert isinstance( + points, (float, int) + ), f"rpc_forward_stream should have number of points as number or None, got {points}" - hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends) + hidden_states = await _rpc_forward( + *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points + ) assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor" # Serialize the overall output @@ -164,8 +225,15 @@ class TransformerConnectionHandler(ConnectionHandler): flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors] requested_uids = self._check_uids(request.uid) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) - - grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends) + metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} + points = metadata.get("points", 0) + assert isinstance( + points, (float, int) + ), f"rpc_backward should have number of points as number or None, got {points}" + + grads = await _rpc_backward( + *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points + ) # Modify grad_inputs_schema to support grad_prompts assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize @@ -187,11 +255,17 @@ class TransformerConnectionHandler(ConnectionHandler): self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext ) -> AsyncIterator[runtime_pb2.ExpertResponse]: - uids_header, flat_tensors = await self._gather_inputs(requests, context) + uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context) requested_uids = self._check_uids(uids_header) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) + points = metadata.get("points", 0) + assert isinstance( + points, (float, int) + ), f"rpc_backward_stream should have number of points as number or None, got {points}" - grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends) + grads = await _rpc_backward( + *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points + ) # Modify grad_inputs_schema to support grad_prompts assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize @@ -244,7 +318,12 @@ class TransformerConnectionHandler(ConnectionHandler): yield handles -async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]) -> torch.Tensor: +async def _rpc_forward( + *flat_tensors: torch.Tensor, + requested_backends: Sequence[TransformerBackend], + prioritizer: TaskPrioritizerBase, + points: int = 0, +) -> torch.Tensor: """ Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream @@ -267,7 +346,15 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence for backend, prompt in zip(requested_backends, prompts): if not is_dummy(prompt): hidden_states[:, : prompt.shape[1]] += prompt - (hidden_states,) = await backend.forward_pool.submit_task(hidden_states) + + assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + priority = prioritizer.prioritize( + hidden_states, points=points / len(requested_backends), backend=backend, type="forward" + ) + (hidden_states,) = await backend.forward_pool.submit_task( + hidden_states, + priority=priority, + ) assert isinstance(hidden_states, torch.Tensor) assert ( hidden_states.ndim == 3 @@ -278,7 +365,10 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence async def _rpc_backward( - *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend] + *flat_tensors: torch.Tensor, + requested_backends: Sequence[TransformerBackend], + prioritizer: TaskPrioritizerBase, + points: int = 0, ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: inputs, grad_outputs, prompts = flat_tensors # Cast inputs & grad outputs to backend dtype @@ -298,7 +388,12 @@ async def _rpc_backward( if not is_dummy(prompt): inputs[:, : prompt.shape[1]] += prompt inter_inputs.append(inputs) - (inputs,) = await backend.forward_pool.submit_task(inputs) + assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + priority = prioritizer.prioritize( + inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward" + ) + (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority) + assert isinstance(inputs, torch.Tensor) if not is_dummy(prompts[-1]): @@ -309,7 +404,12 @@ async def _rpc_backward( grad_prompts_reversed = [] # Run a chain of requested backends for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))): - (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs) + assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + priority = prioritizer.prioritize( + inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" + ) + (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority) + assert isinstance(grad_outputs, torch.Tensor) if not is_dummy(prompt): grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) diff --git a/src/server/runtime.py b/src/server/runtime.py new file mode 100644 index 0000000..11547aa --- /dev/null +++ b/src/server/runtime.py @@ -0,0 +1,198 @@ +import multiprocessing as mp +import multiprocessing.pool +import threading +from collections import defaultdict +from itertools import chain +from queue import SimpleQueue +from selectors import EVENT_READ, DefaultSelector +from statistics import mean +from time import time +from typing import Dict, NamedTuple, Optional + +import torch +from hivemind.moe.server.module_backend import ModuleBackend +from hivemind.utils import get_logger +from prefetch_generator import BackgroundGenerator + +logger = get_logger(__name__) + + +class Runtime(threading.Thread): + """ + A group of processes that processes incoming requests for multiple module backends on a shared device. + Runtime is usually created and managed by Server, humans need not apply. + + For debugging, you can start runtime manually with .start() or .run() + + >>> module_backends = {'block_uid': ModuleBackend(**kwargs)} + >>> runtime = Runtime(module_backends) + >>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run() + >>> runtime.ready.wait() # await for runtime to load all blocks on device and create request pools + >>> future = runtime.module_backends['block_uid'].forward_pool.submit_task(*module_inputs) + >>> print("Returned:", future.result()) + >>> runtime.shutdown() + + :param module_backends: a dict [block uid -> ModuleBackend] + :param prefetch_batches: form up to this many batches in advance + :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads + :param device: if specified, moves all blocks and data to this device via .to(device=device). + If you want to manually specify devices for each block (in their forward pass), leave device=None (default) + + :param stats_report_interval: interval to collect and log statistics about runtime performance + """ + + SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED" + + def __init__( + self, + module_backends: Dict[str, ModuleBackend], + prefetch_batches: int = 1, + sender_threads: int = 1, + device: torch.device = None, + stats_report_interval: Optional[int] = None, + ): + super().__init__() + self.module_backends = module_backends + self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values()))) + self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads + self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False) + self.shutdown_trigger = mp.Event() + self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches + + self.stats_report_interval = stats_report_interval + if self.stats_report_interval is not None: + self.stats_reporter = StatsReporter(self.stats_report_interval) + + def run(self): + for pool in self.pools: + if not pool.is_alive(): + pool.start() + if self.device is not None: + for backend in self.module_backends.values(): + backend.module.to(self.device) + + with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool: + try: + self.ready.set() + if self.stats_report_interval is not None: + self.stats_reporter.start() + logger.info("Started") + + batch_iterator = self.iterate_minibatches_from_pools() + if self.prefetch_batches > 0: + batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches) + + for pool, batch_index, batch in batch_iterator: + logger.debug(f"Processing batch {batch_index} from pool {pool.name}") + + start = time() + try: + outputs = pool.process_func(*batch) + output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs]) + + batch_processing_time = time() - start + + batch_size = outputs[0].size(0) + logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}") + + if self.stats_report_interval is not None: + self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time) + + except KeyboardInterrupt: + raise + except BaseException as exception: + logger.exception(f"Caught {exception}, attempting to recover") + output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception]) + + finally: + if not self.shutdown_trigger.is_set(): + self.shutdown() + + def shutdown(self): + """Gracefully terminate a running runtime.""" + logger.info("Shutting down") + self.ready.clear() + + if self.stats_report_interval is not None: + self.stats_reporter.stop.set() + self.stats_reporter.join() + + logger.debug("Terminating pools") + for pool in self.pools: + if pool.is_alive(): + pool.shutdown() + logger.debug("Pools terminated") + + # trigger background thread to shutdown + self.shutdown_send.send(self.SHUTDOWN_TRIGGER) + self.shutdown_trigger.set() + + def iterate_minibatches_from_pools(self, timeout=None): + """ + Chooses pool according to priority, then copies exposed batch and frees the buffer + """ + with DefaultSelector() as selector: + for pool in self.pools: + selector.register(pool.batch_receiver, EVENT_READ, pool) + selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER) + + while True: + # wait until at least one batch_receiver becomes available + logger.debug("Waiting for inputs from task pools") + ready_fds = selector.select() + ready_objects = {key.data for (key, events) in ready_fds} + if self.SHUTDOWN_TRIGGER in ready_objects: + break # someone asked us to shutdown, break from the loop + + logger.debug("Choosing the pool with first priority") + + pool = min(ready_objects, key=lambda pool: pool.priority) + + logger.debug(f"Loading batch from {pool.name}") + batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device) + logger.debug(f"Loaded batch from {pool.name}") + yield pool, batch_index, batch_tensors + + +BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float))) + + +class StatsReporter(threading.Thread): + def __init__(self, report_interval: int): + super().__init__() + self.report_interval = report_interval + self.stop = threading.Event() + self.stats_queue = SimpleQueue() + + def run(self): + while not self.stop.wait(self.report_interval): + pool_batch_stats = defaultdict(list) + while not self.stats_queue.empty(): + pool_uid, batch_stats = self.stats_queue.get() + pool_batch_stats[pool_uid].append(batch_stats) + + total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values()) + logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:") + for pool_uid, pool_stats in pool_batch_stats.items(): + total_batches = len(pool_stats) + total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats) + avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats) + total_time = sum(batch_stats.processing_time for batch_stats in pool_stats) + batches_to_time = total_batches / total_time + batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch") + + examples_to_time = total_examples / total_time + example_performance = f"{examples_to_time:.2f} " + ( + "examples/s" if examples_to_time > 1 else "s/example" + ) + + logger.info( + f"{pool_uid}: " + f"{total_batches} batches ({batch_performance}), " + f"{total_examples} examples ({example_performance}), " + f"avg batch size {avg_batch_size:.2f}" + ) + + def report_stats(self, pool_uid, batch_size, processing_time): + batch_stats = BatchStats(batch_size, processing_time) + self.stats_queue.put_nowait((pool_uid, batch_stats)) diff --git a/src/server/server.py b/src/server/server.py index 5d92bd9..efa1787 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -71,9 +71,9 @@ class Server(threading.Thread): runs Runtime (self.runtime) to process incoming requests. """ logger.info(f"Serving {len(self.module_backends)} blocks:") - for expert_name, backend in self.module_backends.items(): + for block_name, backend in self.module_backends.items(): num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad) - logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters") + logger.info(f"{block_name}: {backend.module.__class__.__name__}, {num_parameters} parameters") if not self.dht.is_alive(): self.dht.run_in_background(await_ready=True) @@ -118,6 +118,8 @@ class Server(threading.Thread): custom_module_path=None, update_period: float = 30, expiration: Optional[float] = None, + prefetch_batches: int = 1, + sender_threads: int = 1, max_block_selection_delay: float = 1, use_auth_token: Optional[str] = None, load_in_8bit: bool = False, @@ -236,6 +238,8 @@ class Server(threading.Thread): stats_report_interval=stats_report_interval, update_period=update_period, expiration=expiration, + prefetch_batches=prefetch_batches, + sender_threads=sender_threads, start=start, ) diff --git a/src/server/task_pool.py b/src/server/task_pool.py new file mode 100644 index 0000000..2bf65c0 --- /dev/null +++ b/src/server/task_pool.py @@ -0,0 +1,175 @@ +import ctypes +import multiprocessing as mp +import threading +import time +from dataclasses import dataclass, field +from queue import PriorityQueue +from typing import Any, Generator, List, Optional, Sequence, Tuple + +import torch +from hivemind import MPFuture, get_logger, use_hivemind_log_handler +from hivemind.moe.server.task_pool import TaskPoolBase + +use_hivemind_log_handler("in_root_logger") +logger = get_logger(__file__) + + +@dataclass(order=True, frozen=True) +class Task: + priority: float + time_submitted: float + future: MPFuture = field(compare=False) + args: Sequence[torch.Tensor] = field(compare=False) + + @property + def uid(self) -> int: + return self.future._uid + + +class PrioritizedTaskPool(TaskPoolBase): + """ + Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then + returns results (or exception) to the corresponding ConnectionHandler. Runs a background process. + A single PrioritizedTaskPool services a specific function (e.g. layer1.forward, layer2.forward or layer1.backward) + + :note: unlike hivemind.moe TaskPool, this pool does *not* combine incoming requests into batches. + This would require grouping requests of different length. + + :param process_func: function to be applied to every formed batch; called by Runtime + Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors + :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs) + Measured in the total number of tokens (i.e. batch size * sequence length) + + :param name: pool name, used for logging + :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more + :param start: if True, start automatically at the end of __init__ + """ + + def __init__( + self, + process_func: callable, + max_batch_size: int, + name: str, + min_batch_size=1, + daemon=True, + start=False, + ): + super().__init__(process_func, daemon=daemon, name=name) + self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size + + self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers + self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime + + self._prioritizer_thread = threading.Thread( + name=self.name + "_prioritizer", + target=self._prioritize_tasks, + args=[self.submitted_tasks, self._ordered_tasks], + daemon=True, + ) + self._dispatched_tasks = {} + self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False) + self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0) + self.priority = float("inf"), float("inf") # (first task priority, first task timestamp) + if start: + self.start() + + @staticmethod + def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue): + """Read tasks from incoming queue and put them into a local priority queue""" + while True: + task = submitted_tasks.get() + if task is None: + logger.debug("Shutting down prioritizer thread") + break + + ordered_tasks.put(task, block=True) + + def start(self): + assert not self.is_alive() and not self._prioritizer_thread.is_alive() + self._prioritizer_thread.start() + super().start() + + def shutdown(self, timeout: Optional[float] = None): + self.submitted_tasks.put(None) + self.terminate() + self._prioritizer_thread.join(timeout) + + def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture: + """Add task to this pool's queue, return Future for its output""" + task = Task(priority, time.monotonic(), MPFuture(), args) + if self.get_task_size(task) > self.max_batch_size: + exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed") + task.future.set_exception(exc) + else: + self.submitted_tasks.put(task) + self.batch_sender.send(None) # use this pipe to count the number of unfinished batches + if (task.priority, task.time_submitted) < self.priority: + self.priority = (task.priority, task.time_submitted) + return task.future + + def get_task_size(self, task: Task) -> int: + """compute task processing complexity; defaults to the total number of tokens""" + if task.args and task.args[0].ndim >= 2: + return task.args[0].shape[0] * task.args[0].shape[1] + return 1 + + def load_batch_to_runtime( + self, timeout: Optional[float] = None, device: Optional[torch.device] = None + ) -> Tuple[Any, List[torch.Tensor]]: + """receive next batch of arrays""" + task = self._ordered_tasks.get(block=True, timeout=timeout) + batch_inputs = [ + tensor.detach().to(device, non_blocking=True).requires_grad_(tensor.requires_grad) for tensor in task.args + ] + self._dispatched_tasks[task.uid] = task + self.batch_receiver.recv() # reduce the number of active batches + if not self._ordered_tasks.empty(): + first_remaining_task: Task = self._ordered_tasks.queue[0] + self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted) + return task.uid, batch_inputs + + def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]): + """send results for a processed batch, previously loaded through load_batch_to_runtime""" + batch_outputs = [ + tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad) + for tensor in batch_outputs + ] + + task = self._dispatched_tasks.pop(uid, None) + if task is None: + logger.error( + f"Internal error: task task with index {uid} is missing from the dictionary; " f"Could not set result" + ) + else: + task.future.set_result(batch_outputs) + + def send_exception_from_runtime(self, uid: int, exception: BaseException): + task = self._dispatched_tasks.pop(uid, None) + if task is None: + logger.error( + f"Internal error: task task with index {uid} is missing from the dictionary; " + f"Could not set exception {exception}" + ) + else: + task.future.set_exception(exception) + + def run(self, *args, **kwargs): + mp.Event().wait() + + @property + def empty(self): + return not self.batch_receiver.poll() + + @property + def priority(self) -> Tuple[float, float]: + """The priority of this pool equals the (priority, timestamp) of the most important task in it.""" + return float(self._priority.value), float(self._oldest_undispatched_timestamp.value) + + @priority.setter + def priority(self, item: Tuple[float, float]): + assert len(item) == 2 + self._priority.value = float(item[0]) + self._oldest_undispatched_timestamp.value = float(item[1]) + + def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]: + raise NotImplementedError() diff --git a/src/server/task_prioritizer.py b/src/server/task_prioritizer.py new file mode 100644 index 0000000..6e3b886 --- /dev/null +++ b/src/server/task_prioritizer.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + +import torch +from hivemind.moe.server.task_pool import Task + + +class TaskPrioritizerBase(ABC): + """Abstract class for TaskPrioritizer whose reponsibility is to evaluate task priority""" + + @abstractmethod + def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float: + """Evaluates task value by the amout of points given, task input and additional kwargs. Lower priority is better""" + pass + + +class DummyTaskPrioritizer(TaskPrioritizerBase): + """Simple implementation of TaskPrioritizer which gives constant zero priority for every task""" + + def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float: + return 0.0 diff --git a/tests/test_priority_pool.py b/tests/test_priority_pool.py new file mode 100644 index 0000000..21dd74e --- /dev/null +++ b/tests/test_priority_pool.py @@ -0,0 +1,71 @@ +import multiprocessing as mp +import time + +import pytest +import torch + +from src.server.runtime import Runtime +from src.server.task_pool import PrioritizedTaskPool + + +@pytest.mark.forked +def test_priority_pools(): + outputs_queue = mp.SimpleQueue() + results_valid = mp.Event() + + def dummy_pool_func(x): + time.sleep(0.1) + y = x**2 + outputs_queue.put((x, y)) + return (y,) + + class DummyBackend: + def __init__(self, pools): + self.pools = pools + + def get_pools(self): + return self.pools + + pools = ( + PrioritizedTaskPool(dummy_pool_func, name="A", max_batch_size=1), + PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1), + ) + + runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0) + runtime.start() + + def process_tasks(): + futures = [] + futures.append(pools[0].submit_task(torch.tensor([0]), priority=1)) + futures.append(pools[0].submit_task(torch.tensor([1]), priority=1)) + time.sleep(0.01) + futures.append(pools[1].submit_task(torch.tensor([2]), priority=1)) + futures.append(pools[0].submit_task(torch.tensor([3]), priority=2)) + futures.append(pools[0].submit_task(torch.tensor([4]), priority=10)) + futures.append(pools[0].submit_task(torch.tensor([5]), priority=0)) + futures.append(pools[0].submit_task(torch.tensor([6]), priority=1)) + futures.append(pools[1].submit_task(torch.tensor([7]), priority=11)) + futures.append(pools[1].submit_task(torch.tensor([8]), priority=1)) + for i, f in enumerate(futures): + assert f.result()[0].item() == i**2 + results_valid.set() + + proc = mp.Process(target=process_tasks) + proc.start() + proc.join() + assert results_valid.is_set() + + ordered_outputs = [] + while not outputs_queue.empty(): + ordered_outputs.append(outputs_queue.get()[0].item()) + + assert ordered_outputs == [0, 5, 1, 2, 6, 8, 3, 4, 7] + # 0 - first batch is loaded immediately, before everything else + # 5 - highest priority task overall + # 1 - first of several tasks with equal lowest priority (1) + # 2 - second earliest task with priority 1, fetched from pool B + # 6 - third earliest task with priority 1, fetched from pool A again + # 8 - last priority-1 task, pool B + # 3 - task with priority 2 from pool A + # 4 - task with priority 10 from pool A + # 7 - task with priority 11 from pool B