Show PeerID in TimeoutError, don't show tracebacks

pull/128/head
Aleksandr Borzunov 1 year ago
parent 9dbf5e2e6f
commit ee115dd44e

@ -38,6 +38,7 @@ class _ServerInferenceSession:
def __init__(
self,
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
inputs_queue: asyncio.Queue,
outputs_aiter: AsyncIterator,
@ -46,7 +47,7 @@ class _ServerInferenceSession:
max_length: int,
**metadata,
):
self.uid, self.rpc_info = uid, rpc_info
self.uid, self.stub, self.rpc_info = uid, stub, rpc_info
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
@ -61,11 +62,15 @@ class _ServerInferenceSession:
) -> _ServerInferenceSession:
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
inputs_queue = asyncio.Queue()
outputs_stream = await asyncio.wait_for(
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
timeout,
)
return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
try:
outputs_stream = await asyncio.wait_for(
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
timeout,
)
except asyncio.TimeoutError as e:
e.args = (f"Timeout on rpc_inference.open(remote_peer=...{stub._peer[-6:]})",)
raise
return cls(uid, stub, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
@staticmethod
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
@ -125,7 +130,11 @@ class _ServerInferenceSession:
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
await self._inputs_queue.put(inputs_serialized)
self.stepped = True
return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
try:
return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
except asyncio.TimeoutError as e:
e.args = (f"Timeout on rpc_inference.step(remote_peer=...{self.stub._peer[-6:]})",)
raise
def close(self):
"""Finish a given inference session, close the underlying connection"""

@ -109,7 +109,11 @@ async def run_remote_forward(
# call RPC on remote server
size = sum(t.element_size() * t.nelement() for t in inputs)
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _forward_unary
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
try:
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
except asyncio.TimeoutError as e:
e.args = (f"Timeout on rpc_forward(remote_peer=...{stub._peer[-6:]})",)
raise
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
@ -151,5 +155,9 @@ async def run_remote_backward(
size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _backward_unary
deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
try:
deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
except asyncio.TimeoutError as e:
e.args = (f"Timeout on rpc_backward(remote_peer=...{stub._peer[-6:]})",)
raise
return deserialized_grad_inputs

Loading…
Cancel
Save