diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 4f2a645..470ac5f 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -27,7 +27,7 @@ from petals.server.block_utils import get_block_size from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability -from petals.server.throughput import get_dtype_name, get_host_throughput +from petals.server.throughput import get_dtype_name, get_server_throughput from petals.utils.convert_block import check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR @@ -193,10 +193,11 @@ class Server: assert isinstance(throughput, float) or throughput in ["auto", "eval"] if throughput in ["auto", "eval"]: - throughput = get_host_throughput( + throughput = get_server_throughput( self.block_config, device, torch_dtype, + num_blocks=num_blocks, load_in_8bit=load_in_8bit, tensor_parallel_devices=self.tensor_parallel_devices, force_eval=(throughput == "eval"), diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index ac43759..a60a24d 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -1,11 +1,12 @@ import fcntl import json +import math import os import time from collections import Counter from hashlib import sha256 from pathlib import Path -from typing import Optional, Sequence, Union +from typing import Dict, Optional, Sequence, Union import torch from hivemind.utils.logging import get_logger @@ -32,11 +33,12 @@ if not hasattr(speedtest, "Speedtest"): ) -def get_host_throughput( +def get_server_throughput( config: BloomConfig, device: torch.device, dtype: Union[str, torch.dtype], *, + num_blocks: int, load_in_8bit: bool, tensor_parallel_devices: Sequence[torch.device], force_eval: bool = False, @@ -47,7 +49,7 @@ def get_host_throughput( if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR lock_path = Path(cache_dir, "throughput.lock") - cache_path = Path(cache_dir, "throughput_v2.json") + cache_path = Path(cache_dir, "throughput_v3.json") # We use the system-wide lock since only one process at a time can measure the host throughput os.makedirs(lock_path.parent, exist_ok=True) @@ -85,7 +87,16 @@ def get_host_throughput( except Exception: logger.exception(f"Failed to save throughput info in {cache_path}") - return cache[cache_key] + throughput_info = cache[cache_key] + + # Most requests start at some block hosted by a server, then use all next blocks hosted on this server. + # Assuming the start block index is distributed uniformly, the average number of blocks used per request is + # E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2 + average_blocks_used = (num_blocks + 1) / 2 + throughput = throughput_info["compute_rps"] / average_blocks_used + throughput = min(throughput, throughput_info.get("network_rps", math.inf)) + logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks") + return throughput def measure_throughput_info( @@ -95,22 +106,24 @@ def measure_throughput_info( *, load_in_8bit: bool, tensor_parallel_devices: Sequence[torch.device], -) -> float: +) -> Dict[str, float]: """Measure network and compute throughput in forward pass tokens per second""" logger.info( "Measuring network and compute throughput. This takes about a minute and will be cached for future runs" ) - result = measure_compute_rps( - config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices - ) + throughput_info = { + "compute_rps": measure_compute_rps( + config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices + ) + } try: - result = min(result, measure_network_rps(config)) + throughput_info["network_rps"] = measure_network_rps(config) except Exception: logger.warning("Failed to measure network throughput:", exc_info=True) logger.warning("Proceeding with the compute throughput only") - return result + return throughput_info def measure_network_rps(config: BloomConfig) -> Optional[float]: @@ -127,10 +140,9 @@ def measure_network_rps(config: BloomConfig) -> Optional[float]: raise ValueError("speedtest has returned network_rps == 0") logger.info( - f"Network throughput: " - f"{network_info['download'] / 1e6:.2f} Mbit/s on download, " - f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, " - f"{network_rps:.1f} RPS" + f"Network throughput: {network_rps:.1f} RPS " + f"({network_info['download'] / 1e6:.2f} Mbit/s on download, " + f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)" ) return network_rps @@ -168,7 +180,8 @@ def measure_compute_rps( devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common()) logger.info( - f"Forward pass throughput ({devices_repr}, {get_dtype_name(dtype, load_in_8bit)}): " f"{device_rps:.1f} RPS" + f"Forward pass throughput: {device_rps:.1f} RPS per block " + f"({devices_repr}, {get_dtype_name(dtype, load_in_8bit)})" ) return device_rps