Merge inference pools into one to increase inference speed (#225)

It turns out using a separate pool for each block has led to significant slowdown, see #224 for details.
pull/219/head
justheuristic 1 year ago committed by GitHub
parent 3189b395f0
commit c4938bc23e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,7 +9,7 @@ jobs:
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- uses: psf/black@stable
with:
options: "--check --diff"
@ -17,7 +17,7 @@ jobs:
isort:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- uses: actions/setup-python@v2
with:
python-version: 3.8

@ -14,7 +14,7 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: Docker meta
id: meta

@ -13,7 +13,7 @@ jobs:
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: Check if the model is cached
id: cache-model
uses: actions/cache@v3
@ -64,7 +64,7 @@ jobs:
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v2
with:

@ -6,6 +6,7 @@ from enum import Enum
from typing import Any, Dict, Tuple
from hivemind import PeerID
from hivemind.moe.expert_uid import ExpertUID
from petals.server.memory_cache import Handle
@ -48,5 +49,6 @@ RPCInfo = Dict[str, Any]
@dataclasses.dataclass(frozen=True)
class InferenceMetadata:
uid: ExpertUID
prefix_length: int
cache_handles: Tuple[Handle, ...]

@ -3,10 +3,11 @@ from __future__ import annotations
from collections import Counter
from itertools import chain
from typing import Any, Dict, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple
import torch
from hivemind import BatchTensorDescriptor, TensorDescriptor
from hivemind.moe.expert_uid import ExpertUID
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.utils import get_logger
from tensor_parallel import TensorParallel
@ -15,7 +16,7 @@ from transformers import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomAttention
from petals.data_structures import InferenceMetadata
from petals.server.memory_cache import MemoryCache
from petals.server.memory_cache import Handle, MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy
@ -39,7 +40,7 @@ class TransformerBackend(ModuleBackend):
device = self.module.devices[self.module.output_device_index]
self.inference_pool = PrioritizedTaskPool(
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
)
) # note: inference_pools may be merged later, see merge_inference_pools_inplace
self.forward_pool = PrioritizedTaskPool(
self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
)
@ -79,16 +80,14 @@ class TransformerBackend(ModuleBackend):
cache_tensors.extend((keys, values))
return cache_tensors
@torch.inference_mode()
def inference_step(
self,
hidden_states: torch.Tensor,
hypo_ids: torch.LongTensor,
inference_info: InferenceMetadata,
) -> Tuple[torch.Tensor, ...]:
with torch.inference_mode():
assert (
hidden_states.ndim == 3
), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
self._reorder_cache_inplace(cache_tensors, hypo_ids)
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
@ -139,3 +138,39 @@ class TransformerBackend(ModuleBackend):
dummy = torch.tensor([])
for p in self.module.parameters():
p.data = dummy
def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
"""Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())
first_pool = next(iter(backends.values())).inference_pool
merged_pool = PrioritizedTaskPool(
_MergedInferenceStep(backends),
max_batch_size=first_pool.max_batch_size,
device=first_pool.device,
name=f"merged_inference",
)
for backend in backends.values():
assert not backend.inference_pool.is_alive()
backend.inference_pool = merged_pool
class _MergedInferenceStep:
def __init__(self, backends: Dict[ExpertUID, TransformerBackend]):
self.backends = backends
def __call__(
self,
hidden_states: torch.Tensor,
hypo_ids: torch.LongTensor,
inference_infos: Sequence[InferenceMetadata],
*optional_prompts: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, ...]:
assert len(inference_infos) == len(
optional_prompts
), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
if optional_prompt is not None:
hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
(hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
return (hidden_states,)

@ -141,10 +141,11 @@ class TransformerConnectionHandler(ConnectionHandler):
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
# parse deep prompts (optional argument)
if prompts is None or is_dummy(prompts) or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
if prompts is None or is_dummy(prompts):
prompts = [None] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
@ -156,33 +157,26 @@ class TransformerConnectionHandler(ConnectionHandler):
f" exceeds pre-allocated maximum {max_length}"
)
# run request tensors through all requested modules, update caches
for backend, backend_cache_handles, prompt in zip(requested_backends, cache_handles, prompts):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
if hidden_states.numel() == 0:
continue # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
metadata = InferenceMetadata(prefix_length, tuple(backend_cache_handles))
assert isinstance(
hidden_states, torch.Tensor
), f"hidden states must be tensor, got {type(hidden_states)}"
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
assert isinstance(
backend.inference_pool, PrioritizedTaskPool
), "petals support only prioritized pools"
priority = self._prioritizer.prioritize(
hidden_states,
hypo_ids,
points=point_per_piece / len(requested_backends),
backend=backend,
points=point_per_piece,
requested_uids=requested_uids,
type="inference",
)
(hidden_states,) = await backend.inference_pool.submit_task(
hidden_states, hypo_ids, metadata, priority=priority
inference_infos = tuple(
InferenceMetadata(uid, prefix_length, tuple(handles))
for uid, handles in zip(requested_uids, cache_handles)
)
if hidden_states.numel() == 0:
pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
else:
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
(hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
)
# serialize and send last layer outputs
@ -444,7 +438,6 @@ async def _rpc_forward(
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
# Serialize the overall output
return hidden_states

@ -22,7 +22,7 @@ from petals.constants import PUBLIC_INITIAL_PEERS
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
from petals.dht_utils import declare_active_modules, get_remote_module_infos
from petals.server import block_selection
from petals.server.backend import TransformerBackend
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
from petals.server.block_utils import get_block_size
from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache
@ -453,11 +453,12 @@ class ModuleContainer(threading.Thread):
joining_announcer.stop.set()
joining_announcer.join()
merge_inference_pools_inplace(blocks)
return cls(
dht,
blocks,
throughput=throughput,
device=device,
update_period=update_period,
expiration=expiration,
**kwargs,
@ -476,7 +477,6 @@ class ModuleContainer(threading.Thread):
request_timeout: float,
session_timeout: float,
step_timeout: float,
device: Union[str, torch.device],
start: bool,
**kwargs,
):
@ -495,7 +495,7 @@ class ModuleContainer(threading.Thread):
)
for _ in range(num_handlers)
]
self.runtime = Runtime(self.module_backends, device=None, **kwargs)
self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
# note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
self.online_announcer = ModuleAnnouncerThread(
list(self.module_backends.keys()),
@ -633,3 +633,11 @@ class ModuleAnnouncerThread(threading.Thread):
)
if self.stop.wait(self.update_period):
break
class RuntimeWithDeduplicatedPools(Runtime):
"""A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pools = tuple(set(self.pools))

@ -16,4 +16,6 @@ class DummyTaskPrioritizer(TaskPrioritizerBase):
"""Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
return 0.0
if kwargs.get("type") == "inference":
return 1.0 # inference steps go first since they are more latency-sensitive
return 2.0 # forward, backward

Loading…
Cancel
Save