diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 448b704..9f0a740 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -107,9 +107,10 @@ def main(): parser.add_argument("--mean_balance_check_period", type=float, default=60, help="Check the swarm's balance every N seconds (and rebalance it if necessary)") - parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained") - parser.add_argument('--load_in_8bit', type=bool, default=None, - help="Convert the loaded model into mixed-8bit quantized model. Default: True if GPU is available") + parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained") + parser.add_argument('--load_in_8bit', type=str, default=None, + help="Convert the loaded model into mixed-8bit quantized model. " + "Default: True if GPU is available. Use `--load_in_8bit False` to disable this") # fmt:on args = vars(parser.parse_args()) @@ -133,8 +134,9 @@ def main(): if args.pop("new_swarm"): args["initial_peers"] = [] - use_auth_token = args.pop("use_auth_token") - args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token + load_in_8bit = args.pop("load_in_8bit") + if load_in_8bit is not None: + args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"] server = Server(**args, compression=compression, attn_cache_size=attn_cache_size) try: diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index f9f94cd..0881683 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -34,6 +34,12 @@ def get_host_throughput( cache_path: str = DEFAULT_CACHE_PATH, lock_path: str = DEFAULT_LOCK_PATH, ) -> float: + # Resolve default dtypes + if dtype == "auto" or dtype is None: + dtype = config.torch_dtype + if dtype == "auto" or dtype is None: + dtype = torch.float32 + # We use the system-wide lock since only one process at a time can measure the host throughput os.makedirs(lock_path.parent, exist_ok=True) with open(lock_path, "wb") as lock_fd: @@ -42,8 +48,8 @@ def get_host_throughput( # The OS will release the lock when lock_fd is closed or the process is killed cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}" - cache_key += f"_device_{_get_device_name(device).replace(' ', '_')}" - cache_key += f"_dtype_{_get_dtype_name(dtype, load_in_8bit)}" + cache_key += f"_device_{get_device_name(device).replace(' ', '_')}" + cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}" cache = {} try: @@ -71,7 +77,7 @@ def get_host_throughput( def measure_throughput_info( config: BloomConfig, device: torch.device, - dtype: Union[str, torch.dtype], + dtype: torch.dtype, *, load_in_8bit: bool, ) -> float: @@ -107,7 +113,7 @@ def measure_network_rps(config: BloomConfig) -> float: def measure_compute_rps( config: BloomConfig, device: torch.device, - dtype: Union[str, torch.dtype], + dtype: torch.dtype, *, load_in_8bit: bool, n_tokens: int = 16, @@ -115,10 +121,7 @@ def measure_compute_rps( layer_index: int = 0, ) -> float: with torch.inference_mode(): - block = BloomBlock(config, layer_index) - if dtype != "auto": - block = block.to(dtype) - input_dtype = block.input_layernorm.weight.dtype + block = BloomBlock(config, layer_index).to(dtype) if load_in_8bit: block = replace_8bit_linear(block) block = block.to(device) @@ -126,8 +129,8 @@ def measure_compute_rps( cache = None elapsed = 0 for step in range(n_steps + 1): - dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=input_dtype) - alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=input_dtype) + dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype) + alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=dtype) start_time = time.perf_counter() _, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache) @@ -136,15 +139,15 @@ def measure_compute_rps( device_rps = n_steps * n_tokens / elapsed logger.info( - f"Forward pass throughput ({_get_device_name(device)}, {_get_dtype_name(dtype, load_in_8bit)}): " + f"Forward pass throughput ({get_device_name(device)}, {get_dtype_name(dtype, load_in_8bit)}): " f"{device_rps:.1f} RPS" ) return device_rps -def _get_device_name(device: torch.device) -> str: +def get_device_name(device: torch.device) -> str: return f"{torch.cuda.get_device_name(device)} GPU" if device == "cuda" else "CPU" -def _get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str: +def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str: return "8-bit" if load_in_8bit else str(dtype) diff --git a/src/petals/utils/linear8bitlt_patch.py b/src/petals/utils/linear8bitlt_patch.py index a8b3362..523436f 100644 --- a/src/petals/utils/linear8bitlt_patch.py +++ b/src/petals/utils/linear8bitlt_patch.py @@ -9,7 +9,7 @@ Based on: https://github.com/TimDettmers/bitsandbytes/blob/main/csrc/kernels.cu# Exact match tests: see $REPO/tests/test_linear8bitlt.py """ import dataclasses -import warnings +import logging from typing import Optional, Tuple import bitsandbytes.functional as F @@ -155,7 +155,7 @@ class CustomMatMul8bitLt(MatMul8bitLt): # Cast A to fp16 if A.dtype != torch.float16: - warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") + logging.debug(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") # 1. Quantize A if len(A.shape) == 3: