import fcntl import json import os import subprocess import tempfile import time from hashlib import sha256 from pathlib import Path from typing import Union import torch from hivemind.utils.logging import get_logger, use_hivemind_log_handler from petals.bloom.block import BloomBlock from petals.bloom.model import BloomConfig from petals.bloom.ops import build_alibi_tensor from petals.utils.convert_8bit import replace_8bit_linear use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", "petals", "throughput_v2.json") DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), "petals", "throughput.lock") def get_host_throughput( config: BloomConfig, device: torch.device, dtype: Union[str, torch.dtype], *, load_in_8bit: bool, force_eval: bool = False, cache_path: str = DEFAULT_CACHE_PATH, lock_path: str = DEFAULT_LOCK_PATH, ) -> float: # We use the system-wide lock since only one process at a time can measure the host throughput os.makedirs(lock_path.parent, exist_ok=True) with open(lock_path, "wb") as lock_fd:"Loading throughput info") fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX) # The OS will release the lock when lock_fd is closed or the process is killed cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}" cache_key += f"_device_{_get_device_name(device).replace(' ', '_')}" cache_key += f"_dtype_{_get_dtype_name(dtype, load_in_8bit)}" cache = {} try: if not force_eval and os.path.exists(cache_path): with open(cache_path) as cache_fd: cache = json.load(cache_fd) assert isinstance(cache, dict) except Exception: logger.exception(f"Failed to read throughput info from {cache_path}") cache = {} if cache_key not in cache: cache[cache_key] = measure_throughput_info(config, device, dtype, load_in_8bit=load_in_8bit) try: os.makedirs(cache_path.parent, exist_ok=True) with open(cache_path, "w") as cache_fd: json.dump(cache, cache_fd) except Exception: logger.exception(f"Failed to save throughput info in {cache_path}") return cache[cache_key] def measure_throughput_info( config: BloomConfig, device: torch.device, dtype: Union[str, torch.dtype], *, load_in_8bit: bool, ) -> float: """Measure network and compute throughput in forward pass tokens per second""" "Measuring network and compute throughput. This takes about a minute and will be cached for future runs" ) return min( measure_network_rps(config), measure_compute_rps(config, device, dtype, load_in_8bit=load_in_8bit), ) def measure_network_rps(config: BloomConfig) -> float: proc ="python3 -m petals.cli.speed_test --json", shell=True, capture_output=True) if proc.returncode != 0: raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})") network_info = json.loads(proc.stdout) bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request f"Network throughput: " f"{network_info['download'] / 1e6:.2f} Mbit/s on download, " f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, " f"{network_rps:.1f} RPS" ) return network_rps def measure_compute_rps( config: BloomConfig, device: torch.device, dtype: Union[str, torch.dtype], *, load_in_8bit: bool, n_tokens: int = 16, n_steps: int = 500, layer_index: int = 0, ) -> float: with torch.inference_mode(): block = BloomBlock(config, layer_index) if dtype != "auto": block = input_dtype = block.input_layernorm.weight.dtype if load_in_8bit: block = replace_8bit_linear(block) block = cache = None elapsed = 0 for step in range(n_steps + 1): dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=input_dtype) alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=input_dtype) start_time = time.perf_counter() _, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache) if step >= 1: # Skip the 1st step to exclude the initialization time elapsed += time.perf_counter() - start_time device_rps = n_steps * n_tokens / elapsed f"Forward pass throughput ({_get_device_name(device)}, {_get_dtype_name(dtype, load_in_8bit)}): " f"{device_rps:.1f} RPS" ) return device_rps def _get_device_name(device: torch.device) -> str: return f"{torch.cuda.get_device_name(device)} GPU" if device == "cuda" else "CPU" def _get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str: return "8-bit" if load_in_8bit else str(dtype)