|
|
|
@ -1,11 +1,12 @@
|
|
|
|
|
import fcntl
|
|
|
|
|
import json
|
|
|
|
|
import math
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
from collections import Counter
|
|
|
|
|
from hashlib import sha256
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Optional, Sequence, Union
|
|
|
|
|
from typing import Dict, Optional, Sequence, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
@ -32,11 +33,12 @@ if not hasattr(speedtest, "Speedtest"):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_host_throughput(
|
|
|
|
|
def get_server_throughput(
|
|
|
|
|
config: BloomConfig,
|
|
|
|
|
device: torch.device,
|
|
|
|
|
dtype: Union[str, torch.dtype],
|
|
|
|
|
*,
|
|
|
|
|
num_blocks: int,
|
|
|
|
|
load_in_8bit: bool,
|
|
|
|
|
tensor_parallel_devices: Sequence[torch.device],
|
|
|
|
|
force_eval: bool = False,
|
|
|
|
@ -47,7 +49,7 @@ def get_host_throughput(
|
|
|
|
|
if cache_dir is None:
|
|
|
|
|
cache_dir = DEFAULT_CACHE_DIR
|
|
|
|
|
lock_path = Path(cache_dir, "throughput.lock")
|
|
|
|
|
cache_path = Path(cache_dir, "throughput_v2.json")
|
|
|
|
|
cache_path = Path(cache_dir, "throughput_v3.json")
|
|
|
|
|
|
|
|
|
|
# We use the system-wide lock since only one process at a time can measure the host throughput
|
|
|
|
|
os.makedirs(lock_path.parent, exist_ok=True)
|
|
|
|
@ -85,7 +87,16 @@ def get_host_throughput(
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.exception(f"Failed to save throughput info in {cache_path}")
|
|
|
|
|
|
|
|
|
|
return cache[cache_key]
|
|
|
|
|
throughput_info = cache[cache_key]
|
|
|
|
|
|
|
|
|
|
# Most requests start at some block hosted by a server, then use all next blocks hosted on this server.
|
|
|
|
|
# Assuming the start block index is distributed uniformly, the average number of blocks used per request is
|
|
|
|
|
# E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2
|
|
|
|
|
average_blocks_used = (num_blocks + 1) / 2
|
|
|
|
|
throughput = throughput_info["compute_rps"] / average_blocks_used
|
|
|
|
|
throughput = min(throughput, throughput_info.get("network_rps", math.inf))
|
|
|
|
|
logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks")
|
|
|
|
|
return throughput
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def measure_throughput_info(
|
|
|
|
@ -95,22 +106,24 @@ def measure_throughput_info(
|
|
|
|
|
*,
|
|
|
|
|
load_in_8bit: bool,
|
|
|
|
|
tensor_parallel_devices: Sequence[torch.device],
|
|
|
|
|
) -> float:
|
|
|
|
|
) -> Dict[str, float]:
|
|
|
|
|
"""Measure network and compute throughput in forward pass tokens per second"""
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
"Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
result = measure_compute_rps(
|
|
|
|
|
config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
|
|
|
|
|
)
|
|
|
|
|
throughput_info = {
|
|
|
|
|
"compute_rps": measure_compute_rps(
|
|
|
|
|
config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
|
|
|
|
|
)
|
|
|
|
|
}
|
|
|
|
|
try:
|
|
|
|
|
result = min(result, measure_network_rps(config))
|
|
|
|
|
throughput_info["network_rps"] = measure_network_rps(config)
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.warning("Failed to measure network throughput:", exc_info=True)
|
|
|
|
|
logger.warning("Proceeding with the compute throughput only")
|
|
|
|
|
return result
|
|
|
|
|
return throughput_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def measure_network_rps(config: BloomConfig) -> Optional[float]:
|
|
|
|
@ -127,10 +140,9 @@ def measure_network_rps(config: BloomConfig) -> Optional[float]:
|
|
|
|
|
raise ValueError("speedtest has returned network_rps == 0")
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Network throughput: "
|
|
|
|
|
f"{network_info['download'] / 1e6:.2f} Mbit/s on download, "
|
|
|
|
|
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, "
|
|
|
|
|
f"{network_rps:.1f} RPS"
|
|
|
|
|
f"Network throughput: {network_rps:.1f} RPS "
|
|
|
|
|
f"({network_info['download'] / 1e6:.2f} Mbit/s on download, "
|
|
|
|
|
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)"
|
|
|
|
|
)
|
|
|
|
|
return network_rps
|
|
|
|
|
|
|
|
|
@ -168,7 +180,8 @@ def measure_compute_rps(
|
|
|
|
|
devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Forward pass throughput ({devices_repr}, {get_dtype_name(dtype, load_in_8bit)}): " f"{device_rps:.1f} RPS"
|
|
|
|
|
f"Forward pass throughput: {device_rps:.1f} RPS per block "
|
|
|
|
|
f"({devices_repr}, {get_dtype_name(dtype, load_in_8bit)})"
|
|
|
|
|
)
|
|
|
|
|
return device_rps
|
|
|
|
|
|
|
|
|
|