|
|
|
@ -34,6 +34,12 @@ def get_host_throughput(
|
|
|
|
|
cache_path: str = DEFAULT_CACHE_PATH,
|
|
|
|
|
lock_path: str = DEFAULT_LOCK_PATH,
|
|
|
|
|
) -> float:
|
|
|
|
|
# Resolve default dtypes
|
|
|
|
|
if dtype == "auto" or dtype is None:
|
|
|
|
|
dtype = config.torch_dtype
|
|
|
|
|
if dtype == "auto" or dtype is None:
|
|
|
|
|
dtype = torch.float32
|
|
|
|
|
|
|
|
|
|
# We use the system-wide lock since only one process at a time can measure the host throughput
|
|
|
|
|
os.makedirs(lock_path.parent, exist_ok=True)
|
|
|
|
|
with open(lock_path, "wb") as lock_fd:
|
|
|
|
@ -42,8 +48,8 @@ def get_host_throughput(
|
|
|
|
|
# The OS will release the lock when lock_fd is closed or the process is killed
|
|
|
|
|
|
|
|
|
|
cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}"
|
|
|
|
|
cache_key += f"_device_{_get_device_name(device).replace(' ', '_')}"
|
|
|
|
|
cache_key += f"_dtype_{_get_dtype_name(dtype, load_in_8bit)}"
|
|
|
|
|
cache_key += f"_device_{get_device_name(device).replace(' ', '_')}"
|
|
|
|
|
cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}"
|
|
|
|
|
|
|
|
|
|
cache = {}
|
|
|
|
|
try:
|
|
|
|
@ -71,7 +77,7 @@ def get_host_throughput(
|
|
|
|
|
def measure_throughput_info(
|
|
|
|
|
config: BloomConfig,
|
|
|
|
|
device: torch.device,
|
|
|
|
|
dtype: Union[str, torch.dtype],
|
|
|
|
|
dtype: torch.dtype,
|
|
|
|
|
*,
|
|
|
|
|
load_in_8bit: bool,
|
|
|
|
|
) -> float:
|
|
|
|
@ -107,7 +113,7 @@ def measure_network_rps(config: BloomConfig) -> float:
|
|
|
|
|
def measure_compute_rps(
|
|
|
|
|
config: BloomConfig,
|
|
|
|
|
device: torch.device,
|
|
|
|
|
dtype: Union[str, torch.dtype],
|
|
|
|
|
dtype: torch.dtype,
|
|
|
|
|
*,
|
|
|
|
|
load_in_8bit: bool,
|
|
|
|
|
n_tokens: int = 16,
|
|
|
|
@ -115,10 +121,7 @@ def measure_compute_rps(
|
|
|
|
|
layer_index: int = 0,
|
|
|
|
|
) -> float:
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
block = BloomBlock(config, layer_index)
|
|
|
|
|
if dtype != "auto":
|
|
|
|
|
block = block.to(dtype)
|
|
|
|
|
input_dtype = block.input_layernorm.weight.dtype
|
|
|
|
|
block = BloomBlock(config, layer_index).to(dtype)
|
|
|
|
|
if load_in_8bit:
|
|
|
|
|
block = replace_8bit_linear(block)
|
|
|
|
|
block = block.to(device)
|
|
|
|
@ -126,8 +129,8 @@ def measure_compute_rps(
|
|
|
|
|
cache = None
|
|
|
|
|
elapsed = 0
|
|
|
|
|
for step in range(n_steps + 1):
|
|
|
|
|
dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=input_dtype)
|
|
|
|
|
alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=input_dtype)
|
|
|
|
|
dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
|
|
|
|
|
alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
start_time = time.perf_counter()
|
|
|
|
|
_, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
|
|
|
|
@ -136,15 +139,15 @@ def measure_compute_rps(
|
|
|
|
|
device_rps = n_steps * n_tokens / elapsed
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Forward pass throughput ({_get_device_name(device)}, {_get_dtype_name(dtype, load_in_8bit)}): "
|
|
|
|
|
f"Forward pass throughput ({get_device_name(device)}, {get_dtype_name(dtype, load_in_8bit)}): "
|
|
|
|
|
f"{device_rps:.1f} RPS"
|
|
|
|
|
)
|
|
|
|
|
return device_rps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_device_name(device: torch.device) -> str:
|
|
|
|
|
def get_device_name(device: torch.device) -> str:
|
|
|
|
|
return f"{torch.cuda.get_device_name(device)} GPU" if device == "cuda" else "CPU"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str:
|
|
|
|
|
def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str:
|
|
|
|
|
return "8-bit" if load_in_8bit else str(dtype)
|
|
|
|
|