Add dry_run option to --throughput

no_qkv_merge
Max Ryabinin 9 months ago
parent c666a975d0
commit b2ab84cc33

@ -106,12 +106,13 @@ def main():
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.") "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
parser.add_argument('--throughput', parser.add_argument('--throughput',
type=lambda value: value if value in ['auto', 'eval'] else float(value), type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value),
default='auto', default='auto',
help='Expected server throughput (a float measured in RPS). ' help='Expected server throughput (a float measured in RPS). '
'If set to "auto" (default), the script evaluates network and compute throughput ' 'If set to "auto" (default), the script evaluates network and compute throughput '
'on the first run and uses these estimates for future runs. ' 'on the first run and uses these estimates for future runs. '
'If set to "eval", the script re-evaluates the throughput and overrides the cache.') 'If set to "eval", the script re-evaluates the throughput and overrides the cache. '
'If set to "dry_run", the script re-evaluates the throughput and exits.')
parser.add_argument('--update_period', type=float, required=False, default=120, parser.add_argument('--update_period', type=float, required=False, default=120,
help='Server will report blocks to DHT once in this many seconds') help='Server will report blocks to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None, parser.add_argument('--expiration', type=float, required=False, default=None,

@ -5,6 +5,7 @@ import math
import multiprocessing as mp import multiprocessing as mp
import os import os
import random import random
import sys
import threading import threading
import time import time
from typing import Dict, List, Optional, Sequence, Union from typing import Dict, List, Optional, Sequence, Union
@ -234,8 +235,9 @@ class Server:
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
assert isinstance(throughput, float) or throughput in ["auto", "eval"] assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
if throughput in ["auto", "eval"]: if throughput in ["auto", "eval", "dry_run"]:
force_eval = throughput in ["eval", "dry_run"]
throughput_info = get_server_throughput( throughput_info = get_server_throughput(
converted_model_name_or_path, converted_model_name_or_path,
self.block_config, self.block_config,
@ -245,9 +247,12 @@ class Server:
quant_type=quant_type, quant_type=quant_type,
tensor_parallel_devices=self.tensor_parallel_devices, tensor_parallel_devices=self.tensor_parallel_devices,
reachable_via_relay=reachable_via_relay, reachable_via_relay=reachable_via_relay,
force_eval=(throughput == "eval"), force_eval=force_eval,
cache_dir=cache_dir, cache_dir=cache_dir,
) )
if throughput == "dry_run":
logger.info("Finished estimating throughput, exiting")
sys.exit(0)
else: else:
throughput_info = {"throughput": throughput} throughput_info = {"throughput": throughput}
self.server_info = ServerInfo( self.server_info = ServerInfo(

Loading…
Cancel
Save