Prioritize short inference, unmerge pools for long inference (#458)

Right now, long inference requests may occupy Runtime for a few seconds without giving it away to process short (most latency-sensitive requests). This PR fixes it by disallowing the merged pool for long requests and prioritizing the short ones.
Alexander Borzunov 9 months ago committed by GitHub
parent 55eb36ef48
commit 056f22515a
No known key found for this signature in database

@ -16,8 +16,15 @@ from petals.server.backend import TransformerBackend
from petals.server.memory_cache import Handle
from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import TaskPrioritizerBase
from petals.utils.convert_block import QuantType
from petals.utils.misc import DUMMY, is_dummy
# We prioritize short inference requests and make them use a *merged* inference pool,
# so they are processed without interruptions and extra overheads
# TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward
async def run_rpc_forward(
*flat_tensors: torch.Tensor,
@ -127,9 +134,11 @@ async def iterate_rpc_inference(
active_adapter: Optional[str],
input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],
cache_handles: Sequence[Sequence[Handle]],
max_length: int,
prioritizer: TaskPrioritizerBase,
points: int,
quant_type: QuantType,
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
assert len(cache_handles) == len(requested_backends)
@ -138,6 +147,7 @@ async def iterate_rpc_inference(
async for request, step_metadata in input_iterator:
hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
batch_size, length_increment, _ = hidden_states.shape
# Cast inputs to backend dtype
hidden_states =[0].dtype)
@ -154,34 +164,40 @@ async def iterate_rpc_inference(
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
if prefix_length + length_increment > max_length:
raise ValueError(
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
f" exceeds pre-allocated maximum {max_length}"
merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType.NF4 else MAX_SHORT_INFERENCE_TOKENS
can_merge_pools = batch_size * length_increment <= merge_max_tokens
priority = prioritizer.prioritize(
inference_infos = tuple(
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
for uid, handles in zip(requested_uids, cache_handles)
type="short_inference" if can_merge_pools else "inference",
if hidden_states.numel() == 0:
pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
# A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache.
if hidden_states.numel() > 0:
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
if can_merge_pools:
inference_infos = tuple(
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
for uid, handles in zip(requested_uids, cache_handles)
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):
inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
(hidden_states,) = await backend.inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, prompt, priority=priority
# serialize and send last layer outputs
output_tensors = [

@ -34,6 +34,7 @@ from petals.server.backend import TransformerBackend
from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward
from petals.server.memory_cache import Handle
from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
from petals.utils.convert_block import QuantType
logger = get_logger(__name__)
@ -71,6 +72,7 @@ class TransformerConnectionHandler(ConnectionHandler):
session_timeout: float,
step_timeout: float,
task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
quant_type: QuantType,
super().__init__(dht, module_backends)
for module_backend in self.module_backends.values():
@ -88,6 +90,7 @@ class TransformerConnectionHandler(ConnectionHandler):
self.request_timeout = request_timeout
self.session_timeout, self.step_timeout = session_timeout, step_timeout
self._prioritizer = task_prioritizer
self.quant_type = quant_type
async def add_p2p_handlers(self, *args, **kwargs) -> None:
if self._listener_task is None:
@ -176,6 +179,7 @@ class TransformerConnectionHandler(ConnectionHandler):
if can_push:
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))

@ -560,6 +560,7 @@ class ModuleContainer(threading.Thread):
for i in range(num_handlers)

@ -13,9 +13,10 @@ class TaskPrioritizerBase(ABC):
class DummyTaskPrioritizer(TaskPrioritizerBase):
"""Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
# Inference steps (especially short ones) go first since they are more latency-sensitive
if kwargs.get("type") == "short_inference":
return 1.0
if kwargs.get("type") == "inference":
return 1.0 # inference steps go first since they are more latency-sensitive
return 2.0 # forward, backward
return 2.0
return 3.0 # Forward, backward

@ -4,6 +4,7 @@ import pytest
import torch
from petals import AutoDistributedConfig, RemoteSequential
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
from petals.server.from_pretrained import load_pretrained_block
from test_utils import *
@ -13,26 +14,30 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_sequential = RemoteSequential(config)
for block_index in random.sample(range(config.num_hidden_layers), 3):
remote_block = remote_sequential[block_index]
block_index = random.randint(0, config.num_hidden_layers - 1)
remote_block = remote_sequential[block_index]
inputs = torch.randn(1, 8, config.hidden_size)
outputs_forward = remote_block(inputs)
inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS + 8, config.hidden_size)
outputs_forward = remote_block(inputs)
outputs_inference = []
with torch.inference_mode():
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = []
with torch.inference_mode():
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
# Test long inference (unmerged inference pools)
outputs_inference.append(sess.step(inputs[:, : MAX_SHORT_INFERENCE_TOKENS + 1, :]))
# test that max length is respected
with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
sess.step(inputs[:, -1:, :])
assert "Maximum length exceeded" in repr(exc_info.value)
outputs_inference =, dim=1)
# Test short inference (merged inference pools)
for i in range(MAX_SHORT_INFERENCE_TOKENS + 1, inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
(outputs_local,) = ref_block(inputs)
# test that max length is respected
with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
sess.step(inputs[:, -1:, :])
assert "Maximum length exceeded" in repr(exc_info.value)
outputs_inference =, dim=1)
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
(outputs_local,) = ref_block(inputs)
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)

@ -40,7 +40,7 @@ def test_remote_sequential():
assert hidden.shape == test_inputs.shape
assert hidden.requires_grad
second_half_outputs = second_half(hidden)
assert torch.allclose(second_half_outputs, full_outputs, atol=3e-4)
assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3)
(second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad, atol=1e-2)
