From b2ab84cc3347575dae4d43224caa127b3a53736c Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Sep 2023 00:41:03 +0300 Subject: [PATCH] Add dry_run option to --throughput --- src/petals/cli/run_server.py | 5 +++-- src/petals/server/server.py | 11 ++++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 3728c16..94f5c2e 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -106,12 +106,13 @@ def main(): "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.") parser.add_argument('--throughput', - type=lambda value: value if value in ['auto', 'eval'] else float(value), + type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value), default='auto', help='Expected server throughput (a float measured in RPS). ' 'If set to "auto" (default), the script evaluates network and compute throughput ' 'on the first run and uses these estimates for future runs. ' - 'If set to "eval", the script re-evaluates the throughput and overrides the cache.') + 'If set to "eval", the script re-evaluates the throughput and overrides the cache. ' + 'If set to "dry_run", the script re-evaluates the throughput and exits.') parser.add_argument('--update_period', type=float, required=False, default=120, help='Server will report blocks to DHT once in this many seconds') parser.add_argument('--expiration', type=float, required=False, default=None, diff --git a/src/petals/server/server.py b/src/petals/server/server.py index ab646a5..a5f2ba0 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -5,6 +5,7 @@ import math import multiprocessing as mp import os import random +import sys import threading import time from typing import Dict, List, Optional, Sequence, Union @@ -234,8 +235,9 @@ class Server: self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") - assert isinstance(throughput, float) or throughput in ["auto", "eval"] - if throughput in ["auto", "eval"]: + assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"] + if throughput in ["auto", "eval", "dry_run"]: + force_eval = throughput in ["eval", "dry_run"] throughput_info = get_server_throughput( converted_model_name_or_path, self.block_config, @@ -245,9 +247,12 @@ class Server: quant_type=quant_type, tensor_parallel_devices=self.tensor_parallel_devices, reachable_via_relay=reachable_via_relay, - force_eval=(throughput == "eval"), + force_eval=force_eval, cache_dir=cache_dir, ) + if throughput == "dry_run": + logger.info("Finished estimating throughput, exiting") + sys.exit(0) else: throughput_info = {"throughput": throughput} self.server_info = ServerInfo(