diff --git a/src/petals/client/routing/sequence_info.py b/src/petals/client/routing/sequence_info.py index b35b02b..bce6712 100644 --- a/src/petals/client/routing/sequence_info.py +++ b/src/petals/client/routing/sequence_info.py @@ -73,12 +73,15 @@ class RemoteSequenceInfo: active_spans = {} for block_index, info in enumerate(block_infos): if info is not None: - for peer_id, server in info.servers.items(): - if server.state != ServerState.ONLINE: + for peer_id, server_info in info.servers.items(): + if server_info.state != ServerState.ONLINE: continue if peer_id not in active_spans: active_spans[peer_id] = RemoteSpanInfo( - peer_id=peer_id, start=block_index, end=block_index + 1, throughput=server.throughput + peer_id=peer_id, + start=block_index, + end=block_index + 1, + server_info=server_info, ) else: # peer_id in active_spans active_spans[peer_id].end = block_index + 1 diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index fc505cc..19b475b 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -151,7 +151,7 @@ class RemoteSequenceManager: raise MissingBlocksError(current_index) if mode == "max_throughput": - span_weights = np.array([span.throughput for span in candidate_spans], dtype=np.float64) + span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64) elif mode == "min_latency": span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64) else: diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index 8d7d50b..e3a3e03 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -19,10 +19,17 @@ class ServerState(Enum): ONLINE = 2 +RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True) + + @pydantic.dataclasses.dataclass class ServerInfo: state: ServerState - throughput: pydantic.confloat(ge=0, allow_inf_nan=False, strict=True) + throughput: RPS + + network_rps: Optional[RPS] = None + forward_rps: Optional[RPS] = None + inference_rps: Optional[RPS] = None adapters: Sequence[str] = () version: Optional[str] = None @@ -60,7 +67,7 @@ class RemoteSpanInfo: peer_id: PeerID start: int end: int - throughput: float + server_info: ServerInfo @property def length(self): diff --git a/src/petals/server/server.py b/src/petals/server/server.py index f09724f..aea57c7 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -206,7 +206,7 @@ class Server: assert isinstance(throughput, float) or throughput in ["auto", "eval"] if throughput in ["auto", "eval"]: - throughput = get_server_throughput( + throughput_info = get_server_throughput( converted_model_name_or_path, self.block_config, device, @@ -217,14 +217,16 @@ class Server: force_eval=(throughput == "eval"), cache_dir=cache_dir, ) + else: + throughput_info = {"throughput": throughput} self.server_info = ServerInfo( state=ServerState.JOINING, - throughput=throughput, adapters=tuple(adapters), version=petals.__version__, torch_dtype=str(torch_dtype).replace("torch.", ""), quant_type=quant_type.name.lower(), using_relay=self.dht.client_mode, + **throughput_info, ) self.balance_quality = balance_quality diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 20625e6..d92355e 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -43,13 +43,13 @@ def get_server_throughput( tensor_parallel_devices: Sequence[torch.device], force_eval: bool = False, cache_dir: Optional[str] = None, -) -> float: +) -> Dict[str, float]: dtype = resolve_block_dtype(config, dtype) if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR lock_path = Path(cache_dir, "throughput.lock") - cache_path = Path(cache_dir, "throughput_v3.json") + cache_path = Path(cache_dir, "throughput_v4.json") # We use the system-wide lock since only one process at a time can measure the host throughput os.makedirs(lock_path.parent, exist_ok=True) @@ -93,10 +93,12 @@ def get_server_throughput( # Assuming the start block index is distributed uniformly, the average number of blocks used per request is # E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2 average_blocks_used = (num_blocks + 1) / 2 - throughput = throughput_info["compute_rps"] / average_blocks_used + throughput = throughput_info["forward_rps"] / average_blocks_used throughput = min(throughput, throughput_info.get("network_rps", math.inf)) + throughput_info["throughput"] = throughput logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks") - return throughput + + return throughput_info def measure_throughput_info( @@ -114,15 +116,31 @@ def measure_throughput_info( ) throughput_info = { - "compute_rps": measure_compute_rps( - config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices - ) + "inference_rps": measure_compute_rps( + config, + device, + dtype, + quant_type=quant_type, + tensor_parallel_devices=tensor_parallel_devices, + n_tokens=1, + n_steps=100, + inference=True, + ), + "forward_rps": measure_compute_rps( + config, + device, + dtype, + quant_type=quant_type, + tensor_parallel_devices=tensor_parallel_devices, + n_tokens=1024, + n_steps=10, + inference=False, + ), } try: throughput_info["network_rps"] = measure_network_rps(config) except Exception as e: - logger.warning(f"Failed to measure network throughput: {repr(e)}") - logger.warning("Proceeding with the compute throughput only") + logger.info(f"Network throughput is not available: {e}") return throughput_info @@ -135,6 +153,8 @@ def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Opt process.terminate() raise RuntimeError(f"speedtest did not finish in {timeout} seconds") network_info = pipe_recv.recv() + if "exception" in network_info: + raise RuntimeError(f"speedtest failed: {network_info['exception']}") bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request @@ -150,12 +170,15 @@ def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Opt def _measure_bits_per_second(pipe_send: mp.Pipe): - s = speedtest.Speedtest() - s.get_servers() - s.get_best_server() - s.download() - s.upload() - pipe_send.send(s.results.dict()) + try: + s = speedtest.Speedtest() + s.get_servers() + s.get_best_server() + s.download() + s.upload() + pipe_send.send(s.results.dict()) + except Exception as e: + pipe_send.send({"exception": repr(e)}) def measure_compute_rps( @@ -165,8 +188,9 @@ def measure_compute_rps( *, quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], - n_tokens: int = 16, - n_steps: int = 500, + n_tokens: int, + n_steps: int, + inference: bool, ) -> float: if not tensor_parallel_devices: tensor_parallel_devices = (device,) @@ -180,7 +204,7 @@ def measure_compute_rps( dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype) start_time = time.perf_counter() - _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache) + _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None) if step >= 1: # Skip the 1st step to exclude the initialization time elapsed += time.perf_counter() - start_time device_rps = n_steps * n_tokens / elapsed @@ -191,8 +215,8 @@ def measure_compute_rps( devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common()) logger.info( - f"Forward pass throughput: {device_rps:.1f} RPS per block " - f"({devices_repr}, {get_dtype_name(dtype, quant_type)})" + f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} RPS per block " + f"({n_tokens} tokens/batch, {devices_repr}, {get_dtype_name(dtype, quant_type)})" ) return device_rps @@ -202,7 +226,7 @@ def get_device_name(device: torch.device) -> str: def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str: - name = str(dtype) + name = str(dtype).replace("torch.", "") if quant_type != QuantType.NONE: name += f", quantized to {quant_type.name.lower()}" return name diff --git a/src/petals/utils/ping.py b/src/petals/utils/ping.py index de675a5..d5fd129 100644 --- a/src/petals/utils/ping.py +++ b/src/petals/utils/ping.py @@ -16,7 +16,7 @@ async def ping( _dht: hivemind.DHT, node: hivemind.dht.DHTNode, *, - wait_timeout: float = 1, + wait_timeout: float = 5, ) -> float: try: ping_request = dht_pb2.PingRequest(peer=node.protocol.node_info) diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index e5b450c..64c9c6a 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -24,8 +24,10 @@ def test_bnb_not_imported_when_unnecessary(): @pytest.mark.forked +@pytest.mark.parametrize("inference", [False, True]) +@pytest.mark.parametrize("n_tokens", [1, 16]) @pytest.mark.parametrize("tensor_parallel", [False, True]) -def test_compute_throughput(tensor_parallel: bool): +def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool): config = AutoDistributedConfig.from_pretrained(MODEL_NAME) tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else () compute_rps = measure_compute_rps( @@ -34,6 +36,8 @@ def test_compute_throughput(tensor_parallel: bool): dtype=torch.bfloat16, quant_type=QuantType.NONE, tensor_parallel_devices=tensor_parallel_devices, - n_steps=10, + n_tokens=n_tokens, + n_steps=5, + inference=inference, ) assert isinstance(compute_rps, float) and compute_rps > 0