|
|
|
@ -101,25 +101,25 @@ def measure_throughput_info(
|
|
|
|
|
logger.info(
|
|
|
|
|
"Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
|
|
|
|
|
)
|
|
|
|
|
return min(
|
|
|
|
|
measure_network_rps(config),
|
|
|
|
|
measure_compute_rps(
|
|
|
|
|
config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def measure_network_rps(config: BloomConfig) -> float:
|
|
|
|
|
result = measure_compute_rps(
|
|
|
|
|
config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
s = speedtest.Speedtest()
|
|
|
|
|
s.get_servers()
|
|
|
|
|
s.get_best_server()
|
|
|
|
|
s.download()
|
|
|
|
|
s.upload()
|
|
|
|
|
network_info = s.results.dict()
|
|
|
|
|
except:
|
|
|
|
|
logger.error("Failed to measure network throughput:")
|
|
|
|
|
raise
|
|
|
|
|
result = min(result, measure_network_rps(config))
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.warning("Failed to measure network throughput:", exc_info=True)
|
|
|
|
|
logger.warning("Proceeding with the compute throughput only")
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def measure_network_rps(config: BloomConfig) -> Optional[float]:
|
|
|
|
|
s = speedtest.Speedtest()
|
|
|
|
|
s.get_servers()
|
|
|
|
|
s.get_best_server()
|
|
|
|
|
s.download()
|
|
|
|
|
s.upload()
|
|
|
|
|
network_info = s.results.dict()
|
|
|
|
|
|
|
|
|
|
bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
|
|
|
|
|
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
|
|
|
|
|