Divide compute throughput by average no. of used blocks (#314)

See #192.
pull/315/head
Alexander Borzunov 1 year ago committed by GitHub
parent 6137b1b4b0
commit d9e7bfc949
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -27,7 +27,7 @@ from petals.server.block_utils import get_block_size
from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache
from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
from petals.server.throughput import get_dtype_name, get_host_throughput
from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.convert_block import check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
@ -193,10 +193,11 @@ class Server:
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
throughput = get_host_throughput(
throughput = get_server_throughput(
self.block_config,
device,
torch_dtype,
num_blocks=num_blocks,
load_in_8bit=load_in_8bit,
tensor_parallel_devices=self.tensor_parallel_devices,
force_eval=(throughput == "eval"),

@ -1,11 +1,12 @@
import fcntl
import json
import math
import os
import time
from collections import Counter
from hashlib import sha256
from pathlib import Path
from typing import Optional, Sequence, Union
from typing import Dict, Optional, Sequence, Union
import torch
from hivemind.utils.logging import get_logger
@ -32,11 +33,12 @@ if not hasattr(speedtest, "Speedtest"):
)
def get_host_throughput(
def get_server_throughput(
config: BloomConfig,
device: torch.device,
dtype: Union[str, torch.dtype],
*,
num_blocks: int,
load_in_8bit: bool,
tensor_parallel_devices: Sequence[torch.device],
force_eval: bool = False,
@ -47,7 +49,7 @@ def get_host_throughput(
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
lock_path = Path(cache_dir, "throughput.lock")
cache_path = Path(cache_dir, "throughput_v2.json")
cache_path = Path(cache_dir, "throughput_v3.json")
# We use the system-wide lock since only one process at a time can measure the host throughput
os.makedirs(lock_path.parent, exist_ok=True)
@ -85,7 +87,16 @@ def get_host_throughput(
except Exception:
logger.exception(f"Failed to save throughput info in {cache_path}")
return cache[cache_key]
throughput_info = cache[cache_key]
# Most requests start at some block hosted by a server, then use all next blocks hosted on this server.
# Assuming the start block index is distributed uniformly, the average number of blocks used per request is
# E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2
average_blocks_used = (num_blocks + 1) / 2
throughput = throughput_info["compute_rps"] / average_blocks_used
throughput = min(throughput, throughput_info.get("network_rps", math.inf))
logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks")
return throughput
def measure_throughput_info(
@ -95,22 +106,24 @@ def measure_throughput_info(
*,
load_in_8bit: bool,
tensor_parallel_devices: Sequence[torch.device],
) -> float:
) -> Dict[str, float]:
"""Measure network and compute throughput in forward pass tokens per second"""
logger.info(
"Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
)
result = measure_compute_rps(
config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
)
throughput_info = {
"compute_rps": measure_compute_rps(
config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
)
}
try:
result = min(result, measure_network_rps(config))
throughput_info["network_rps"] = measure_network_rps(config)
except Exception:
logger.warning("Failed to measure network throughput:", exc_info=True)
logger.warning("Proceeding with the compute throughput only")
return result
return throughput_info
def measure_network_rps(config: BloomConfig) -> Optional[float]:
@ -127,10 +140,9 @@ def measure_network_rps(config: BloomConfig) -> Optional[float]:
raise ValueError("speedtest has returned network_rps == 0")
logger.info(
f"Network throughput: "
f"{network_info['download'] / 1e6:.2f} Mbit/s on download, "
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, "
f"{network_rps:.1f} RPS"
f"Network throughput: {network_rps:.1f} RPS "
f"({network_info['download'] / 1e6:.2f} Mbit/s on download, "
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)"
)
return network_rps
@ -168,7 +180,8 @@ def measure_compute_rps(
devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())
logger.info(
f"Forward pass throughput ({devices_repr}, {get_dtype_name(dtype, load_in_8bit)}): " f"{device_rps:.1f} RPS"
f"Forward pass throughput: {device_rps:.1f} RPS per block "
f"({devices_repr}, {get_dtype_name(dtype, load_in_8bit)})"
)
return device_rps

Loading…
Cancel
Save