@ -43,13 +43,13 @@ def get_server_throughput(
tensor_parallel_devices : Sequence [ torch . device ] ,
force_eval : bool = False ,
cache_dir : Optional [ str ] = None ,
) - > float :
) - > Dict [ str , float ] :
dtype = resolve_block_dtype ( config , dtype )
if cache_dir is None :
cache_dir = DEFAULT_CACHE_DIR
lock_path = Path ( cache_dir , " throughput.lock " )
cache_path = Path ( cache_dir , " throughput_v 3 .json" )
cache_path = Path ( cache_dir , " throughput_v 4 .json" )
# We use the system-wide lock since only one process at a time can measure the host throughput
os . makedirs ( lock_path . parent , exist_ok = True )
@ -93,10 +93,12 @@ def get_server_throughput(
# Assuming the start block index is distributed uniformly, the average number of blocks used per request is
# E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2
average_blocks_used = ( num_blocks + 1 ) / 2
throughput = throughput_info [ " compute _rps" ] / average_blocks_used
throughput = throughput_info [ " forward _rps" ] / average_blocks_used
throughput = min ( throughput , throughput_info . get ( " network_rps " , math . inf ) )
throughput_info [ " throughput " ] = throughput
logger . info ( f " Reporting throughput: { throughput : .1f } RPS for { num_blocks } blocks " )
return throughput
return throughput_info
def measure_throughput_info (
@ -114,15 +116,31 @@ def measure_throughput_info(
)
throughput_info = {
" compute_rps " : measure_compute_rps (
config , device , dtype , quant_type = quant_type , tensor_parallel_devices = tensor_parallel_devices
)
" inference_rps " : measure_compute_rps (
config ,
device ,
dtype ,
quant_type = quant_type ,
tensor_parallel_devices = tensor_parallel_devices ,
n_tokens = 1 ,
n_steps = 100 ,
inference = True ,
) ,
" forward_rps " : measure_compute_rps (
config ,
device ,
dtype ,
quant_type = quant_type ,
tensor_parallel_devices = tensor_parallel_devices ,
n_tokens = 1024 ,
n_steps = 10 ,
inference = False ,
) ,
}
try :
throughput_info [ " network_rps " ] = measure_network_rps ( config )
except Exception as e :
logger . warning ( f " Failed to measure network throughput: { repr ( e ) } " )
logger . warning ( " Proceeding with the compute throughput only " )
logger . info ( f " Network throughput is not available: { e } " )
return throughput_info
@ -135,6 +153,8 @@ def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Opt
process . terminate ( )
raise RuntimeError ( f " speedtest did not finish in { timeout } seconds " )
network_info = pipe_recv . recv ( )
if " exception " in network_info :
raise RuntimeError ( f " speedtest failed: { network_info [ ' exception ' ] } " )
bits_per_request = config . hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
network_rps = min ( network_info [ " download " ] , network_info [ " upload " ] ) / bits_per_request
@ -150,12 +170,15 @@ def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Opt
def _measure_bits_per_second ( pipe_send : mp . Pipe ) :
s = speedtest . Speedtest ( )
s . get_servers ( )
s . get_best_server ( )
s . download ( )
s . upload ( )
pipe_send . send ( s . results . dict ( ) )
try :
s = speedtest . Speedtest ( )
s . get_servers ( )
s . get_best_server ( )
s . download ( )
s . upload ( )
pipe_send . send ( s . results . dict ( ) )
except Exception as e :
pipe_send . send ( { " exception " : repr ( e ) } )
def measure_compute_rps (
@ -165,8 +188,9 @@ def measure_compute_rps(
* ,
quant_type : QuantType ,
tensor_parallel_devices : Sequence [ torch . device ] ,
n_tokens : int = 16 ,
n_steps : int = 500 ,
n_tokens : int ,
n_steps : int ,
inference : bool ,
) - > float :
if not tensor_parallel_devices :
tensor_parallel_devices = ( device , )
@ -180,7 +204,7 @@ def measure_compute_rps(
dummy_input = torch . randn ( n_tokens , 1 , config . hidden_size , device = device , dtype = dtype )
start_time = time . perf_counter ( )
_ , cache = block . forward ( dummy_input , use_cache = True , layer_past = cache )
_ , cache = block . forward ( dummy_input , use_cache = True , layer_past = cache if inference else None )
if step > = 1 : # Skip the 1st step to exclude the initialization time
elapsed + = time . perf_counter ( ) - start_time
device_rps = n_steps * n_tokens / elapsed
@ -191,8 +215,8 @@ def measure_compute_rps(
devices_repr = " , " . join ( f " { count } x { name } " for name , count in Counter ( device_names ) . most_common ( ) )
logger . info (
f " Forward pass throughput: { device_rps : .1f } RPS per block "
f " ( { devices_repr} , { get_dtype_name ( dtype , quant_type ) } ) "
f " { ' Inference ' if inference else ' Forward pass ' } throughput: { device_rps : .1f } RPS per block "
f " ( { n_tokens} tokens/batch, { devices_repr} , { get_dtype_name ( dtype , quant_type ) } ) "
)
return device_rps
@ -202,7 +226,7 @@ def get_device_name(device: torch.device) -> str:
def get_dtype_name ( dtype : torch . dtype , quant_type : QuantType ) - > str :
name = str ( dtype )
name = str ( dtype ) . replace ( " torch. " , " " )
if quant_type != QuantType . NONE :
name + = f " , quantized to { quant_type . name . lower ( ) } "
return name