|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
import fcntl
|
|
|
|
|
import json
|
|
|
|
|
import math
|
|
|
|
|
import multiprocessing as mp
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
from collections import Counter
|
|
|
|
@ -120,24 +121,26 @@ def measure_throughput_info(
|
|
|
|
|
}
|
|
|
|
|
try:
|
|
|
|
|
throughput_info["network_rps"] = measure_network_rps(config)
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.warning("Failed to measure network throughput:", exc_info=True)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to measure network throughput: {repr(e)}")
|
|
|
|
|
logger.warning("Proceeding with the compute throughput only")
|
|
|
|
|
return throughput_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def measure_network_rps(config: BloomConfig) -> Optional[float]:
|
|
|
|
|
s = speedtest.Speedtest()
|
|
|
|
|
s.get_servers()
|
|
|
|
|
s.get_best_server()
|
|
|
|
|
s.download()
|
|
|
|
|
s.upload()
|
|
|
|
|
network_info = s.results.dict()
|
|
|
|
|
def measure_network_rps(config: BloomConfig, *, timeout: float = 60) -> Optional[float]:
|
|
|
|
|
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
|
|
|
|
process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
|
|
|
|
|
process.start()
|
|
|
|
|
|
|
|
|
|
if not pipe_recv.poll(timeout):
|
|
|
|
|
process.terminate()
|
|
|
|
|
raise RuntimeError(f"speedtest did not finish in {timeout} seconds")
|
|
|
|
|
network_info = pipe_recv.recv()
|
|
|
|
|
|
|
|
|
|
bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
|
|
|
|
|
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
|
|
|
|
|
if network_rps == 0:
|
|
|
|
|
raise ValueError("speedtest has returned network_rps == 0")
|
|
|
|
|
raise RuntimeError("speedtest has returned network_rps == 0")
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Network throughput: {network_rps:.1f} RPS "
|
|
|
|
@ -147,6 +150,15 @@ def measure_network_rps(config: BloomConfig) -> Optional[float]:
|
|
|
|
|
return network_rps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _measure_bits_per_second(pipe_send: mp.Pipe):
|
|
|
|
|
s = speedtest.Speedtest()
|
|
|
|
|
s.get_servers()
|
|
|
|
|
s.get_best_server()
|
|
|
|
|
s.download()
|
|
|
|
|
s.upload()
|
|
|
|
|
pipe_send.send(s.results.dict())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def measure_compute_rps(
|
|
|
|
|
config: BloomConfig,
|
|
|
|
|
device: torch.device,
|
|
|
|
|