You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/server/throughput.py

125 lines
4.5 KiB
Python

import fcntl
import json
import os
import subprocess
import tempfile
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Dict, Union
import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from petals.bloom.block import BloomBlock
from petals.bloom.model import BloomConfig
from petals.bloom.ops import build_alibi_tensor
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", "petals", "throughput.json")
DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), "petals", "throughput.lock")
@dataclass
class ThroughputInfo:
network_rps: float
device_rps: Dict[str, float]
def get_host_throughput(
device: Union[str, torch.device],
force_eval: bool = False,
cache_path: str = DEFAULT_CACHE_PATH,
lock_path: str = DEFAULT_LOCK_PATH,
) -> float:
# We only keep the device type, assuming that the throughput is similar among all host's GPUs
device = torch.device(device).type
# We use the system-wide lock since only one process at a time can measure the host throughput
os.makedirs(lock_path.parent, exist_ok=True)
with open(lock_path, "wb") as lock_fd:
logger.info("Loading throughput info")
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
# The OS will release the lock when lock_fd is closed or the process is killed
info = None
try:
if not force_eval and os.path.exists(cache_path):
with open(cache_path) as cache_fd:
info = ThroughputInfo(**json.load(cache_fd))
if device not in info.device_rps:
force_eval = True
except Exception:
logger.exception(f"Failed to read throughput info from {cache_path}")
force_eval = True
if force_eval or info is None:
info = measure_throughput_info()
try:
os.makedirs(cache_path.parent, exist_ok=True)
with open(cache_path, "w") as cache_fd:
json.dump(asdict(info), cache_fd)
except Exception:
logger.exception(f"Failed to save throughput info in {cache_path}")
throughput = min(info.network_rps, info.device_rps[device])
return throughput
def measure_throughput_info() -> ThroughputInfo:
logger.info(
"Measuring network, CPU, and GPU throughput. " "This takes about a minute and will be cached for future runs"
)
# We measure throughput in "(inference) requests per second" (RPS) using a fixed model
config = BloomConfig.from_pretrained("bigscience/test-bloomd-6b3")
network_rps = measure_network_rps(config)
device_rps = {"cpu": measure_device_rps("cpu", config)}
if torch.cuda.is_available():
device_rps["cuda"] = measure_device_rps("cuda", config)
return ThroughputInfo(network_rps=network_rps, device_rps=device_rps)
def measure_network_rps(config: BloomConfig) -> float:
proc = subprocess.run("python3 -m petals.cli.speed_test --json", shell=True, capture_output=True)
if proc.returncode != 0:
raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
network_info = json.loads(proc.stdout)
bits_per_request = config.hidden_size * 32
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
logger.info(
f"Network throughput: "
f"{network_info['download'] / 1e6:.2f} Mbit/s on download, "
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, "
f"{network_rps:.2f} RPS"
)
return network_rps
def measure_device_rps(device: str, config: BloomConfig, layer_index: int = 0, n_steps: int = 500) -> float:
with torch.inference_mode():
block = BloomBlock(config, layer_index).to(device)
cache = None
elapsed = 0
for i in range(n_steps):
dummy_input = torch.randn(1, 1, config.hidden_size, device=device)
alibi = build_alibi_tensor(i + 1, config.num_attention_heads, dtype=torch.float32, device=device)
start_time = time.perf_counter()
_, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
elapsed += time.perf_counter() - start_time
device_rps = n_steps / elapsed
device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == "cuda" else "CPU"
logger.info(f"Compute throughput ({device_name}): {device_rps:.2f} RPS")
return device_rps