Measure and cache network & compute throughput (#21)

pull/22/head
Alexander Borzunov 2 years ago committed by GitHub
parent ac7df18dfa
commit 75856e4769
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -41,8 +41,13 @@ def main():
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--throughput', type=float, default=1.0,
help='Expected server throughput')
parser.add_argument('--throughput',
type=lambda value: value if value in ['auto', 'eval'] else float(value),
default='auto',
help='Expected server throughput (a float measured in RPS). '
'If set to "auto" (default), the script evaluates network and compute throughput '
'on the first run and uses these estimates for future runs. '
'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
parser.add_argument('--update_period', type=float, required=False, default=30,
help='Server will report experts to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,

File diff suppressed because it is too large Load Diff

@ -2,4 +2,5 @@ from src.bloom import *
from src.client import *
from src.dht_utils import declare_active_modules, get_remote_module
project_name = "bloomd"
__version__ = "0.2"

@ -3,6 +3,7 @@ Utilities for declaring and retrieving active model layers using a shared DHT.
"""
from __future__ import annotations
import math
from functools import partial
from typing import Dict, List, Optional, Sequence, Union
@ -134,11 +135,10 @@ async def _get_remote_module_infos(
for peer_id, server_info in metadata.value.items():
try:
peer_id = PeerID.from_base58(peer_id)
server_info = server_info.value
if not (isinstance(server_info, tuple) and len(server_info) == 2 and
isinstance(server_info[0], int) and isinstance(server_info[1], float)):
raise ValueError(f"Invalid server info for uid={uid}, peer_id={peer_id}: {server_info}")
state, throughput = server_info
state, throughput = server_info.value
if not (isinstance(state, int) and isinstance(throughput, float) and
math.isfinite(throughput) and throughput >= 0.0):
raise ValueError(f"Invalid server info: {server_info}")
servers[peer_id] = ServerInfo(ServerState(state), throughput)
except (TypeError, ValueError) as e:
logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")

@ -1,8 +1,10 @@
from __future__ import annotations
import multiprocessing as mp
import random
import threading
from typing import Dict, Optional, Sequence, Union
import time
from typing import Dict, Literal, Optional, Sequence, Union
import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
@ -19,6 +21,7 @@ from src.server.backend import TransformerBackend
from src.server.block_selection import choose_best_blocks
from src.server.cache import MemoryCache
from src.server.handler import TransformerConnectionHandler
from src.server.throughput import get_host_throughput
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -95,7 +98,7 @@ class Server(threading.Thread):
cls,
prefix: Optional[str],
converted_model_name_or_path: str,
throughput: float,
throughput: Union[float, Literal['auto', 'eval']],
num_blocks: Optional[int] = None,
block_indices: Optional[str] = None,
num_handlers: Optional[int] = None,
@ -103,13 +106,14 @@ class Server(threading.Thread):
max_batch_size: int = 4096,
torch_dtype: str = "auto",
cache_size_bytes: Optional[int] = None,
device: Union[str, torch.device] = None,
device: Optional[Union[str, torch.device]] = None,
initial_peers: Sequence[str] = (),
compression=CompressionType.NONE,
stats_report_interval: Optional[int] = None,
custom_module_path=None,
update_period: float = 30,
expiration: Optional[float] = None,
max_block_selection_delay: float = 1,
use_auth_token: Optional[str] = None,
*,
start: bool,
@ -136,6 +140,10 @@ class Server(threading.Thread):
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
memory_cache = MemoryCache(device, cache_size_bytes)
assert isinstance(throughput, float) or throughput in ['auto', 'eval']
if throughput in ['auto', 'eval']:
throughput = get_host_throughput(device, force_eval=(throughput == 'eval'))
if isinstance(torch_dtype, str):
torch_dtype = DTYPE_MAP[torch_dtype]
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
@ -153,6 +161,10 @@ class Server(threading.Thread):
raise
block_indices = range(first_block_index, last_block_index)
else:
# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
# this delay decreases the probability of a race condition while choosing the best blocks to serve.
time.sleep(random.random() * max_block_selection_delay)
assert num_blocks is not None
uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)]
module_infos = get_remote_module_infos(dht, uids, expiration_time=float("inf"))

@ -0,0 +1,126 @@
import fcntl
import json
import os
import subprocess
import tempfile
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Dict, Union
import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src import project_name
from src.bloom.block import BloomBlock
from src.bloom.model import BloomConfig
from src.bloom.ops import build_alibi_tensor
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
DEFAULT_CACHE_PATH = Path(Path.home(), '.cache', project_name, 'throughput.json')
DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), project_name, 'throughput.lock')
SPEED_TEST_PATH = Path(Path(__file__).absolute().parents[2], 'cli', 'speed_test.py')
@dataclass
class ThroughputInfo:
network_rps: float
device_rps: Dict[str, float]
def get_host_throughput(
device: Union[str, torch.device],
force_eval: bool = False,
cache_path: str = DEFAULT_CACHE_PATH,
lock_path: str = DEFAULT_LOCK_PATH,
) -> float:
# We only keep the device type, assuming that the throughput is similar among all host's GPUs
device = torch.device(device).type
# We use the system-wide lock since only one process at a time can measure the host throughput
os.makedirs(lock_path.parent, exist_ok=True)
with open(lock_path, 'wb') as lock_fd:
logger.info("Loading throughput info")
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
# The OS will release the lock when lock_fd is closed or the process is killed
info = None
try:
if not force_eval and os.path.exists(cache_path):
with open(cache_path) as cache_fd:
info = ThroughputInfo(**json.load(cache_fd))
if device not in info.device_rps:
force_eval = True
except Exception:
logger.exception(f"Failed to read throughput info from {cache_path}")
force_eval = True
if force_eval or info is None:
info = measure_throughput_info()
try:
os.makedirs(cache_path.parent, exist_ok=True)
with open(cache_path, 'w') as cache_fd:
json.dump(asdict(info), cache_fd)
except Exception:
logger.exception(f"Failed to save throughput info in {cache_path}")
throughput = min(info.network_rps, info.device_rps[device])
return throughput
def measure_throughput_info() -> ThroughputInfo:
logger.info("Measuring network, CPU, and GPU throughput. "
"This takes about a minute and will be cached for future runs")
# We measure throughput in "(inference) requests per second" (RPS) using a fixed model
config = BloomConfig.from_pretrained('bigscience/test-bloomd-6b3')
network_rps = measure_network_rps(config)
device_rps = {'cpu': measure_device_rps('cpu', config)}
if torch.cuda.is_available():
device_rps['cuda'] = measure_device_rps('cuda', config)
return ThroughputInfo(network_rps=network_rps, device_rps=device_rps)
def measure_network_rps(config: BloomConfig) -> float:
proc = subprocess.run([SPEED_TEST_PATH, '--json'], capture_output=True)
if proc.returncode != 0:
raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
network_info = json.loads(proc.stdout)
bits_per_request = config.hidden_size * 32
network_rps = min(network_info['download'], network_info['upload']) / bits_per_request
logger.info(
f"Network throughput: "
f"{network_info['download'] / 1e6:.2f} Mbit/s on download, "
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, "
f"{network_rps:.2f} RPS"
)
return network_rps
def measure_device_rps(device: str, config: BloomConfig, layer_index: int = 0, n_steps: int = 500) -> float:
with torch.inference_mode():
block = BloomBlock(config, layer_index).to(device)
cache = None
elapsed = 0
for i in range(n_steps):
dummy_input = torch.randn(1, 1, config.hidden_size, device=device)
alibi = build_alibi_tensor(i + 1, config.num_attention_heads, dtype=torch.float32, device=device)
start_time = time.perf_counter()
_, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
elapsed += time.perf_counter() - start_time
device_rps = n_steps / elapsed
device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == 'cuda' else 'CPU'
logger.info(f"Compute throughput ({device_name}): {device_rps:.2f} RPS")
return device_rps
Loading…
Cancel
Save