Report inference, forward, and network RPS separately (#358)

Inference RPS may be very different from forward RPS. E.g., currently bnb uses a completely different algorithm for NF4 inference. We report detailed RPS info that can be then used for shortest-path routing for inference.
pull/364/head
Alexander Borzunov 10 months ago committed by GitHub
parent 9517dd1e3d
commit 11f0d992d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -73,12 +73,15 @@ class RemoteSequenceInfo:
active_spans = {}
for block_index, info in enumerate(block_infos):
if info is not None:
for peer_id, server in info.servers.items():
if server.state != ServerState.ONLINE:
for peer_id, server_info in info.servers.items():
if server_info.state != ServerState.ONLINE:
continue
if peer_id not in active_spans:
active_spans[peer_id] = RemoteSpanInfo(
peer_id=peer_id, start=block_index, end=block_index + 1, throughput=server.throughput
peer_id=peer_id,
start=block_index,
end=block_index + 1,
server_info=server_info,
)
else: # peer_id in active_spans
active_spans[peer_id].end = block_index + 1

@ -151,7 +151,7 @@ class RemoteSequenceManager:
raise MissingBlocksError(current_index)
if mode == "max_throughput":
span_weights = np.array([span.throughput for span in candidate_spans], dtype=np.float64)
span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
elif mode == "min_latency":
span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64)
else:

@ -19,10 +19,17 @@ class ServerState(Enum):
ONLINE = 2
RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
@pydantic.dataclasses.dataclass
class ServerInfo:
state: ServerState
throughput: pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
throughput: RPS
network_rps: Optional[RPS] = None
forward_rps: Optional[RPS] = None
inference_rps: Optional[RPS] = None
adapters: Sequence[str] = ()
version: Optional[str] = None
@ -60,7 +67,7 @@ class RemoteSpanInfo:
peer_id: PeerID
start: int
end: int
throughput: float
server_info: ServerInfo
@property
def length(self):

