diff --git a/langchain/callbacks/streaming_aiter.py b/langchain/callbacks/streaming_aiter.py index 45a4459b..47372dd9 100644 --- a/langchain/callbacks/streaming_aiter.py +++ b/langchain/callbacks/streaming_aiter.py @@ -47,14 +47,19 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler): while not self.queue.empty() or not self.done.is_set(): # Wait for the next token in the queue, # but stop waiting if the done event is set - done, _ = await asyncio.wait( + done, other = await asyncio.wait( [ + # NOTE: If you add other tasks here, update the code below, + # which assumes each set has exactly one task each asyncio.ensure_future(self.queue.get()), asyncio.ensure_future(self.done.wait()), ], return_when=asyncio.FIRST_COMPLETED, ) + # Cancel the other task + other.pop().cancel() + # Extract the value of the first completed task token_or_done = cast(Union[str, Literal[True]], done.pop().result())