|
|
|
@ -5,6 +5,7 @@ import math
|
|
|
|
|
import multiprocessing as mp
|
|
|
|
|
import os
|
|
|
|
|
import random
|
|
|
|
|
import sys
|
|
|
|
|
import threading
|
|
|
|
|
import time
|
|
|
|
|
from typing import Dict, List, Optional, Sequence, Union
|
|
|
|
@ -234,8 +235,9 @@ class Server:
|
|
|
|
|
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
|
|
|
|
|
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
|
|
|
|
|
|
|
|
|
|
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
|
|
|
|
if throughput in ["auto", "eval"]:
|
|
|
|
|
assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
|
|
|
|
|
if throughput in ["auto", "eval", "dry_run"]:
|
|
|
|
|
force_eval = throughput in ["eval", "dry_run"]
|
|
|
|
|
throughput_info = get_server_throughput(
|
|
|
|
|
converted_model_name_or_path,
|
|
|
|
|
self.block_config,
|
|
|
|
@ -245,9 +247,12 @@ class Server:
|
|
|
|
|
quant_type=quant_type,
|
|
|
|
|
tensor_parallel_devices=self.tensor_parallel_devices,
|
|
|
|
|
reachable_via_relay=reachable_via_relay,
|
|
|
|
|
force_eval=(throughput == "eval"),
|
|
|
|
|
force_eval=force_eval,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
)
|
|
|
|
|
if throughput == "dry_run":
|
|
|
|
|
logger.info("Finished estimating throughput, exiting")
|
|
|
|
|
sys.exit(0)
|
|
|
|
|
else:
|
|
|
|
|
throughput_info = {"throughput": throughput}
|
|
|
|
|
self.server_info = ServerInfo(
|
|
|
|
|