@ -13,9 +13,10 @@ import torch.mps
from hivemind . utils . logging import get_logger
from hivemind . utils . logging import get_logger
from transformers import PretrainedConfig
from transformers import PretrainedConfig
from petals . server . block_utils import resolve_block_dtype
from petals . server . block_utils import get_model_block, resolve_block_dtype
from petals . utils . convert_block import QuantType , convert_block
from petals . utils . convert_block import QuantType , convert_block
from petals . utils . disk_cache import DEFAULT_CACHE_DIR
from petals . utils . disk_cache import DEFAULT_CACHE_DIR
from petals . utils . misc import DUMMY_KEY_PAST
logger = get_logger ( __name__ )
logger = get_logger ( __name__ )
@ -201,18 +202,25 @@ def measure_compute_rps(
if not tensor_parallel_devices :
if not tensor_parallel_devices :
tensor_parallel_devices = ( device , )
tensor_parallel_devices = ( device , )
with torch . inference_mode ( ) :
with torch . inference_mode ( ) :
block = config . block_class ( config ) . to ( dtype )
block = get_model_block ( config )
block = block . to ( dtype )
block = convert_block ( block , 0 , config , tensor_parallel_devices , device , quant_type = quant_type , freeze = True )
block = convert_block ( block , 0 , config , tensor_parallel_devices , device , quant_type = quant_type , freeze = True )
cache = None
cache = ( DUMMY_KEY_PAST . to ( dtype ) , DUMMY_KEY_PAST . to ( dtype ) )
elapsed = 0
elapsed = 0
dummy_input = torch . randn ( 1 , n_tokens , config . hidden_size , device = device , dtype = dtype )
dummy_input = torch . randn ( 1 , n_tokens , config . hidden_size , device = device , dtype = dtype )
_ , cache = block . forward ( dummy_input , use_cache = True ) # Skip the 1st step to exclude the initialization time
# Skip the 1st step to exclude the initialization time
def step ( cache_ ) :
outputs = block . forward ( dummy_input , use_cache = inference , layer_past = cache_ if inference else None )
return outputs [ 1 ] if inference else None
cache = step ( cache )
synchronize ( device )
synchronize ( device )
start_time = time . perf_counter ( )
start_time = time . perf_counter ( )
for _ in range ( n_steps ) :
for _ in range ( n_steps ) :
_ , cache = block . forward ( dummy_input , use_cache = True , layer_past = cache if inference else None )
cache = step ( cach e)
synchronize ( device )
synchronize ( device )
elapsed = time . perf_counter ( ) - start_time
elapsed = time . perf_counter ( ) - start_time
device_rps = n_steps * n_tokens / elapsed
device_rps = n_steps * n_tokens / elapsed