From 6f39e88a2ca9659a740f417386f857ac5ee74201 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Sat, 8 Apr 2023 22:34:55 +0100 Subject: [PATCH] Add AsyncIteratorCallbackHandler (#2329) --- langchain/callbacks/__init__.py | 4 ++ langchain/callbacks/streaming_aiter.py | 66 ++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 langchain/callbacks/streaming_aiter.py diff --git a/langchain/callbacks/__init__.py b/langchain/callbacks/__init__.py index 1bcf6d29..a3764c92 100644 --- a/langchain/callbacks/__init__.py +++ b/langchain/callbacks/__init__.py @@ -5,6 +5,7 @@ from typing import Generator, Optional from langchain.callbacks.aim_callback import AimCallbackHandler from langchain.callbacks.base import ( + AsyncCallbackManager, BaseCallbackHandler, BaseCallbackManager, CallbackManager, @@ -13,6 +14,7 @@ from langchain.callbacks.clearml_callback import ClearMLCallbackHandler from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.callbacks.shared import SharedCallbackManager from langchain.callbacks.stdout import StdOutCallbackHandler +from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler from langchain.callbacks.tracers import SharedLangChainTracer from langchain.callbacks.wandb_callback import WandbCallbackHandler @@ -69,12 +71,14 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]: __all__ = [ "CallbackManager", + "AsyncCallbackManager", "OpenAICallbackHandler", "SharedCallbackManager", "StdOutCallbackHandler", "AimCallbackHandler", "WandbCallbackHandler", "ClearMLCallbackHandler", + "AsyncIteratorCallbackHandler", "get_openai_callback", "set_tracing_callback_manager", "set_default_callback_manager", diff --git a/langchain/callbacks/streaming_aiter.py b/langchain/callbacks/streaming_aiter.py new file mode 100644 index 00000000..45a4459b --- /dev/null +++ b/langchain/callbacks/streaming_aiter.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import asyncio +from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast + +from langchain.callbacks.base import AsyncCallbackHandler +from langchain.schema import LLMResult + +# TODO If used by two LLM runs in parallel this won't work as expected + + +class AsyncIteratorCallbackHandler(AsyncCallbackHandler): + """Callback handler that returns an async iterator.""" + + queue: asyncio.Queue[str] + + done: asyncio.Event + + @property + def always_verbose(self) -> bool: + return True + + def __init__(self) -> None: + self.queue = asyncio.Queue() + self.done = asyncio.Event() + + async def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + # If two calls are made in a row, this resets the state + self.done.clear() + + async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + self.queue.put_nowait(token) + + async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + self.done.set() + + async def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + self.done.set() + + # TODO implement the other methods + + async def aiter(self) -> AsyncIterator[str]: + while not self.queue.empty() or not self.done.is_set(): + # Wait for the next token in the queue, + # but stop waiting if the done event is set + done, _ = await asyncio.wait( + [ + asyncio.ensure_future(self.queue.get()), + asyncio.ensure_future(self.done.wait()), + ], + return_when=asyncio.FIRST_COMPLETED, + ) + + # Extract the value of the first completed task + token_or_done = cast(Union[str, Literal[True]], done.pop().result()) + + # If the extracted value is the boolean True, the done event was set + if token_or_done is True: + break + + # Otherwise, the extracted value is a token, which we yield + yield token_or_done