Ignore network RPS if we failed to measure it (#198)

pull/197/head^2
Alexander Borzunov 1 year ago committed by GitHub
parent 487411e87e
commit 127cf66bee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -101,25 +101,25 @@ def measure_throughput_info(
logger.info(
"Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
)
return min(
measure_network_rps(config),
measure_compute_rps(
config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
),
)
def measure_network_rps(config: BloomConfig) -> float:
result = measure_compute_rps(
config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
)
try:
s = speedtest.Speedtest()
s.get_servers()
s.get_best_server()
s.download()
s.upload()
network_info = s.results.dict()
except:
logger.error("Failed to measure network throughput:")
raise
result = min(result, measure_network_rps(config))
except Exception:
logger.warning("Failed to measure network throughput:", exc_info=True)
logger.warning("Proceeding with the compute throughput only")
return result
def measure_network_rps(config: BloomConfig) -> Optional[float]:
s = speedtest.Speedtest()
s.get_servers()
s.get_best_server()
s.download()
s.upload()
network_info = s.results.dict()
bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request

Loading…
Cancel
Save