|
|
|
@ -38,6 +38,7 @@ class _ServerInferenceSession:
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
uid: ModuleUID,
|
|
|
|
|
stub: StubBase,
|
|
|
|
|
rpc_info: RPCInfo,
|
|
|
|
|
inputs_queue: asyncio.Queue,
|
|
|
|
|
outputs_aiter: AsyncIterator,
|
|
|
|
@ -46,7 +47,7 @@ class _ServerInferenceSession:
|
|
|
|
|
max_length: int,
|
|
|
|
|
**metadata,
|
|
|
|
|
):
|
|
|
|
|
self.uid, self.rpc_info = uid, rpc_info
|
|
|
|
|
self.uid, self.stub, self.rpc_info = uid, stub, rpc_info
|
|
|
|
|
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
|
|
|
|
|
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
|
|
|
|
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
|
|
|
@ -61,11 +62,15 @@ class _ServerInferenceSession:
|
|
|
|
|
) -> _ServerInferenceSession:
|
|
|
|
|
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
|
|
inputs_queue = asyncio.Queue()
|
|
|
|
|
outputs_stream = await asyncio.wait_for(
|
|
|
|
|
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
|
|
|
|
|
timeout,
|
|
|
|
|
)
|
|
|
|
|
return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
|
|
|
|
|
try:
|
|
|
|
|
outputs_stream = await asyncio.wait_for(
|
|
|
|
|
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
|
|
|
|
|
timeout,
|
|
|
|
|
)
|
|
|
|
|
except asyncio.TimeoutError as e:
|
|
|
|
|
e.args = (f"Timeout on rpc_inference.open(remote_peer=...{stub._peer[-6:]})",)
|
|
|
|
|
raise
|
|
|
|
|
return cls(uid, stub, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
|
|
|
|
@ -125,7 +130,11 @@ class _ServerInferenceSession:
|
|
|
|
|
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
|
|
await self._inputs_queue.put(inputs_serialized)
|
|
|
|
|
self.stepped = True
|
|
|
|
|
return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
|
|
|
|
|
try:
|
|
|
|
|
return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
|
|
|
|
|
except asyncio.TimeoutError as e:
|
|
|
|
|
e.args = (f"Timeout on rpc_inference.step(remote_peer=...{self.stub._peer[-6:]})",)
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def close(self):
|
|
|
|
|
"""Finish a given inference session, close the underlying connection"""
|
|
|
|
|