From c4938bc23efe22e3ab6d638261bfd56c6ad807a9 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 19 Jan 2023 18:38:21 +0300 Subject: [PATCH] Merge inference pools into one to increase inference speed (#225) It turns out using a separate pool for each block has led to significant slowdown, see #224 for details. --- .github/workflows/check-style.yaml | 4 +- .github/workflows/push-docker-image.yaml | 2 +- .github/workflows/run-tests.yaml | 4 +- src/petals/data_structures.py | 2 + src/petals/server/backend.py | 61 +++++++++++++++++++----- src/petals/server/handler.py | 53 +++++++++----------- src/petals/server/server.py | 16 +++++-- src/petals/server/task_prioritizer.py | 4 +- 8 files changed, 93 insertions(+), 53 deletions(-) diff --git a/.github/workflows/check-style.yaml b/.github/workflows/check-style.yaml index 94b9517..42e1460 100644 --- a/.github/workflows/check-style.yaml +++ b/.github/workflows/check-style.yaml @@ -9,7 +9,7 @@ jobs: black: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: psf/black@stable with: options: "--check --diff" @@ -17,7 +17,7 @@ jobs: isort: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/setup-python@v2 with: python-version: 3.8 diff --git a/.github/workflows/push-docker-image.yaml b/.github/workflows/push-docker-image.yaml index cbad1b2..345b8f2 100644 --- a/.github/workflows/push-docker-image.yaml +++ b/.github/workflows/push-docker-image.yaml @@ -14,7 +14,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Docker meta id: meta diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 50509dc..c1a01e8 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -13,7 +13,7 @@ jobs: timeout-minutes: 15 steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Check if the model is cached id: cache-model uses: actions/cache@v3 @@ -64,7 +64,7 @@ jobs: timeout-minutes: 15 steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v2 with: diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index d5a7181..5d85f07 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -6,6 +6,7 @@ from enum import Enum from typing import Any, Dict, Tuple from hivemind import PeerID +from hivemind.moe.expert_uid import ExpertUID from petals.server.memory_cache import Handle @@ -48,5 +49,6 @@ RPCInfo = Dict[str, Any] @dataclasses.dataclass(frozen=True) class InferenceMetadata: + uid: ExpertUID prefix_length: int cache_handles: Tuple[Handle, ...] diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 4f9a3bb..81f3a33 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -3,10 +3,11 @@ from __future__ import annotations from collections import Counter from itertools import chain -from typing import Any, Dict, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple import torch from hivemind import BatchTensorDescriptor, TensorDescriptor +from hivemind.moe.expert_uid import ExpertUID from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils import get_logger from tensor_parallel import TensorParallel @@ -15,7 +16,7 @@ from transformers import BloomConfig from transformers.models.bloom.modeling_bloom import BloomAttention from petals.data_structures import InferenceMetadata -from petals.server.memory_cache import MemoryCache +from petals.server.memory_cache import Handle, MemoryCache from petals.server.task_pool import PrioritizedTaskPool from petals.utils.misc import is_dummy @@ -39,7 +40,7 @@ class TransformerBackend(ModuleBackend): device = self.module.devices[self.module.output_device_index] self.inference_pool = PrioritizedTaskPool( self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference" - ) + ) # note: inference_pools may be merged later, see merge_inference_pools_inplace self.forward_pool = PrioritizedTaskPool( self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward" ) @@ -79,22 +80,20 @@ class TransformerBackend(ModuleBackend): cache_tensors.extend((keys, values)) return cache_tensors + @torch.inference_mode() def inference_step( self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: - with torch.inference_mode(): - assert ( - hidden_states.ndim == 3 - ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" - with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors: - self._reorder_cache_inplace(cache_tensors, hypo_ids) - layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) - hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True) - self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length) - return (hidden_states,) + assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" + with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors: + self._reorder_cache_inplace(cache_tensors, hypo_ids) + layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) + hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True) + self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length) + return (hidden_states,) def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor): """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids""" @@ -139,3 +138,39 @@ class TransformerBackend(ModuleBackend): dummy = torch.tensor([]) for p in self.module.parameters(): p.data = dummy + + +def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]): + """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call""" + assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values()) + first_pool = next(iter(backends.values())).inference_pool + merged_pool = PrioritizedTaskPool( + _MergedInferenceStep(backends), + max_batch_size=first_pool.max_batch_size, + device=first_pool.device, + name=f"merged_inference", + ) + for backend in backends.values(): + assert not backend.inference_pool.is_alive() + backend.inference_pool = merged_pool + + +class _MergedInferenceStep: + def __init__(self, backends: Dict[ExpertUID, TransformerBackend]): + self.backends = backends + + def __call__( + self, + hidden_states: torch.Tensor, + hypo_ids: torch.LongTensor, + inference_infos: Sequence[InferenceMetadata], + *optional_prompts: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, ...]: + assert len(inference_infos) == len( + optional_prompts + ), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts" + for inference_info, optional_prompt in zip(inference_infos, optional_prompts): + if optional_prompt is not None: + hidden_states[:, : optional_prompt.shape[1]] += optional_prompt + (hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info) + return (hidden_states,) diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 3c889f6..b1c36ed 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -141,10 +141,11 @@ class TransformerConnectionHandler(ConnectionHandler): assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" # parse deep prompts (optional argument) - if prompts is None or is_dummy(prompts) or is_dummy(prompts): - prompts = [DUMMY] * len(requested_backends) + if prompts is None or is_dummy(prompts): + prompts = [None] * len(requested_backends) else: prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts] if not (len(requested_backends) == len(prompts)): raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") @@ -156,33 +157,26 @@ class TransformerConnectionHandler(ConnectionHandler): f" exceeds pre-allocated maximum {max_length}" ) - # run request tensors through all requested modules, update caches - for backend, backend_cache_handles, prompt in zip(requested_backends, cache_handles, prompts): - if not is_dummy(prompt): - hidden_states[:, : prompt.shape[1]] += prompt - if hidden_states.numel() == 0: - continue # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. - # when user wants to pre-allocate cache or check that server *can* allocate that cache - - metadata = InferenceMetadata(prefix_length, tuple(backend_cache_handles)) - assert isinstance( - hidden_states, torch.Tensor - ), f"hidden states must be tensor, got {type(hidden_states)}" - assert ( - hidden_states.ndim == 3 - ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" - assert isinstance( - backend.inference_pool, PrioritizedTaskPool - ), "petals support only prioritized pools" - priority = self._prioritizer.prioritize( - hidden_states, - hypo_ids, - points=point_per_piece / len(requested_backends), - backend=backend, - type="inference", - ) - (hidden_states,) = await backend.inference_pool.submit_task( - hidden_states, hypo_ids, metadata, priority=priority + priority = self._prioritizer.prioritize( + hidden_states, + hypo_ids, + points=point_per_piece, + requested_uids=requested_uids, + type="inference", + ) + + inference_infos = tuple( + InferenceMetadata(uid, prefix_length, tuple(handles)) + for uid, handles in zip(requested_uids, cache_handles) + ) + + if hidden_states.numel() == 0: + pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. + # when user wants to pre-allocate cache or check that server *can* allocate that cache + else: + assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor" + (hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task( + hidden_states, hypo_ids, inference_infos, *prompts, priority=priority ) # serialize and send last layer outputs @@ -444,7 +438,6 @@ async def _rpc_forward( hidden_states.ndim == 3 ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" - # Serialize the overall output return hidden_states diff --git a/src/petals/server/server.py b/src/petals/server/server.py index a411fd3..a25fce6 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -22,7 +22,7 @@ from petals.constants import PUBLIC_INITIAL_PEERS from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState from petals.dht_utils import declare_active_modules, get_remote_module_infos from petals.server import block_selection -from petals.server.backend import TransformerBackend +from petals.server.backend import TransformerBackend, merge_inference_pools_inplace from petals.server.block_utils import get_block_size from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache @@ -453,11 +453,12 @@ class ModuleContainer(threading.Thread): joining_announcer.stop.set() joining_announcer.join() + merge_inference_pools_inplace(blocks) + return cls( dht, blocks, throughput=throughput, - device=device, update_period=update_period, expiration=expiration, **kwargs, @@ -476,7 +477,6 @@ class ModuleContainer(threading.Thread): request_timeout: float, session_timeout: float, step_timeout: float, - device: Union[str, torch.device], start: bool, **kwargs, ): @@ -495,7 +495,7 @@ class ModuleContainer(threading.Thread): ) for _ in range(num_handlers) ] - self.runtime = Runtime(self.module_backends, device=None, **kwargs) + self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs) # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed. self.online_announcer = ModuleAnnouncerThread( list(self.module_backends.keys()), @@ -633,3 +633,11 @@ class ModuleAnnouncerThread(threading.Thread): ) if self.stop.wait(self.update_period): break + + +class RuntimeWithDeduplicatedPools(Runtime): + """A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pools = tuple(set(self.pools)) diff --git a/src/petals/server/task_prioritizer.py b/src/petals/server/task_prioritizer.py index 3ec5a90..6490fc5 100644 --- a/src/petals/server/task_prioritizer.py +++ b/src/petals/server/task_prioritizer.py @@ -16,4 +16,6 @@ class DummyTaskPrioritizer(TaskPrioritizerBase): """Simple implementation of TaskPrioritizer which gives constant zero priority for every task""" def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float: - return 0.0 + if kwargs.get("type") == "inference": + return 1.0 # inference steps go first since they are more latency-sensitive + return 2.0 # forward, backward