Fix missing torch.cuda.synchronize for computing throughput (#456)

pull/402/head^2
justheuristic 9 months ago committed by GitHub
parent 0e7189b3ed
commit 55eb36ef48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,7 +51,7 @@ def get_server_throughput(
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
lock_path = Path(cache_dir, "throughput.lock")
cache_path = Path(cache_dir, "throughput_v4.json")
cache_path = Path(cache_dir, "throughput_v5.json")
# We use the system-wide lock since only one process at a time can measure the host throughput
os.makedirs(lock_path.parent, exist_ok=True)
@ -196,6 +196,7 @@ def measure_compute_rps(
n_steps: int,
inference: bool,
) -> float:
device = torch.device(device)
if not tensor_parallel_devices:
tensor_parallel_devices = (device,)
with torch.inference_mode():
@ -204,13 +205,17 @@ def measure_compute_rps(
cache = None
elapsed = 0
for step in range(n_steps + 1):
dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
if device.type == "cuda":
torch.cuda.synchronize(device)
start_time = time.perf_counter()
start_time = time.perf_counter()
for step in range(n_steps):
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
if step >= 1: # Skip the 1st step to exclude the initialization time
elapsed += time.perf_counter() - start_time
if device.type == "cuda":
torch.cuda.synchronize(device)
elapsed = time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed
devices_repr = get_device_name(device)

Loading…
Cancel
Save