Add AsyncIteratorCallbackHandler (#2329)

fix-readthedocs
Nuno Campos 1 year ago committed by GitHub
parent 6e4e7d2637
commit 6f39e88a2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,6 +5,7 @@ from typing import Generator, Optional
from langchain.callbacks.aim_callback import AimCallbackHandler from langchain.callbacks.aim_callback import AimCallbackHandler
from langchain.callbacks.base import ( from langchain.callbacks.base import (
AsyncCallbackManager,
BaseCallbackHandler, BaseCallbackHandler,
BaseCallbackManager, BaseCallbackManager,
CallbackManager, CallbackManager,
@ -13,6 +14,7 @@ from langchain.callbacks.clearml_callback import ClearMLCallbackHandler
from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.shared import SharedCallbackManager from langchain.callbacks.shared import SharedCallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.callbacks.tracers import SharedLangChainTracer from langchain.callbacks.tracers import SharedLangChainTracer
from langchain.callbacks.wandb_callback import WandbCallbackHandler from langchain.callbacks.wandb_callback import WandbCallbackHandler
@ -69,12 +71,14 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
__all__ = [ __all__ = [
"CallbackManager", "CallbackManager",
"AsyncCallbackManager",
"OpenAICallbackHandler", "OpenAICallbackHandler",
"SharedCallbackManager", "SharedCallbackManager",
"StdOutCallbackHandler", "StdOutCallbackHandler",
"AimCallbackHandler", "AimCallbackHandler",
"WandbCallbackHandler", "WandbCallbackHandler",
"ClearMLCallbackHandler", "ClearMLCallbackHandler",
"AsyncIteratorCallbackHandler",
"get_openai_callback", "get_openai_callback",
"set_tracing_callback_manager", "set_tracing_callback_manager",
"set_default_callback_manager", "set_default_callback_manager",

@ -0,0 +1,66 @@
from __future__ import annotations
import asyncio
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.schema import LLMResult
# TODO If used by two LLM runs in parallel this won't work as expected
class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
"""Callback handler that returns an async iterator."""
queue: asyncio.Queue[str]
done: asyncio.Event
@property
def always_verbose(self) -> bool:
return True
def __init__(self) -> None:
self.queue = asyncio.Queue()
self.done = asyncio.Event()
async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
# If two calls are made in a row, this resets the state
self.done.clear()
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.queue.put_nowait(token)
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.done.set()
async def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
self.done.set()
# TODO implement the other methods
async def aiter(self) -> AsyncIterator[str]:
while not self.queue.empty() or not self.done.is_set():
# Wait for the next token in the queue,
# but stop waiting if the done event is set
done, _ = await asyncio.wait(
[
asyncio.ensure_future(self.queue.get()),
asyncio.ensure_future(self.done.wait()),
],
return_when=asyncio.FIRST_COMPLETED,
)
# Extract the value of the first completed task
token_or_done = cast(Union[str, Literal[True]], done.pop().result())
# If the extracted value is the boolean True, the done event was set
if token_or_done is True:
break
# Otherwise, the extracted value is a token, which we yield
yield token_or_done
Loading…
Cancel
Save