fix reference

client
justheuristic 2 years ago
parent a4bdce32c1
commit 217f109723

@ -57,6 +57,7 @@ class RemoteTransformerBlockInferenceSession:
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self.stepped = False
self.closed = False
@classmethod
@ -102,6 +103,7 @@ class RemoteTransformerBlockInferenceSession:
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
await self._inputs_queue.put(inputs_serialized)
self.stepped = True
return await anext(self._outputs_stream)
def close(self):
@ -116,11 +118,12 @@ class RemoteTransformerBlockInferenceSession:
"""Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
if self._outputs_stream is None:
return # already closed
await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
try:
await anext(self._outputs_stream)
except StopAsyncIteration:
pass
if self.stepped:
await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
try:
await anext(self._outputs_stream)
except StopAsyncIteration:
pass
def __del__(self):
self.close()

@ -75,7 +75,6 @@ class RemoteSequential(nn.Sequential):
return RemoteSequentialInferenceSession(self.remote_sequence_info)
class RemoteSequentialInferenceSession:
"""An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""

Loading…
Cancel
Save