|
|
|
@ -2,11 +2,10 @@ import fcntl
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import subprocess
|
|
|
|
|
import tempfile
|
|
|
|
|
import time
|
|
|
|
|
from hashlib import sha256
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Union
|
|
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
@ -15,15 +14,12 @@ from petals.bloom.block import BloomBlock
|
|
|
|
|
from petals.bloom.model import BloomConfig
|
|
|
|
|
from petals.bloom.ops import build_alibi_tensor
|
|
|
|
|
from petals.utils.convert_8bit import replace_8bit_linear
|
|
|
|
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
|
|
|
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", "petals", "throughput_v2.json")
|
|
|
|
|
DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), "petals", "throughput.lock")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_host_throughput(
|
|
|
|
|
config: BloomConfig,
|
|
|
|
|
device: torch.device,
|
|
|
|
@ -31,8 +27,7 @@ def get_host_throughput(
|
|
|
|
|
*,
|
|
|
|
|
load_in_8bit: bool,
|
|
|
|
|
force_eval: bool = False,
|
|
|
|
|
cache_path: str = DEFAULT_CACHE_PATH,
|
|
|
|
|
lock_path: str = DEFAULT_LOCK_PATH,
|
|
|
|
|
cache_dir: Optional[str] = None,
|
|
|
|
|
) -> float:
|
|
|
|
|
# Resolve default dtypes
|
|
|
|
|
if dtype == "auto" or dtype is None:
|
|
|
|
@ -40,6 +35,11 @@ def get_host_throughput(
|
|
|
|
|
if dtype == "auto" or dtype is None:
|
|
|
|
|
dtype = torch.float32
|
|
|
|
|
|
|
|
|
|
if cache_dir is None:
|
|
|
|
|
cache_dir = DEFAULT_CACHE_DIR
|
|
|
|
|
lock_path = Path(cache_dir, "throughput.lock")
|
|
|
|
|
cache_path = Path(cache_dir, "throughput_v2.json")
|
|
|
|
|
|
|
|
|
|
# We use the system-wide lock since only one process at a time can measure the host throughput
|
|
|
|
|
os.makedirs(lock_path.parent, exist_ok=True)
|
|
|
|
|
with open(lock_path, "wb") as lock_fd:
|
|
|
|
|