Switch to speedtest-cli (#157)

This pullrequest removes custom speed_test code in favour of speedtest-cli module.
This is necessary to ensure that random warnings / print-outs do not mess with our outputs.

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
pull/160/head
justheuristic 1 year ago committed by GitHub
parent 34644f13e1
commit 91898c3c90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -37,6 +37,7 @@ install_requires =
huggingface-hub==0.11.1
transformers==4.25.1
protobuf>=3.20.3,<4.0dev
speedtest-cli==2.1.3
hivemind==1.1.3
humanfriendly
async-timeout>=4.0.2

File diff suppressed because it is too large Load Diff

@ -1,7 +1,6 @@
import fcntl
import json
import os
import subprocess
import time
from hashlib import sha256
from pathlib import Path
@ -18,6 +17,19 @@ from petals.utils.disk_cache import DEFAULT_CACHE_DIR
logger = get_logger(__file__)
try:
import speedtest
except ImportError:
raise ImportError("Please `pip install speedtest-cli==2.1.3`")
if not hasattr(speedtest, "Speedtest"):
raise ImportError(
"You are using the wrong speedtest module. Please replace speedtest with speedtest-cli.\n"
"To do that, run `pip uninstall -y speedtest`. Depending on your python environment, "
"you may need to run uninstall speedtest two or more times, until it says 'not installed'.\n"
"After that, please `pip install speedtest-cli==2.1.3` to install the correct version."
)
def get_host_throughput(
config: BloomConfig,
@ -88,10 +100,16 @@ def measure_throughput_info(
def measure_network_rps(config: BloomConfig) -> float:
proc = subprocess.run("python3 -m petals.cli.speed_test --json", shell=True, capture_output=True)
if proc.returncode != 0:
raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
network_info = json.loads(proc.stdout)
try:
s = speedtest.Speedtest()
s.get_servers()
s.get_best_server()
s.download()
s.upload()
network_info = s.results.dict()
except:
logger.error("Failed to measure network throughput:")
raise
bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request

@ -3,13 +3,15 @@ import torch
from test_utils import MODEL_NAME
from petals.client import DistributedBloomConfig
from petals.server.throughput import measure_compute_rps
from petals.server.throughput import measure_compute_rps, measure_network_rps
@pytest.mark.forked
def test_throughput_basic():
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
throughput = measure_compute_rps(
compute_rps = measure_compute_rps(
config, device=torch.device("cpu"), dtype=torch.bfloat16, load_in_8bit=False, n_steps=10
)
assert isinstance(throughput, float) and throughput > 0
assert isinstance(compute_rps, float) and compute_rps > 0
network_rps = measure_network_rps(config)
assert isinstance(network_rps, float) and network_rps > 0

Loading…
Cancel
Save