@ -1,7 +1,6 @@
import fcntl
import json
import os
import subprocess
import time
from hashlib import sha256
from pathlib import Path
@ -18,6 +17,19 @@ from petals.utils.disk_cache import DEFAULT_CACHE_DIR
logger = get_logger ( __file__ )
try :
import speedtest
except ImportError :
raise ImportError ( " Please `pip install speedtest-cli==2.1.3` " )
if not hasattr ( speedtest , " Speedtest " ) :
raise ImportError (
" You are using the wrong speedtest module. Please replace speedtest with speedtest-cli. \n "
" To do that, run `pip uninstall -y speedtest`. Depending on your python environment, "
" you may need to run uninstall speedtest two or more times, until it says ' not installed ' . \n "
" After that, please `pip install speedtest-cli==2.1.3` to install the correct version. "
)
def get_host_throughput (
config : BloomConfig ,
@ -88,10 +100,16 @@ def measure_throughput_info(
def measure_network_rps ( config : BloomConfig ) - > float :
proc = subprocess . run ( " python3 -m petals.cli.speed_test --json " , shell = True , capture_output = True )
if proc . returncode != 0 :
raise RuntimeError ( f " Failed to measure network throughput (stdout: { proc . stdout } , stderr: { proc . stderr } ) " )
network_info = json . loads ( proc . stdout )
try :
s = speedtest . Speedtest ( )
s . get_servers ( )
s . get_best_server ( )
s . download ( )
s . upload ( )
network_info = s . results . dict ( )
except :
logger . error ( " Failed to measure network throughput: " )
raise
bits_per_request = config . hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
network_rps = min ( network_info [ " download " ] , network_info [ " upload " ] ) / bits_per_request