From 55eb36ef4829d61315361049edf85753c631dab8 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 9 Aug 2023 21:59:56 +0300 Subject: [PATCH] Fix missing torch.cuda.synchronize for computing throughput (#456) --- src/petals/server/throughput.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index d977611..2806183 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -51,7 +51,7 @@ def get_server_throughput( if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR lock_path = Path(cache_dir, "throughput.lock") - cache_path = Path(cache_dir, "throughput_v4.json") + cache_path = Path(cache_dir, "throughput_v5.json") # We use the system-wide lock since only one process at a time can measure the host throughput os.makedirs(lock_path.parent, exist_ok=True) @@ -196,6 +196,7 @@ def measure_compute_rps( n_steps: int, inference: bool, ) -> float: + device = torch.device(device) if not tensor_parallel_devices: tensor_parallel_devices = (device,) with torch.inference_mode(): @@ -204,13 +205,17 @@ def measure_compute_rps( cache = None elapsed = 0 - for step in range(n_steps + 1): - dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype) + dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype) + _, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time + if device.type == "cuda": + torch.cuda.synchronize(device) - start_time = time.perf_counter() + start_time = time.perf_counter() + for step in range(n_steps): _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None) - if step >= 1: # Skip the 1st step to exclude the initialization time - elapsed += time.perf_counter() - start_time + if device.type == "cuda": + torch.cuda.synchronize(device) + elapsed = time.perf_counter() - start_time device_rps = n_steps * n_tokens / elapsed devices_repr = get_device_name(device)