From ee115dd44e7d58a10b6fa9af340f8156487beb27 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Sat, 3 Dec 2022 12:21:14 +0000 Subject: [PATCH] Show PeerID in TimeoutError, don't show tracebacks --- src/petals/client/inference_session.py | 23 ++++++++++++++------ src/petals/client/remote_forward_backward.py | 12 ++++++++-- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 310840e..b177b55 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -38,6 +38,7 @@ class _ServerInferenceSession: def __init__( self, uid: ModuleUID, + stub: StubBase, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator, @@ -46,7 +47,7 @@ class _ServerInferenceSession: max_length: int, **metadata, ): - self.uid, self.rpc_info = uid, rpc_info + self.uid, self.stub, self.rpc_info = uid, stub, rpc_info self.num_blocks = uid.count(CHAIN_DELIMITER) + 1 self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter @@ -61,11 +62,15 @@ class _ServerInferenceSession: ) -> _ServerInferenceSession: """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" inputs_queue = asyncio.Queue() - outputs_stream = await asyncio.wait_for( - stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)), - timeout, - ) - return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata) + try: + outputs_stream = await asyncio.wait_for( + stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)), + timeout, + ) + except asyncio.TimeoutError as e: + e.args = (f"Timeout on rpc_inference.open(remote_peer=...{stub._peer[-6:]})",) + raise + return cls(uid, stub, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata) @staticmethod async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator: @@ -125,7 +130,11 @@ class _ServerInferenceSession: """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker""" await self._inputs_queue.put(inputs_serialized) self.stepped = True - return await asyncio.wait_for(anext(self._outputs_stream), self.timeout) + try: + return await asyncio.wait_for(anext(self._outputs_stream), self.timeout) + except asyncio.TimeoutError as e: + e.args = (f"Timeout on rpc_inference.step(remote_peer=...{self.stub._peer[-6:]})",) + raise def close(self): """Finish a given inference session, close the underlying connection""" diff --git a/src/petals/client/remote_forward_backward.py b/src/petals/client/remote_forward_backward.py index 542ad9c..ba57c9b 100644 --- a/src/petals/client/remote_forward_backward.py +++ b/src/petals/client/remote_forward_backward.py @@ -109,7 +109,11 @@ async def run_remote_forward( # call RPC on remote server size = sum(t.element_size() * t.nelement() for t in inputs) forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _forward_unary - deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs) + try: + deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs) + except asyncio.TimeoutError as e: + e.args = (f"Timeout on rpc_forward(remote_peer=...{stub._peer[-6:]})",) + raise return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"]) @@ -151,5 +155,9 @@ async def run_remote_backward( size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs) backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _backward_unary - deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs) + try: + deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs) + except asyncio.TimeoutError as e: + e.args = (f"Timeout on rpc_backward(remote_peer=...{stub._peer[-6:]})",) + raise return deserialized_grad_inputs