@ -206,7 +206,7 @@ class Server:
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
throughput = get_server_throughput(
throughput_info = get_server_throughput(
converted_model_name_or_path,
self.block_config,
device,
@ -217,14 +217,16 @@ class Server:
force_eval=(throughput == "eval"),
cache_dir=cache_dir,
)
else:
throughput_info = {"throughput": throughput}
self.server_info = ServerInfo(
state=ServerState.JOINING,
throughput=throughput,
adapters=tuple(adapters),
version=petals.__version__,
torch_dtype=str(torch_dtype).replace("torch.", ""),
quant_type=quant_type.name.lower(),
using_relay=self.dht.client_mode,
**throughput_info,
)
self.balance_quality = balance_quality

@ -43,13 +43,13 @@ def get_server_throughput(
tensor_parallel_devices: Sequence[torch.device],
force_eval: bool = False,
cache_dir: Optional[str] = None,
) -> float:
) -> Dict[str, float]:
dtype = resolve_block_dtype(config, dtype)
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
lock_path = Path(cache_dir, "throughput.lock")
cache_path = Path(cache_dir, "throughput_v3.json")
cache_path = Path(cache_dir, "throughput_v4.json")
# We use the system-wide lock since only one process at a time can measure the host throughput
os.makedirs(lock_path.parent, exist_ok=True)
@ -93,10 +93,12 @@ def get_server_throughput(
# Assuming the start block index is distributed uniformly, the average number of blocks used per request is
# E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2
average_blocks_used = (num_blocks + 1) / 2
throughput = throughput_info["compute_rps"] / average_blocks_used
throughput = throughput_info["forward_rps"] / average_blocks_used
throughput = min(throughput, throughput_info.get("network_rps", math.inf))
throughput_info["throughput"] = throughput
logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks")
return throughput
return throughput_info
def measure_throughput_info(
@ -114,15 +116,31 @@ def measure_throughput_info(
)
throughput_info = {
"compute_rps": measure_compute_rps(
config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices
)
"inference_rps": measure_compute_rps(
config,
device,
dtype,
quant_type=quant_type,
tensor_parallel_devices=tensor_parallel_devices,
n_tokens=1,
n_steps=100,
inference=True,
),
"forward_rps": measure_compute_rps(
config,
device,
dtype,
quant_type=quant_type,
tensor_parallel_devices=tensor_parallel_devices,
n_tokens=1024,
n_steps=10,
inference=False,
),
}
try:
throughput_info["network_rps"] = measure_network_rps(config)
except Exception as e:
logger.warning(f"Failed to measure network throughput: {repr(e)}")
logger.warning("Proceeding with the compute throughput only")
logger.info(f"Network throughput is not available: {e}")
return throughput_info
@ -135,6 +153,8 @@ def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Opt
process.terminate()
raise RuntimeError(f"speedtest did not finish in {timeout} seconds")
network_info = pipe_recv.recv()
if "exception" in network_info:
raise RuntimeError(f"speedtest failed: {network_info['exception']}")
bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
@ -150,12 +170,15 @@ def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Opt
def _measure_bits_per_second(pipe_send: mp.Pipe):
s = speedtest.Speedtest()
s.get_servers()
s.get_best_server()
s.download()
s.upload()
pipe_send.send(s.results.dict())
try:
s = speedtest.Speedtest()
s.get_servers()
s.get_best_server()
s.download()
s.upload()
pipe_send.send(s.results.dict())
except Exception as e:
pipe_send.send({"exception": repr(e)})
def measure_compute_rps(
@ -165,8 +188,9 @@ def measure_compute_rps(
*,
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
n_tokens: int = 16,
n_steps: int = 500,
n_tokens: int,
n_steps: int,
inference: bool,
) -> float:
if not tensor_parallel_devices:
tensor_parallel_devices = (device,)
@ -180,7 +204,7 @@ def measure_compute_rps(
dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
start_time = time.perf_counter()
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache)
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
if step >= 1: # Skip the 1st step to exclude the initialization time
elapsed += time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed
@ -191,8 +215,8 @@ def measure_compute_rps(
devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())
logger.info(
f"Forward pass throughput: {device_rps:.1f} RPS per block "
f"({devices_repr}, {get_dtype_name(dtype, quant_type)})"
f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} RPS per block "
f"({n_tokens} tokens/batch, {devices_repr}, {get_dtype_name(dtype, quant_type)})"
)
return device_rps
@ -202,7 +226,7 @@ def get_device_name(device: torch.device) -> str:
def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:
name = str(dtype)
name = str(dtype).replace("torch.", "")
if quant_type != QuantType.NONE:
name += f", quantized to {quant_type.name.lower()}"
return name

@ -16,7 +16,7 @@ async def ping(
_dht: hivemind.DHT,
node: hivemind.dht.DHTNode,
*,
wait_timeout: float = 1,
wait_timeout: float = 5,
) -> float:
try:
ping_request = dht_pb2.PingRequest(peer=node.protocol.node_info)

@ -24,8 +24,10 @@ def test_bnb_not_imported_when_unnecessary():
@pytest.mark.forked
@pytest.mark.parametrize("inference", [False, True])
@pytest.mark.parametrize("n_tokens", [1, 16])
@pytest.mark.parametrize("tensor_parallel", [False, True])
def test_compute_throughput(tensor_parallel: bool):
def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
compute_rps = measure_compute_rps(
@ -34,6 +36,8 @@ def test_compute_throughput(tensor_parallel: bool):
dtype=torch.bfloat16,
quant_type=QuantType.NONE,
tensor_parallel_devices=tensor_parallel_devices,
n_steps=10,
n_tokens=n_tokens,
n_steps=5,
inference=inference,
)
assert isinstance(compute_rps, float) and compute_rps > 0

Loading…
Cancel
Save