|
|
|
@ -16,6 +16,7 @@ from transformers import PretrainedConfig
|
|
|
|
|
from petals.server.block_utils import resolve_block_dtype, get_model_block
|
|
|
|
|
from petals.utils.convert_block import QuantType, convert_block
|
|
|
|
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
|
|
|
|
from petals.utils.misc import DUMMY_KEY_PAST
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
@ -205,15 +206,21 @@ def measure_compute_rps(
|
|
|
|
|
block = block.to(dtype)
|
|
|
|
|
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
|
|
|
|
|
|
|
|
|
|
cache = None
|
|
|
|
|
cache = (DUMMY_KEY_PAST, DUMMY_KEY_PAST)
|
|
|
|
|
elapsed = 0
|
|
|
|
|
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
|
|
|
|
|
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
|
|
|
|
|
|
|
|
|
|
def step(cache_):
|
|
|
|
|
outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None)
|
|
|
|
|
return outputs[1] if inference else None
|
|
|
|
|
|
|
|
|
|
cache = step(cache)
|
|
|
|
|
# Skip the 1st step to exclude the initialization time
|
|
|
|
|
synchronize(device)
|
|
|
|
|
|
|
|
|
|
start_time = time.perf_counter()
|
|
|
|
|
for _ in range(n_steps):
|
|
|
|
|
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
|
|
|
|
|
cache = step(cache)
|
|
|
|
|
synchronize(device)
|
|
|
|
|
elapsed = time.perf_counter() - start_time
|
|
|
|
|
device_rps = n_steps * n_tokens / elapsed
|
|
|
|
|