diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index 9208deb..c1f1d93 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -16,8 +16,15 @@ from petals.server.backend import TransformerBackend from petals.server.memory_cache import Handle from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_prioritizer import TaskPrioritizerBase +from petals.utils.convert_block import QuantType from petals.utils.misc import DUMMY, is_dummy +# We prioritize short inference requests and make them use a *merged* inference pool, +# so they are processed without interruptions and extra overheads +# TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward +MAX_SHORT_INFERENCE_TOKENS = 128 +MAX_NF4_SHORT_INFERENCE_TOKENS = 1 + async def run_rpc_forward( *flat_tensors: torch.Tensor, @@ -127,9 +134,11 @@ async def iterate_rpc_inference( active_adapter: Optional[str], input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]], cache_handles: Sequence[Sequence[Handle]], + *, max_length: int, prioritizer: TaskPrioritizerBase, points: int, + quant_type: QuantType, ) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]: assert len(cache_handles) == len(requested_backends) @@ -138,6 +147,7 @@ async def iterate_rpc_inference( async for request, step_metadata in input_iterator: hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors) + batch_size, length_increment, _ = hidden_states.shape # Cast inputs to backend dtype hidden_states = hidden_states.to(requested_backends[0].dtype) @@ -154,34 +164,40 @@ async def iterate_rpc_inference( if not (len(requested_backends) == len(prompts)): raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") - length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq) if prefix_length + length_increment > max_length: raise ValueError( f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" f" exceeds pre-allocated maximum {max_length}" ) + merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType.NF4 else MAX_SHORT_INFERENCE_TOKENS + can_merge_pools = batch_size * length_increment <= merge_max_tokens priority = prioritizer.prioritize( hidden_states, hypo_ids, points=point_per_piece, requested_uids=requested_uids, - type="inference", - ) - - inference_infos = tuple( - InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) - for uid, handles in zip(requested_uids, cache_handles) + type="short_inference" if can_merge_pools else "inference", ) - if hidden_states.numel() == 0: - pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g. - # when user wants to pre-allocate cache or check that server *can* allocate that cache - else: + # A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g. + # when user wants to pre-allocate cache or check that server *can* allocate that cache. + if hidden_states.numel() > 0: assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor" - (hidden_states,) = await requested_backends[0].inference_pool.submit_task( - hidden_states, hypo_ids, inference_infos, *prompts, priority=priority - ) + if can_merge_pools: + inference_infos = tuple( + InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) + for uid, handles in zip(requested_uids, cache_handles) + ) + (hidden_states,) = await requested_backends[0].inference_pool.submit_task( + hidden_states, hypo_ids, inference_infos, *prompts, priority=priority + ) + else: + for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts): + inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),) + (hidden_states,) = await backend.inference_pool.submit_task( + hidden_states, hypo_ids, inference_infos, prompt, priority=priority + ) # serialize and send last layer outputs output_tensors = [ diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index b9be294..00df0d5 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -34,6 +34,7 @@ from petals.server.backend import TransformerBackend from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward from petals.server.memory_cache import Handle from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase +from petals.utils.convert_block import QuantType logger = get_logger(__name__) @@ -71,6 +72,7 @@ class TransformerConnectionHandler(ConnectionHandler): session_timeout: float, step_timeout: float, task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(), + quant_type: QuantType, ): super().__init__(dht, module_backends) for module_backend in self.module_backends.values(): @@ -88,6 +90,7 @@ class TransformerConnectionHandler(ConnectionHandler): self.request_timeout = request_timeout self.session_timeout, self.step_timeout = session_timeout, step_timeout self._prioritizer = task_prioritizer + self.quant_type = quant_type async def add_p2p_handlers(self, *args, **kwargs) -> None: if self._listener_task is None: @@ -176,6 +179,7 @@ class TransformerConnectionHandler(ConnectionHandler): max_length=max_length, prioritizer=self._prioritizer, points=points, + quant_type=self.quant_type, ): if can_push: task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata)) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index bf7470a..405dd9b 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -560,6 +560,7 @@ class ModuleContainer(threading.Thread): request_timeout=request_timeout, session_timeout=session_timeout, step_timeout=step_timeout, + quant_type=QuantType[server_info.quant_type.upper()], ) for i in range(num_handlers) ] diff --git a/src/petals/server/task_prioritizer.py b/src/petals/server/task_prioritizer.py index 6490fc5..4a575b1 100644 --- a/src/petals/server/task_prioritizer.py +++ b/src/petals/server/task_prioritizer.py @@ -13,9 +13,10 @@ class TaskPrioritizerBase(ABC): class DummyTaskPrioritizer(TaskPrioritizerBase): - """Simple implementation of TaskPrioritizer which gives constant zero priority for every task""" - def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float: + # Inference steps (especially short ones) go first since they are more latency-sensitive + if kwargs.get("type") == "short_inference": + return 1.0 if kwargs.get("type") == "inference": - return 1.0 # inference steps go first since they are more latency-sensitive - return 2.0 # forward, backward + return 2.0 + return 3.0 # Forward, backward diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index d98918b..80c695f 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -4,6 +4,7 @@ import pytest import torch from petals import AutoDistributedConfig, RemoteSequential +from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS from petals.server.from_pretrained import load_pretrained_block from test_utils import * @@ -13,26 +14,30 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) remote_sequential = RemoteSequential(config) - for block_index in random.sample(range(config.num_hidden_layers), 3): - remote_block = remote_sequential[block_index] + block_index = random.randint(0, config.num_hidden_layers - 1) + remote_block = remote_sequential[block_index] - inputs = torch.randn(1, 8, config.hidden_size) - outputs_forward = remote_block(inputs) + inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS + 8, config.hidden_size) + outputs_forward = remote_block(inputs) - outputs_inference = [] - with torch.inference_mode(): - with remote_block.inference_session(max_length=inputs.shape[1]) as sess: - for i in range(inputs.shape[1]): - outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) + outputs_inference = [] + with torch.inference_mode(): + with remote_block.inference_session(max_length=inputs.shape[1]) as sess: + # Test long inference (unmerged inference pools) + outputs_inference.append(sess.step(inputs[:, : MAX_SHORT_INFERENCE_TOKENS + 1, :])) - # test that max length is respected - with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info: - sess.step(inputs[:, -1:, :]) - assert "Maximum length exceeded" in repr(exc_info.value) - outputs_inference = torch.cat(outputs_inference, dim=1) + # Test short inference (merged inference pools) + for i in range(MAX_SHORT_INFERENCE_TOKENS + 1, inputs.shape[1]): + outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) - ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) - (outputs_local,) = ref_block(inputs) + # test that max length is respected + with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info: + sess.step(inputs[:, -1:, :]) + assert "Maximum length exceeded" in repr(exc_info.value) + outputs_inference = torch.cat(outputs_inference, dim=1) - assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward) - assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference) + ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) + (outputs_local,) = ref_block(inputs) + + assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward) + assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference) diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 9189e68..30698c5 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -40,7 +40,7 @@ def test_remote_sequential(): assert hidden.shape == test_inputs.shape assert hidden.requires_grad second_half_outputs = second_half(hidden) - assert torch.allclose(second_half_outputs, full_outputs, atol=3e-4) + assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3) (second_half_outputs * grad_proj).sum().backward() assert torch.allclose(test_inputs.grad, full_grad, atol=1e-2)