Suppress quantization warning and fix dtype defaults in compute benchmark (#117)

pull/120/head
Alexander Borzunov 1 year ago committed by GitHub
parent 643a054170
commit f72c220404
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -107,9 +107,10 @@ def main():
parser.add_argument("--mean_balance_check_period", type=float, default=60,
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
parser.add_argument('--load_in_8bit', type=bool, default=None,
help="Convert the loaded model into mixed-8bit quantized model. Default: True if GPU is available")
parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained")
parser.add_argument('--load_in_8bit', type=str, default=None,
help="Convert the loaded model into mixed-8bit quantized model. "
"Default: True if GPU is available. Use `--load_in_8bit False` to disable this")
# fmt:on
args = vars(parser.parse_args())
@ -133,8 +134,9 @@ def main():
if args.pop("new_swarm"):
args["initial_peers"] = []
use_auth_token = args.pop("use_auth_token")
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
load_in_8bit = args.pop("load_in_8bit")
if load_in_8bit is not None:
args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"]
server = Server(**args, compression=compression, attn_cache_size=attn_cache_size)
try:

@ -34,6 +34,12 @@ def get_host_throughput(
cache_path: str = DEFAULT_CACHE_PATH,
lock_path: str = DEFAULT_LOCK_PATH,
) -> float:
# Resolve default dtypes
if dtype == "auto" or dtype is None:
dtype = config.torch_dtype
if dtype == "auto" or dtype is None:
dtype = torch.float32
# We use the system-wide lock since only one process at a time can measure the host throughput
os.makedirs(lock_path.parent, exist_ok=True)
with open(lock_path, "wb") as lock_fd:
@ -42,8 +48,8 @@ def get_host_throughput(
# The OS will release the lock when lock_fd is closed or the process is killed
cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}"
cache_key += f"_device_{_get_device_name(device).replace(' ', '_')}"
cache_key += f"_dtype_{_get_dtype_name(dtype, load_in_8bit)}"
cache_key += f"_device_{get_device_name(device).replace(' ', '_')}"
cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}"
cache = {}
try:
@ -71,7 +77,7 @@ def get_host_throughput(
def measure_throughput_info(
config: BloomConfig,
device: torch.device,
dtype: Union[str, torch.dtype],
dtype: torch.dtype,
*,
load_in_8bit: bool,
) -> float:
@ -107,7 +113,7 @@ def measure_network_rps(config: BloomConfig) -> float:
def measure_compute_rps(
config: BloomConfig,
device: torch.device,
dtype: Union[str, torch.dtype],
dtype: torch.dtype,
*,
load_in_8bit: bool,
n_tokens: int = 16,
@ -115,10 +121,7 @@ def measure_compute_rps(
layer_index: int = 0,
) -> float:
with torch.inference_mode():
block = BloomBlock(config, layer_index)
if dtype != "auto":
block = block.to(dtype)
input_dtype = block.input_layernorm.weight.dtype
block = BloomBlock(config, layer_index).to(dtype)
if load_in_8bit:
block = replace_8bit_linear(block)
block = block.to(device)
@ -126,8 +129,8 @@ def measure_compute_rps(
cache = None
elapsed = 0
for step in range(n_steps + 1):
dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=input_dtype)
alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=input_dtype)
dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=dtype)
start_time = time.perf_counter()
_, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
@ -136,15 +139,15 @@ def measure_compute_rps(
device_rps = n_steps * n_tokens / elapsed
logger.info(
f"Forward pass throughput ({_get_device_name(device)}, {_get_dtype_name(dtype, load_in_8bit)}): "
f"Forward pass throughput ({get_device_name(device)}, {get_dtype_name(dtype, load_in_8bit)}): "
f"{device_rps:.1f} RPS"
)
return device_rps
def _get_device_name(device: torch.device) -> str:
def get_device_name(device: torch.device) -> str:
return f"{torch.cuda.get_device_name(device)} GPU" if device == "cuda" else "CPU"
def _get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str:
def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str:
return "8-bit" if load_in_8bit else str(dtype)

@ -9,7 +9,7 @@ Based on: https://github.com/TimDettmers/bitsandbytes/blob/main/csrc/kernels.cu#
Exact match tests: see $REPO/tests/test_linear8bitlt.py
"""
import dataclasses
import warnings
import logging
from typing import Optional, Tuple
import bitsandbytes.functional as F
@ -155,7 +155,7 @@ class CustomMatMul8bitLt(MatMul8bitLt):
# Cast A to fp16
if A.dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
logging.debug(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
# 1. Quantize A
if len(A.shape) == 3:

Loading…
Cancel
Save