|
|
|
@ -57,6 +57,7 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
|
|
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
|
|
|
|
|
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
|
|
|
|
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
|
|
|
|
self.stepped = False
|
|
|
|
|
self.closed = False
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@ -102,6 +103,7 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
|
|
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
|
|
|
|
|
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
|
|
await self._inputs_queue.put(inputs_serialized)
|
|
|
|
|
self.stepped = True
|
|
|
|
|
return await anext(self._outputs_stream)
|
|
|
|
|
|
|
|
|
|
def close(self):
|
|
|
|
@ -116,11 +118,12 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
|
|
"""Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
|
|
if self._outputs_stream is None:
|
|
|
|
|
return # already closed
|
|
|
|
|
await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
|
|
|
|
|
try:
|
|
|
|
|
await anext(self._outputs_stream)
|
|
|
|
|
except StopAsyncIteration:
|
|
|
|
|
pass
|
|
|
|
|
if self.stepped:
|
|
|
|
|
await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
|
|
|
|
|
try:
|
|
|
|
|
await anext(self._outputs_stream)
|
|
|
|
|
except StopAsyncIteration:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
|
self.close()
|
|
|
|
|