Divide compute throughput by average no. of used blocks (#314)

See #192.
pull/315/head
Alexander Borzunov 1 year ago committed by GitHub
parent 6137b1b4b0
commit d9e7bfc949
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -27,7 +27,7 @@ from petals.server.block_utils import get_block_size
from petals.server.handler import TransformerConnectionHandler from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache from petals.server.memory_cache import MemoryCache
from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
from petals.server.throughput import get_dtype_name, get_host_throughput from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.convert_block import check_device_balance, convert_block from petals.utils.convert_block import check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR from petals.utils.disk_cache import DEFAULT_CACHE_DIR
@ -193,10 +193,11 @@ class Server:
assert isinstance(throughput, float) or throughput in ["auto", "eval"] assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]: if throughput in ["auto", "eval"]:
throughput = get_host_throughput( throughput = get_server_throughput(
self.block_config, self.block_config,
device, device,
torch_dtype, torch_dtype,
num_blocks=num_blocks,
load_in_8bit=load_in_8bit, load_in_8bit=load_in_8bit,
tensor_parallel_devices=self.tensor_parallel_devices, tensor_parallel_devices=self.tensor_parallel_devices,
force_eval=(throughput == "eval"), force_eval=(throughput == "eval"),

@ -1,11 +1,12 @@
import fcntl import fcntl
import json import json
import math
import os import os
import time import time
from collections import Counter from collections import Counter
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
from typing import Optional, Sequence, Union from typing import Dict, Optional, Sequence, Union
import torch import torch
from hivemind.utils.logging import get_logger from hivemind.utils.logging import get_logger
@ -32,11 +33,12 @@ if not hasattr(speedtest, "Speedtest"):
) )
def get_host_throughput( def get_server_throughput(
config: BloomConfig, config: BloomConfig,
device: torch.device, device: torch.device,
dtype: Union[str, torch.dtype], dtype: Union[str, torch.dtype],
*, *,
num_blocks: int,
load_in_8bit: bool, load_in_8bit: bool,
tensor_parallel_devices: Sequence[torch.device], tensor_parallel_devices: Sequence[torch.device],
force_eval: bool = False, force_eval: bool = False,
@ -47,7 +49,7 @@ def get_host_throughput(
if cache_dir is None: if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR cache_dir = DEFAULT_CACHE_DIR
lock_path = Path(cache_dir, "throughput.lock") lock_path = Path(cache_dir, "throughput.lock")
cache_path = Path(cache_dir, "throughput_v2.json") cache_path = Path(cache_dir, "throughput_v3.json")
# We use the system-wide lock since only one process at a time can measure the host throughput # We use the system-wide lock since only one process at a time can measure the host throughput
os.makedirs(lock_path.parent, exist_ok=True) os.makedirs(lock_path.parent, exist_ok=True)
@ -85,7 +87,16 @@ def get_host_throughput(
except Exception: except Exception:
logger.exception(f"Failed to save throughput info in {cache_path}") logger.exception(f"Failed to save throughput info in {cache_path}")
return cache[cache_key] throughput_info = cache[cache_key]
# Most requests start at some block hosted by a server, then use all next blocks hosted on this server.
# Assuming the start block index is distributed uniformly, the average number of blocks used per request is
# E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2
average_blocks_used = (num_blocks + 1) / 2
throughput = throughput_info["compute_rps"] / average_blocks_used
throughput = min(throughput, throughput_info.get("network_rps", math.inf))
logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks")
return throughput
def measure_throughput_info( def measure_throughput_info(
@ -95,22 +106,24 @@ def measure_throughput_info(
*, *,
load_in_8bit: bool, load_in_8bit: bool,
tensor_parallel_devices: Sequence[torch.device], tensor_parallel_devices: Sequence[torch.device],
) -> float: ) -> Dict[str, float]:
"""Measure network and compute throughput in forward pass tokens per second""" """Measure network and compute throughput in forward pass tokens per second"""
logger.info( logger.info(
"Measuring network and compute throughput. This takes about a minute and will be cached for future runs" "Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
) )
result = measure_compute_rps( throughput_info = {
config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices "compute_rps": measure_compute_rps(
) config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
)
}
try: try:
result = min(result, measure_network_rps(config)) throughput_info["network_rps"] = measure_network_rps(config)
except Exception: except Exception:
logger.warning("Failed to measure network throughput:", exc_info=True) logger.warning("Failed to measure network throughput:", exc_info=True)
logger.warning("Proceeding with the compute throughput only") logger.warning("Proceeding with the compute throughput only")
return result return throughput_info
def measure_network_rps(config: BloomConfig) -> Optional[float]: def measure_network_rps(config: BloomConfig) -> Optional[float]:
@ -127,10 +140,9 @@ def measure_network_rps(config: BloomConfig) -> Optional[float]:
raise ValueError("speedtest has returned network_rps == 0") raise ValueError("speedtest has returned network_rps == 0")
logger.info( logger.info(
f"Network throughput: " f"Network throughput: {network_rps:.1f} RPS "
f"{network_info['download'] / 1e6:.2f} Mbit/s on download, " f"({network_info['download'] / 1e6:.2f} Mbit/s on download, "
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, " f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)"
f"{network_rps:.1f} RPS"
) )
return network_rps return network_rps
@ -168,7 +180,8 @@ def measure_compute_rps(
devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common()) devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())
logger.info( logger.info(
f"Forward pass throughput ({devices_repr}, {get_dtype_name(dtype, load_in_8bit)}): " f"{device_rps:.1f} RPS" f"Forward pass throughput: {device_rps:.1f} RPS per block "
f"({devices_repr}, {get_dtype_name(dtype, load_in_8bit)})"
) )
return device_rps return device_rps

Loading…
Cancel
Save