From caa8e4742efaafd15b86e19589a7f289cd42b45d Mon Sep 17 00:00:00 2001 From: Ankush Gola <9536492+agola11@users.noreply.github.com> Date: Tue, 14 Feb 2023 15:06:14 -0800 Subject: [PATCH] Enable streaming for OpenAI LLM (#986) * Support a callback `on_llm_new_token` that users can implement when `OpenAI.streaming` is set to `True` --- .../chat_vector_db.ipynb | 134 ++++++- docs/modules/llms/getting_started.ipynb | 12 +- docs/modules/llms/how_to_guides.rst | 1 + docs/modules/llms/streaming_llm.ipynb | 140 ++++++++ langchain/agents/agent.py | 157 ++++++--- langchain/callbacks/base.py | 327 +++++++++++++++++- langchain/callbacks/openai_info.py | 6 +- langchain/callbacks/shared.py | 5 + langchain/callbacks/stdout.py | 4 + langchain/callbacks/streaming_stdout.py | 60 ++++ langchain/callbacks/streamlit.py | 4 + langchain/callbacks/tracers/base.py | 4 + langchain/chains/base.py | 27 +- langchain/chains/chat_vector_db/base.py | 17 + langchain/chains/combine_documents/base.py | 14 + .../chains/combine_documents/map_reduce.py | 23 ++ .../chains/combine_documents/map_rerank.py | 23 +- langchain/chains/combine_documents/refine.py | 59 +++- langchain/chains/combine_documents/stuff.py | 8 + langchain/chains/llm.py | 40 ++- .../chains/question_answering/__init__.py | 56 ++- langchain/llms/base.py | 48 ++- langchain/llms/openai.py | 88 ++++- tests/integration_tests/llms/test_openai.py | 59 ++++ .../callbacks/fake_callback_handler.py | 86 ++++- .../callbacks/test_callback_manager.py | 56 ++- 26 files changed, 1307 insertions(+), 151 deletions(-) create mode 100644 docs/modules/llms/streaming_llm.ipynb create mode 100644 langchain/callbacks/streaming_stdout.py diff --git a/docs/modules/chains/combine_docs_examples/chat_vector_db.ipynb b/docs/modules/chains/combine_docs_examples/chat_vector_db.ipynb index a1494497..1953c51f 100644 --- a/docs/modules/chains/combine_docs_examples/chat_vector_db.ipynb +++ b/docs/modules/chains/combine_docs_examples/chat_vector_db.ipynb @@ -14,7 +14,9 @@ "cell_type": "code", "execution_count": 1, "id": "70c4e529", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "from langchain.embeddings.openai import OpenAIEmbeddings\n", @@ -36,7 +38,9 @@ "cell_type": "code", "execution_count": 2, "id": "01c46e92", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "from langchain.document_loaders import TextLoader\n", @@ -56,7 +60,9 @@ "cell_type": "code", "execution_count": 3, "id": "433363a5", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# loaders = [....]\n", @@ -75,9 +81,11 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "a8930cf7", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "name": "stdout", @@ -106,9 +114,11 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "id": "7b4110f3", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "qa = ChatVectorDBChain.from_llm(OpenAI(temperature=0), vectorstore)" @@ -126,7 +136,9 @@ "cell_type": "code", "execution_count": 6, "id": "7fe3e730", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "chat_history = []\n", @@ -136,9 +148,11 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "bfff9cc8", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "data": { @@ -146,7 +160,7 @@ "\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\"" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -165,9 +179,11 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "00b4cf00", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "chat_history = [(query, result[\"answer\"])]\n", @@ -177,9 +193,11 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "id": "f01828d1", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "data": { @@ -187,7 +205,7 @@ "' Justice Stephen Breyer'" ] }, - "execution_count": 11, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -196,10 +214,90 @@ "result['answer']" ] }, + { + "cell_type": "markdown", + "id": "2324cdc6-98bf-4708-b8cd-02a98b1e5b67", + "metadata": {}, + "source": [ + "## Chat Vector DB with streaming to `stdout`\n", + "\n", + "Output from the chain will be streamed to `stdout` token by token in this example." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2efacec3-2690-4b05-8de3-a32fd2ac3911", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.chains.llm import LLMChain\n", + "from langchain.callbacks.base import CallbackManager\n", + "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", + "from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT\n", + "from langchain.chains.question_answering import load_qa_chain\n", + "\n", + "# Construct a ChatVectorDBChain with a streaming llm for combine docs\n", + "# and a separate, non-streaming llm for question generation\n", + "llm = OpenAI(temperature=0)\n", + "streaming_llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n", + "\n", + "question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)\n", + "doc_chain = load_qa_chain(streaming_llm, chain_type=\"stuff\", prompt=QA_PROMPT)\n", + "\n", + "qa = ChatVectorDBChain(vectorstore=vectorstore, combine_docs_chain=doc_chain, question_generator=question_generator)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "fd6d43f4-7428-44a4-81bc-26fe88a98762", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans." + ] + } + ], + "source": [ + "chat_history = []\n", + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "result = qa({\"question\": query, \"chat_history\": chat_history})" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "5ab38978-f3e8-4fa7-808c-c79dec48379a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Justice Stephen Breyer" + ] + } + ], + "source": [ + "chat_history = [(query, result[\"answer\"])]\n", + "query = \"Did he mention who she suceeded\"\n", + "result = qa({\"question\": query, \"chat_history\": chat_history})" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "d0f869c6", + "id": "a7ea93ff-1899-4171-9c24-85df20ae1a3d", "metadata": {}, "outputs": [], "source": [] @@ -221,7 +319,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/docs/modules/llms/getting_started.ipynb b/docs/modules/llms/getting_started.ipynb index 043fb3ec..e686f30f 100644 --- a/docs/modules/llms/getting_started.ipynb +++ b/docs/modules/llms/getting_started.ipynb @@ -18,7 +18,9 @@ "cell_type": "code", "execution_count": 1, "id": "df924055", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "from langchain.llms import OpenAI" @@ -207,14 +209,6 @@ "source": [ "llm.get_num_tokens(\"what a joke\")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b004ffdd", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs/modules/llms/how_to_guides.rst b/docs/modules/llms/how_to_guides.rst index 07dcd609..6985519a 100644 --- a/docs/modules/llms/how_to_guides.rst +++ b/docs/modules/llms/how_to_guides.rst @@ -8,6 +8,7 @@ They are split into two categories: 1. `Generic Functionality <./generic_how_to.html>`_: Covering generic functionality all LLMs should have. 2. `Integrations <./integrations.html>`_: Covering integrations with various LLM providers. 3. `Asynchronous <./async_llm.html>`_: Covering asynchronous functionality. +4. `Streaming <./streaming_llm.html>`_: Covering streaming functionality. .. toctree:: :maxdepth: 1 diff --git a/docs/modules/llms/streaming_llm.ipynb b/docs/modules/llms/streaming_llm.ipynb new file mode 100644 index 00000000..f1b5f2c0 --- /dev/null +++ b/docs/modules/llms/streaming_llm.ipynb @@ -0,0 +1,140 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6eaf7e66-f49c-42da-8d11-22ea13bef718", + "metadata": {}, + "source": [ + "# Streaming with LLMs\n", + "\n", + "LangChain provides streaming support for LLMs. Currently, we only support streaming for the `OpenAI` LLM implementation, but streaming support for other LLM implementations is on the roadmap. To utilize streaming, use a [`CallbackHandler`](https://github.com/hwchase17/langchain/blob/master/langchain/callbacks/base.py) that implements `on_llm_new_token`. In this example, we are using [`StreamingStdOutCallbackHandler`]()." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "4ac0ff54-540a-4f2b-8d9a-b590fec7fe07", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Verse 1\n", + "I'm sippin' on sparkling water,\n", + "It's so refreshing and light,\n", + "It's the perfect way to quench my thirst,\n", + "On a hot summer night.\n", + "\n", + "Chorus\n", + "Sparkling water, sparkling water,\n", + "It's the best way to stay hydrated,\n", + "It's so refreshing and light,\n", + "It's the perfect way to stay alive.\n", + "\n", + "Verse 2\n", + "I'm sippin' on sparkling water,\n", + "It's so bubbly and bright,\n", + "It's the perfect way to cool me down,\n", + "On a hot summer night.\n", + "\n", + "Chorus\n", + "Sparkling water, sparkling water,\n", + "It's the best way to stay hydrated,\n", + "It's so refreshing and light,\n", + "It's the perfect way to stay alive.\n", + "\n", + "Verse 3\n", + "I'm sippin' on sparkling water,\n", + "It's so crisp and clean,\n", + "It's the perfect way to keep me going,\n", + "On a hot summer day.\n", + "\n", + "Chorus\n", + "Sparkling water, sparkling water,\n", + "It's the best way to stay hydrated,\n", + "It's so refreshing and light,\n", + "It's the perfect way to stay alive." + ] + } + ], + "source": [ + "from langchain.llms import OpenAI\n", + "from langchain.callbacks.base import CallbackManager\n", + "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", + "\n", + "\n", + "llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n", + "resp = llm(\"Write me a song about sparkling water.\")" + ] + }, + { + "cell_type": "markdown", + "id": "61fb6de7-c6c8-48d0-a48e-1204c027a23c", + "metadata": { + "tags": [] + }, + "source": [ + "We still have access to the end `LLMResult` if using `generate`. However, `token_usage` is not currently supported for streaming." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a35373f1-9ee6-4753-a343-5aee749b8527", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Q: What did the fish say when it hit the wall?\n", + "A: Dam!" + ] + }, + { + "data": { + "text/plain": [ + "LLMResult(generations=[[Generation(text='\\n\\nQ: What did the fish say when it hit the wall?\\nA: Dam!', generation_info={'finish_reason': 'stop', 'logprobs': None})]], llm_output={'token_usage': {}})" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llm.generate([\"Tell me a joke.\"])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index da27db35..f75d04f7 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -375,6 +375,22 @@ class AgentExecutor(Chain, BaseModel): final_output["intermediate_steps"] = intermediate_steps return final_output + async def _areturn( + self, output: AgentFinish, intermediate_steps: list + ) -> Dict[str, Any]: + if self.callback_manager.is_async: + await self.callback_manager.on_agent_finish( + output, color="green", verbose=self.verbose + ) + else: + self.callback_manager.on_agent_finish( + output, color="green", verbose=self.verbose + ) + final_output = output.return_values + if self.return_intermediate_steps: + final_output["intermediate_steps"] = intermediate_steps + return final_output + def _take_next_step( self, name_to_tool_map: Dict[str, Tool], @@ -428,6 +444,90 @@ class AgentExecutor(Chain, BaseModel): return AgentFinish({self.agent.return_values[0]: observation}, "") return output, observation + async def _atake_next_step( + self, + name_to_tool_map: Dict[str, Tool], + color_mapping: Dict[str, str], + inputs: Dict[str, str], + intermediate_steps: List[Tuple[AgentAction, str]], + ) -> Union[AgentFinish, Tuple[AgentAction, str]]: + """Take a single step in the thought-action-observation loop. + + Override this to take control of how the agent makes and acts on choices. + """ + # Call the LLM to see what to do. + output = await self.agent.aplan(intermediate_steps, **inputs) + # If the tool chosen is the finishing tool, then we end and return. + if isinstance(output, AgentFinish): + return output + # Otherwise we lookup the tool + if output.tool in name_to_tool_map: + tool = name_to_tool_map[output.tool] + if self.callback_manager.is_async: + await self.callback_manager.on_tool_start( + {"name": str(tool.func)[:60] + "..."}, + output, + verbose=self.verbose, + ) + else: + self.callback_manager.on_tool_start( + {"name": str(tool.func)[:60] + "..."}, + output, + verbose=self.verbose, + ) + try: + # We then call the tool on the tool input to get an observation + observation = ( + await tool.coroutine(output.tool_input) + if tool.coroutine + # If the tool is not a coroutine, we run it in the executor + # to avoid blocking the event loop. + else await asyncio.get_event_loop().run_in_executor( + None, tool.func, output.tool_input + ) + ) + color = color_mapping[output.tool] + return_direct = tool.return_direct + except (KeyboardInterrupt, Exception) as e: + if self.callback_manager.is_async: + await self.callback_manager.on_tool_error(e, verbose=self.verbose) + else: + self.callback_manager.on_tool_error(e, verbose=self.verbose) + raise e + else: + if self.callback_manager.is_async: + await self.callback_manager.on_tool_start( + {"name": "N/A"}, output, verbose=self.verbose + ) + else: + self.callback_manager.on_tool_start( + {"name": "N/A"}, output, verbose=self.verbose + ) + observation = f"{output.tool} is not a valid tool, try another one." + color = None + return_direct = False + llm_prefix = "" if return_direct else self.agent.llm_prefix + if self.callback_manager.is_async: + await self.callback_manager.on_tool_end( + observation, + color=color, + observation_prefix=self.agent.observation_prefix, + llm_prefix=llm_prefix, + verbose=self.verbose, + ) + else: + self.callback_manager.on_tool_end( + observation, + color=color, + observation_prefix=self.agent.observation_prefix, + llm_prefix=llm_prefix, + verbose=self.verbose, + ) + if return_direct: + # Set the log to "" because we do not want to log it. + return AgentFinish({self.agent.return_values[0]: observation}, "") + return output, observation + def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: """Run text through and get agent response.""" # Make sure that every tool is synchronous (not a coroutine) @@ -486,58 +586,15 @@ class AgentExecutor(Chain, BaseModel): iterations = 0 # We now enter the agent loop (until it returns something). while self._should_continue(iterations): - # Call the LLM to see what to do. - output = await self.agent.aplan(intermediate_steps, **inputs) - # If the tool chosen is the finishing tool, then we end and return. - if isinstance(output, AgentFinish): - return self._return(output, intermediate_steps) - - # Otherwise we lookup the tool - if output.tool in name_to_tool_map: - tool = name_to_tool_map[output.tool] - self.callback_manager.on_tool_start( - {"name": str(tool.func)[:60] + "..."}, - output, - verbose=self.verbose, - ) - try: - # We then call the tool on the tool input to get an observation - observation = ( - await tool.coroutine(output.tool_input) - if tool.coroutine - # If the tool is not a coroutine, we run it in the executor - # to avoid blocking the event loop. - else await asyncio.get_event_loop().run_in_executor( - None, tool.func, output.tool_input - ) - ) - color = color_mapping[output.tool] - return_direct = tool.return_direct - except (KeyboardInterrupt, Exception) as e: - self.callback_manager.on_tool_error(e, verbose=self.verbose) - raise e - else: - self.callback_manager.on_tool_start( - {"name": "N/A"}, output, verbose=self.verbose - ) - observation = f"{output.tool} is not a valid tool, try another one." - color = None - return_direct = False - llm_prefix = "" if return_direct else self.agent.llm_prefix - self.callback_manager.on_tool_end( - observation, - color=color, - observation_prefix=self.agent.observation_prefix, - llm_prefix=llm_prefix, - verbose=self.verbose, + next_step_output = await self._atake_next_step( + name_to_tool_map, color_mapping, inputs, intermediate_steps ) - intermediate_steps.append((output, observation)) - if return_direct: - # Set the log to "" because we do not want to log it. - output = AgentFinish({self.agent.return_values[0]: observation}, "") - return self._return(output, intermediate_steps) + if isinstance(next_step_output, AgentFinish): + return await self._areturn(next_step_output, intermediate_steps) + + intermediate_steps.append(next_step_output) iterations += 1 output = self.agent.return_stopped_response( self.early_stopping_method, intermediate_steps, **inputs ) - return self._return(output, intermediate_steps) + return await self._areturn(output, intermediate_steps) diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index aeb409f9..3c4df605 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -1,5 +1,6 @@ """Base callback handler that can be used to handle callbacks from langchain.""" - +import asyncio +import functools from abc import ABC, abstractmethod from typing import Any, Dict, List, Union @@ -32,63 +33,72 @@ class BaseCallbackHandler(ABC): @abstractmethod def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: + ) -> Any: """Run when LLM starts running.""" @abstractmethod - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + def on_llm_new_token(self, token: str, **kwargs: Any) -> Any: + """Run on new LLM token. Only available when streaming is enabled.""" + + @abstractmethod + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any: """Run when LLM ends running.""" @abstractmethod def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + ) -> Any: """Run when LLM errors.""" @abstractmethod def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: + ) -> Any: """Run when chain starts running.""" @abstractmethod - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: """Run when chain ends running.""" @abstractmethod def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + ) -> Any: """Run when chain errors.""" @abstractmethod def on_tool_start( self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any - ) -> None: + ) -> Any: """Run when tool starts running.""" @abstractmethod - def on_tool_end(self, output: str, **kwargs: Any) -> None: + def on_tool_end(self, output: str, **kwargs: Any) -> Any: """Run when tool ends running.""" @abstractmethod def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + ) -> Any: """Run when tool errors.""" @abstractmethod - def on_text(self, text: str, **kwargs: Any) -> None: + def on_text(self, text: str, **kwargs: Any) -> Any: """Run on arbitrary text.""" @abstractmethod - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: """Run on agent end.""" class BaseCallbackManager(BaseCallbackHandler, ABC): """Base callback manager that can be used to handle callbacks from LangChain.""" + @property + def is_async(self) -> bool: + """Whether the callback manager is async.""" + return False + @abstractmethod def add_handler(self, callback: BaseCallbackHandler) -> None: """Add a handler to the callback manager.""" @@ -126,6 +136,15 @@ class CallbackManager(BaseCallbackManager): if verbose or handler.always_verbose: handler.on_llm_start(serialized, prompts, **kwargs) + def on_llm_new_token( + self, token: str, verbose: bool = False, **kwargs: Any + ) -> None: + """Run when LLM generates a new token.""" + for handler in self.handlers: + if not handler.ignore_llm: + if verbose or handler.always_verbose: + handler.on_llm_new_token(token, **kwargs) + def on_llm_end( self, response: LLMResult, verbose: bool = False, **kwargs: Any ) -> None: @@ -239,3 +258,287 @@ class CallbackManager(BaseCallbackManager): def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: """Set handlers as the only handlers on the callback manager.""" self.handlers = handlers + + +class AsyncCallbackHandler(BaseCallbackHandler): + """Async callback handler that can be used to handle callbacks from langchain.""" + + async def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts running.""" + + async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + + async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + + async def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when LLM errors.""" + + async def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + + async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + + async def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when chain errors.""" + + async def on_tool_start( + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + + async def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + + async def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when tool errors.""" + + async def on_text(self, text: str, **kwargs: Any) -> None: + """Run on arbitrary text.""" + + async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run on agent end.""" + + +class AsyncCallbackManager(BaseCallbackManager): + """Async callback manager that can be used to handle callbacks from LangChain.""" + + @property + def is_async(self) -> bool: + """Return whether the handler is async.""" + return True + + def __init__(self, handlers: List[BaseCallbackHandler]) -> None: + """Initialize callback manager.""" + self.handlers: List[BaseCallbackHandler] = handlers + + async def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + verbose: bool = False, + **kwargs: Any + ) -> None: + """Run when LLM starts running.""" + for handler in self.handlers: + if not handler.ignore_llm: + if verbose or handler.always_verbose: + if asyncio.iscoroutinefunction(handler.on_llm_start): + await handler.on_llm_start(serialized, prompts, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, + functools.partial( + handler.on_llm_start, serialized, prompts, **kwargs + ), + ) + + async def on_llm_new_token( + self, token: str, verbose: bool = False, **kwargs: Any + ) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + for handler in self.handlers: + if not handler.ignore_llm: + if verbose or handler.always_verbose: + if asyncio.iscoroutinefunction(handler.on_llm_new_token): + await handler.on_llm_new_token(token, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, + functools.partial( + handler.on_llm_new_token, token, **kwargs + ), + ) + + async def on_llm_end( + self, response: LLMResult, verbose: bool = False, **kwargs: Any + ) -> None: + """Run when LLM ends running.""" + for handler in self.handlers: + if not handler.ignore_llm: + if verbose or handler.always_verbose: + if asyncio.iscoroutinefunction(handler.on_llm_end): + await handler.on_llm_end(response, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, + functools.partial(handler.on_llm_end, response, **kwargs), + ) + + async def on_llm_error( + self, + error: Union[Exception, KeyboardInterrupt], + verbose: bool = False, + **kwargs: Any + ) -> None: + """Run when LLM errors.""" + for handler in self.handlers: + if not handler.ignore_llm: + if verbose or handler.always_verbose: + if asyncio.iscoroutinefunction(handler.on_llm_error): + await handler.on_llm_error(error, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, + functools.partial(handler.on_llm_error, error, **kwargs), + ) + + async def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + verbose: bool = False, + **kwargs: Any + ) -> None: + """Run when chain starts running.""" + for handler in self.handlers: + if not handler.ignore_chain: + if verbose or handler.always_verbose: + if asyncio.iscoroutinefunction(handler.on_chain_start): + await handler.on_chain_start(serialized, inputs, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, + functools.partial( + handler.on_chain_start, serialized, inputs, **kwargs + ), + ) + + async def on_chain_end( + self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any + ) -> None: + """Run when chain ends running.""" + for handler in self.handlers: + if not handler.ignore_chain: + if verbose or handler.always_verbose: + if asyncio.iscoroutinefunction(handler.on_chain_end): + await handler.on_chain_end(outputs, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, + functools.partial(handler.on_chain_end, outputs, **kwargs), + ) + + async def on_chain_error( + self, + error: Union[Exception, KeyboardInterrupt], + verbose: bool = False, + **kwargs: Any + ) -> None: + """Run when chain errors.""" + for handler in self.handlers: + if not handler.ignore_chain: + if verbose or handler.always_verbose: + if asyncio.iscoroutinefunction(handler.on_chain_error): + await handler.on_chain_error(error, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, + functools.partial(handler.on_chain_error, error, **kwargs), + ) + + async def on_tool_start( + self, + serialized: Dict[str, Any], + action: AgentAction, + verbose: bool = False, + **kwargs: Any + ) -> None: + """Run when tool starts running.""" + for handler in self.handlers: + if not handler.ignore_agent: + if verbose or handler.always_verbose: + if asyncio.iscoroutinefunction(handler.on_tool_start): + await handler.on_tool_start(serialized, action, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, + functools.partial( + handler.on_tool_start, serialized, action, **kwargs + ), + ) + + async def on_tool_end( + self, output: str, verbose: bool = False, **kwargs: Any + ) -> None: + """Run when tool ends running.""" + for handler in self.handlers: + if not handler.ignore_agent: + if verbose or handler.always_verbose: + if asyncio.iscoroutinefunction(handler.on_tool_end): + await handler.on_tool_end(output, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, + functools.partial(handler.on_tool_end, output, **kwargs), + ) + + async def on_tool_error( + self, + error: Union[Exception, KeyboardInterrupt], + verbose: bool = False, + **kwargs: Any + ) -> None: + """Run when tool errors.""" + for handler in self.handlers: + if not handler.ignore_agent: + if verbose or handler.always_verbose: + if asyncio.iscoroutinefunction(handler.on_tool_error): + await handler.on_tool_error(error, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, + functools.partial(handler.on_tool_error, error, **kwargs), + ) + + async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None: + """Run when text is printed.""" + for handler in self.handlers: + if verbose or handler.always_verbose: + if asyncio.iscoroutinefunction(handler.on_text): + await handler.on_text(text, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, functools.partial(handler.on_text, text, **kwargs) + ) + + async def on_agent_finish( + self, finish: AgentFinish, verbose: bool = False, **kwargs: Any + ) -> None: + """Run when agent finishes.""" + for handler in self.handlers: + if not handler.ignore_agent: + if verbose or handler.always_verbose: + if asyncio.iscoroutinefunction(handler.on_agent_finish): + await handler.on_agent_finish(finish, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, + functools.partial( + handler.on_agent_finish, finish, **kwargs + ), + ) + + def add_handler(self, handler: BaseCallbackHandler) -> None: + """Add a handler to the callback manager.""" + self.handlers.append(handler) + + def remove_handler(self, handler: BaseCallbackHandler) -> None: + """Remove a handler from the callback manager.""" + self.handlers.remove(handler) + + def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: + """Set handlers as the only handlers on the callback manager.""" + self.handlers = handlers diff --git a/langchain/callbacks/openai_info.py b/langchain/callbacks/openai_info.py index 32b26591..43eeff28 100644 --- a/langchain/callbacks/openai_info.py +++ b/langchain/callbacks/openai_info.py @@ -21,8 +21,12 @@ class OpenAICallbackHandler(BaseCallbackHandler): """Print out the prompts.""" pass + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Print out the token.""" + pass + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Do nothing.""" + """Collect token usage.""" if response.llm_output is not None: if "token_usage" in response.llm_output: token_usage = response.llm_output["token_usage"] diff --git a/langchain/callbacks/shared.py b/langchain/callbacks/shared.py index 4b0772cd..d5dc311b 100644 --- a/langchain/callbacks/shared.py +++ b/langchain/callbacks/shared.py @@ -46,6 +46,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): with self._lock: self._callback_manager.on_llm_end(response, **kwargs) + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run when LLM generates a new token.""" + with self._lock: + self._callback_manager.on_llm_new_token(token, **kwargs) + def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: diff --git a/langchain/callbacks/stdout.py b/langchain/callbacks/stdout.py index 2b0b860d..238704b1 100644 --- a/langchain/callbacks/stdout.py +++ b/langchain/callbacks/stdout.py @@ -23,6 +23,10 @@ class StdOutCallbackHandler(BaseCallbackHandler): """Do nothing.""" pass + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Do nothing.""" + pass + def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: diff --git a/langchain/callbacks/streaming_stdout.py b/langchain/callbacks/streaming_stdout.py new file mode 100644 index 00000000..988a44fb --- /dev/null +++ b/langchain/callbacks/streaming_stdout.py @@ -0,0 +1,60 @@ +"""Callback Handler streams to stdout on new llm token.""" +import sys +from typing import Any, Dict, List, Union + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import AgentAction, AgentFinish, LLMResult + + +class StreamingStdOutCallbackHandler(BaseCallbackHandler): + """Callback handler for streaming. Only works with LLMs that support streaming.""" + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts running.""" + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + sys.stdout.write(token) + sys.stdout.flush() + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when LLM errors.""" + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when chain errors.""" + + def on_tool_start( + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when tool errors.""" + + def on_text(self, text: str, **kwargs: Any) -> None: + """Run on arbitrary text.""" + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run on agent end.""" diff --git a/langchain/callbacks/streamlit.py b/langchain/callbacks/streamlit.py index 1451aac7..8d606e7c 100644 --- a/langchain/callbacks/streamlit.py +++ b/langchain/callbacks/streamlit.py @@ -18,6 +18,10 @@ class StreamlitCallbackHandler(BaseCallbackHandler): for prompt in prompts: st.write(prompt) + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Do nothing.""" + pass + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Do nothing.""" pass diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index 76b11639..6fdae5fe 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -129,6 +129,10 @@ class BaseTracer(BaseCallbackHandler, ABC): ) self._start_trace(llm_run) + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Handle a new token for an LLM run.""" + pass + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """End a trace for an LLM run.""" if not self._stack or not isinstance(self._stack[-1], LLMRun): diff --git a/langchain/chains/base.py b/langchain/chains/base.py index bd535ef5..ef72bf62 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -158,17 +158,30 @@ class Chain(BaseModel, ABC): """ inputs = self.prep_inputs(inputs) - self.callback_manager.on_chain_start( - {"name": self.__class__.__name__}, - inputs, - verbose=self.verbose, - ) + if self.callback_manager.is_async: + await self.callback_manager.on_chain_start( + {"name": self.__class__.__name__}, + inputs, + verbose=self.verbose, + ) + else: + self.callback_manager.on_chain_start( + {"name": self.__class__.__name__}, + inputs, + verbose=self.verbose, + ) try: outputs = await self._acall(inputs) except (KeyboardInterrupt, Exception) as e: - self.callback_manager.on_chain_error(e, verbose=self.verbose) + if self.callback_manager.is_async: + await self.callback_manager.on_chain_error(e, verbose=self.verbose) + else: + self.callback_manager.on_chain_error(e, verbose=self.verbose) raise e - self.callback_manager.on_chain_end(outputs, verbose=self.verbose) + if self.callback_manager.is_async: + await self.callback_manager.on_chain_end(outputs, verbose=self.verbose) + else: + self.callback_manager.on_chain_end(outputs, verbose=self.verbose) return self.prep_outputs(inputs, outputs, return_only_outputs) def prep_outputs( diff --git a/langchain/chains/chat_vector_db/base.py b/langchain/chains/chat_vector_db/base.py index abb2c067..3548eb0e 100644 --- a/langchain/chains/chat_vector_db/base.py +++ b/langchain/chains/chat_vector_db/base.py @@ -83,3 +83,20 @@ class ChatVectorDBChain(Chain, BaseModel): new_inputs["chat_history"] = chat_history_str answer, _ = self.combine_docs_chain.combine_docs(docs, **new_inputs) return {self.output_key: answer} + + async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]: + question = inputs["question"] + chat_history_str = _get_chat_history(inputs["chat_history"]) + if chat_history_str: + new_question = await self.question_generator.arun( + question=question, chat_history=chat_history_str + ) + else: + new_question = question + # TODO: This blocks the event loop, but it's not clear how to avoid it. + docs = self.vectorstore.similarity_search(new_question, k=4) + new_inputs = inputs.copy() + new_inputs["question"] = new_question + new_inputs["chat_history"] = chat_history_str + answer, _ = await self.combine_docs_chain.acombine_docs(docs, **new_inputs) + return {self.output_key: answer} diff --git a/langchain/chains/combine_documents/base.py b/langchain/chains/combine_documents/base.py index 40684e5b..dde16d45 100644 --- a/langchain/chains/combine_documents/base.py +++ b/langchain/chains/combine_documents/base.py @@ -43,6 +43,12 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC): def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: """Combine documents into a single string.""" + @abstractmethod + async def acombine_docs( + self, docs: List[Document], **kwargs: Any + ) -> Tuple[str, dict]: + """Combine documents into a single string asynchronously.""" + def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: docs = inputs[self.input_key] # Other keys are assumed to be needed for LLM prediction @@ -51,6 +57,14 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC): extra_return_dict[self.output_key] = output return extra_return_dict + async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]: + docs = inputs[self.input_key] + # Other keys are assumed to be needed for LLM prediction + other_keys = {k: v for k, v in inputs.items() if k != self.input_key} + output, extra_return_dict = await self.acombine_docs(docs, **other_keys) + extra_return_dict[self.output_key] = output + return extra_return_dict + class AnalyzeDocumentChain(Chain, BaseModel): """Chain that splits documents, then analyzes it in pieces.""" diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index 1e4fade7..9f6d4678 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -140,6 +140,29 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel): # FYI - this is parallelized and so it is fast. [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] ) + return self._process_results(results, docs, token_max, **kwargs) + + async def acombine_docs( + self, docs: List[Document], **kwargs: Any + ) -> Tuple[str, dict]: + """Combine documents in a map reduce manner. + + Combine by mapping first chain over all documents, then reducing the results. + This reducing can be done recursively if needed (if there are many documents). + """ + results = await self.llm_chain.aapply( + # FYI - this is parallelized and so it is fast. + [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] + ) + return self._process_results(results, docs, **kwargs) + + def _process_results( + self, + results: List[Dict], + docs: List[Document], + token_max: int = 3000, + **kwargs: Any, + ) -> Tuple[str, dict]: question_result_key = self.llm_chain.output_key result_docs = [ Document(page_content=r[question_result_key], metadata=docs[i].metadata) diff --git a/langchain/chains/combine_documents/map_rerank.py b/langchain/chains/combine_documents/map_rerank.py index f97fc032..71855650 100644 --- a/langchain/chains/combine_documents/map_rerank.py +++ b/langchain/chains/combine_documents/map_rerank.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast from pydantic import BaseModel, Extra, root_validator @@ -98,8 +98,27 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain, BaseModel): # FYI - this is parallelized and so it is fast. [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] ) - typed_results = cast(List[dict], results) + return self._process_results(docs, results) + + async def acombine_docs( + self, docs: List[Document], **kwargs: Any + ) -> Tuple[str, dict]: + """Combine documents in a map rerank manner. + + Combine by mapping first chain over all documents, then reranking the results. + """ + results = await self.llm_chain.aapply_and_parse( + # FYI - this is parallelized and so it is fast. + [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] + ) + return self._process_results(docs, results) + def _process_results( + self, + docs: List[Document], + results: Sequence[Union[str, List[str], Dict[str, str]]], + ) -> Tuple[str, dict]: + typed_results = cast(List[dict], results) sorted_res = sorted( zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key]) ) diff --git a/langchain/chains/combine_documents/refine.py b/langchain/chains/combine_documents/refine.py index 57d3d025..e20ab147 100644 --- a/langchain/chains/combine_documents/refine.py +++ b/langchain/chains/combine_documents/refine.py @@ -84,36 +84,59 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel): def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: """Combine by mapping first chain over all, then stuffing into final chain.""" - base_info = {"page_content": docs[0].page_content} - base_info.update(docs[0].metadata) - document_info = {k: base_info[k] for k in self.document_prompt.input_variables} - base_inputs: dict = { - self.document_variable_name: self.document_prompt.format(**document_info) - } - inputs = {**base_inputs, **kwargs} + inputs = self._construct_initial_inputs(docs, **kwargs) res = self.initial_llm_chain.predict(**inputs) refine_steps = [res] for doc in docs[1:]: - base_info = {"page_content": doc.page_content} - base_info.update(doc.metadata) - document_info = { - k: base_info[k] for k in self.document_prompt.input_variables - } - base_inputs = { - self.document_variable_name: self.document_prompt.format( - **document_info - ), - self.initial_response_name: res, - } + base_inputs = self._construct_refine_inputs(doc, res) inputs = {**base_inputs, **kwargs} res = self.refine_llm_chain.predict(**inputs) refine_steps.append(res) + return self._construct_result(refine_steps, res) + + async def acombine_docs( + self, docs: List[Document], **kwargs: Any + ) -> Tuple[str, dict]: + """Combine by mapping first chain over all, then stuffing into final chain.""" + inputs = self._construct_initial_inputs(docs, **kwargs) + res = await self.initial_llm_chain.apredict(**inputs) + refine_steps = [res] + for doc in docs[1:]: + base_inputs = self._construct_refine_inputs(doc, res) + inputs = {**base_inputs, **kwargs} + res = await self.refine_llm_chain.apredict(**inputs) + refine_steps.append(res) + return self._construct_result(refine_steps, res) + + def _construct_result(self, refine_steps: List[str], res: str) -> Tuple[str, dict]: if self.return_intermediate_steps: extra_return_dict = {"intermediate_steps": refine_steps} else: extra_return_dict = {} return res, extra_return_dict + def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]: + base_info = {"page_content": doc.page_content} + base_info.update(doc.metadata) + document_info = {k: base_info[k] for k in self.document_prompt.input_variables} + base_inputs = { + self.document_variable_name: self.document_prompt.format(**document_info), + self.initial_response_name: res, + } + return base_inputs + + def _construct_initial_inputs( + self, docs: List[Document], **kwargs: Any + ) -> Dict[str, Any]: + base_info = {"page_content": docs[0].page_content} + base_info.update(docs[0].metadata) + document_info = {k: base_info[k] for k in self.document_prompt.input_variables} + base_inputs: dict = { + self.document_variable_name: self.document_prompt.format(**document_info) + } + inputs = {**base_inputs, **kwargs} + return inputs + @property def _chain_type(self) -> str: return "refine_documents_chain" diff --git a/langchain/chains/combine_documents/stuff.py b/langchain/chains/combine_documents/stuff.py index d5cfb993..eee1d886 100644 --- a/langchain/chains/combine_documents/stuff.py +++ b/langchain/chains/combine_documents/stuff.py @@ -84,6 +84,14 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel): # Call predict on the LLM. return self.llm_chain.predict(**inputs), {} + async def acombine_docs( + self, docs: List[Document], **kwargs: Any + ) -> Tuple[str, dict]: + """Stuff all documents into one prompt and pass to LLM.""" + inputs = self._get_inputs(docs, **kwargs) + # Call predict on the LLM. + return await self.llm_chain.apredict(**inputs), {} + @property def _chain_type(self) -> str: return "stuff_documents_chain" diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 18eee49f..790782e7 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -61,7 +61,7 @@ class LLMChain(Chain, BaseModel): async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult: """Generate LLM result from inputs.""" - prompts, stop = self.prep_prompts(input_list) + prompts, stop = await self.aprep_prompts(input_list) response = await self.llm.agenerate(prompts, stop=stop) return response @@ -86,6 +86,32 @@ class LLMChain(Chain, BaseModel): prompts.append(prompt) return prompts, stop + async def aprep_prompts( + self, input_list: List[Dict[str, Any]] + ) -> Tuple[List[str], Optional[List[str]]]: + """Prepare prompts from inputs.""" + stop = None + if "stop" in input_list[0]: + stop = input_list[0]["stop"] + prompts = [] + for inputs in input_list: + selected_inputs = {k: inputs[k] for k in self.prompt.input_variables} + prompt = self.prompt.format(**selected_inputs) + _colored_text = get_colored_text(prompt, "green") + _text = "Prompt after formatting:\n" + _colored_text + if self.callback_manager.is_async: + await self.callback_manager.on_text( + _text, end="\n", verbose=self.verbose + ) + else: + self.callback_manager.on_text(_text, end="\n", verbose=self.verbose) + if "stop" in inputs and inputs["stop"] != stop: + raise ValueError( + "If `stop` is present in any inputs, should be present in all." + ) + prompts.append(prompt) + return prompts, stop + def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: """Utilize the LLM generate method for speed gains.""" response = self.generate(input_list) @@ -156,6 +182,11 @@ class LLMChain(Chain, BaseModel): ) -> Sequence[Union[str, List[str], Dict[str, str]]]: """Call apply and then parse the results.""" result = self.apply(input_list) + return self._parse_result(result) + + def _parse_result( + self, result: List[Dict[str, str]] + ) -> Sequence[Union[str, List[str], Dict[str, str]]]: if self.prompt.output_parser is not None: new_result = [] for res in result: @@ -165,6 +196,13 @@ class LLMChain(Chain, BaseModel): else: return result + async def aapply_and_parse( + self, input_list: List[Dict[str, Any]] + ) -> Sequence[Union[str, List[str], Dict[str, str]]]: + """Call apply and then parse the results.""" + result = await self.aapply(input_list) + return self._parse_result(result) + @property def _chain_type(self) -> str: return "llm_chain" diff --git a/langchain/chains/question_answering/__init__.py b/langchain/chains/question_answering/__init__.py index a2b64445..101a6156 100644 --- a/langchain/chains/question_answering/__init__.py +++ b/langchain/chains/question_answering/__init__.py @@ -1,6 +1,7 @@ """Load question answering chains.""" from typing import Any, Mapping, Optional, Protocol +from langchain.callbacks.base import BaseCallbackManager from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain @@ -31,14 +32,19 @@ def _load_map_rerank_chain( document_variable_name: str = "context", rank_key: str = "score", answer_key: str = "answer", + callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, ) -> MapRerankDocumentsChain: - llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) + llm_chain = LLMChain( + llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager + ) return MapRerankDocumentsChain( llm_chain=llm_chain, rank_key=rank_key, answer_key=answer_key, document_variable_name=document_variable_name, + verbose=verbose, + callback_manager=callback_manager, **kwargs, ) @@ -48,14 +54,18 @@ def _load_stuff_chain( prompt: BasePromptTemplate = stuff_prompt.PROMPT, document_variable_name: str = "context", verbose: Optional[bool] = None, + callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, ) -> StuffDocumentsChain: - llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) + llm_chain = LLMChain( + llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager + ) # TODO: document prompt return StuffDocumentsChain( llm_chain=llm_chain, document_variable_name=document_variable_name, verbose=verbose, + callback_manager=callback_manager, **kwargs, ) @@ -70,16 +80,28 @@ def _load_map_reduce_chain( reduce_llm: Optional[BaseLLM] = None, collapse_llm: Optional[BaseLLM] = None, verbose: Optional[bool] = None, + callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, ) -> MapReduceDocumentsChain: - map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) + map_chain = LLMChain( + llm=llm, + prompt=question_prompt, + verbose=verbose, + callback_manager=callback_manager, + ) _reduce_llm = reduce_llm or llm - reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) + reduce_chain = LLMChain( + llm=_reduce_llm, + prompt=combine_prompt, + verbose=verbose, + callback_manager=callback_manager, + ) # TODO: document prompt combine_document_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name=combine_document_variable_name, verbose=verbose, + callback_manager=callback_manager, ) if collapse_prompt is None: collapse_chain = None @@ -95,8 +117,11 @@ def _load_map_reduce_chain( llm=_collapse_llm, prompt=collapse_prompt, verbose=verbose, + callback_manager=callback_manager, ), document_variable_name=combine_document_variable_name, + verbose=verbose, + callback_manager=callback_manager, ) return MapReduceDocumentsChain( llm_chain=map_chain, @@ -104,6 +129,7 @@ def _load_map_reduce_chain( document_variable_name=map_reduce_document_variable_name, collapse_document_chain=collapse_chain, verbose=verbose, + callback_manager=callback_manager, **kwargs, ) @@ -116,17 +142,29 @@ def _load_refine_chain( initial_response_name: str = "existing_answer", refine_llm: Optional[BaseLLM] = None, verbose: Optional[bool] = None, + callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, ) -> RefineDocumentsChain: - initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) + initial_chain = LLMChain( + llm=llm, + prompt=question_prompt, + verbose=verbose, + callback_manager=callback_manager, + ) _refine_llm = refine_llm or llm - refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) + refine_chain = LLMChain( + llm=_refine_llm, + prompt=refine_prompt, + verbose=verbose, + callback_manager=callback_manager, + ) return RefineDocumentsChain( initial_llm_chain=initial_chain, refine_llm_chain=refine_chain, document_variable_name=document_variable_name, initial_response_name=initial_response_name, verbose=verbose, + callback_manager=callback_manager, **kwargs, ) @@ -135,6 +173,7 @@ def load_qa_chain( llm: BaseLLM, chain_type: str = "stuff", verbose: Optional[bool] = None, + callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, ) -> BaseCombineDocumentsChain: """Load question answering chain. @@ -145,6 +184,7 @@ def load_qa_chain( "map_reduce", and "refine". verbose: Whether chains should be run in verbose mode or not. Note that this applies to all chains that make up the final chain. + callback_manager: Callback manager to use for the chain. Returns: A chain to use for question answering. @@ -160,4 +200,6 @@ def load_qa_chain( f"Got unsupported chain type: {chain_type}. " f"Should be one of {loader_mapping.keys()}" ) - return loader_mapping[chain_type](llm, verbose=verbose, **kwargs) + return loader_mapping[chain_type]( + llm, verbose=verbose, callback_manager=callback_manager, **kwargs + ) diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 954ee3c1..43db1556 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -165,15 +165,26 @@ class BaseLLM(BaseModel, ABC): raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompts, verbose=self.verbose - ) + if self.callback_manager.is_async: + await self.callback_manager.on_llm_start( + {"name": self.__class__.__name__}, prompts, verbose=self.verbose + ) + else: + self.callback_manager.on_llm_start( + {"name": self.__class__.__name__}, prompts, verbose=self.verbose + ) try: output = await self._agenerate(prompts, stop=stop) except (KeyboardInterrupt, Exception) as e: - self.callback_manager.on_llm_error(e, verbose=self.verbose) + if self.callback_manager.is_async: + await self.callback_manager.on_llm_error(e, verbose=self.verbose) + else: + self.callback_manager.on_llm_error(e, verbose=self.verbose) raise e - self.callback_manager.on_llm_end(output, verbose=self.verbose) + if self.callback_manager.is_async: + await self.callback_manager.on_llm_end(output, verbose=self.verbose) + else: + self.callback_manager.on_llm_end(output, verbose=self.verbose) return output params = self.dict() params["stop"] = stop @@ -184,15 +195,32 @@ class BaseLLM(BaseModel, ABC): missing_prompts, ) = get_prompts(params, prompts) if len(missing_prompts) > 0: - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose - ) + if self.callback_manager.is_async: + await self.callback_manager.on_llm_start( + {"name": self.__class__.__name__}, + missing_prompts, + verbose=self.verbose, + ) + else: + self.callback_manager.on_llm_start( + {"name": self.__class__.__name__}, + missing_prompts, + verbose=self.verbose, + ) try: new_results = await self._agenerate(missing_prompts, stop=stop) except (KeyboardInterrupt, Exception) as e: - self.callback_manager.on_llm_error(e, verbose=self.verbose) + if self.callback_manager.is_async: + await self.callback_manager.on_llm_error(e, verbose=self.verbose) + else: + self.callback_manager.on_llm_error(e, verbose=self.verbose) raise e - self.callback_manager.on_llm_end(new_results, verbose=self.verbose) + if self.callback_manager.is_async: + await self.callback_manager.on_llm_end( + new_results, verbose=self.verbose + ) + else: + self.callback_manager.on_llm_end(new_results, verbose=self.verbose) llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 6bb73122..12728b88 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -42,6 +42,27 @@ def update_token_usage( token_usage[_key] += response["usage"][_key] +def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None: + """Update response from the stream response.""" + response["choices"][0]["text"] += stream_response["choices"][0]["text"] + response["choices"][0]["finish_reason"] = stream_response["choices"][0][ + "finish_reason" + ] + response["choices"][0]["logprobs"] = stream_response["choices"][0]["logprobs"] + + +def _streaming_response_template() -> Dict[str, Any]: + return { + "choices": [ + { + "text": "", + "finish_reason": None, + "logprobs": None, + } + ] + } + + class BaseOpenAI(BaseLLM, BaseModel): """Wrapper around OpenAI large language models. @@ -88,6 +109,8 @@ class BaseOpenAI(BaseLLM, BaseModel): """Adjust the probability of specific tokens being generated.""" max_retries: int = 6 """Maximum number of retries to make when generating.""" + streaming: bool = False + """Whether to stream the results or not.""" class Config: """Configuration for this pydantic object.""" @@ -129,6 +152,10 @@ class BaseOpenAI(BaseLLM, BaseModel): "Could not import openai python package. " "Please it install it with `pip install openai`." ) + if values["streaming"] and values["n"] > 1: + raise ValueError("Cannot stream results when n > 1.") + if values["streaming"] and values["best_of"] > 1: + raise ValueError("Cannot stream results when best_of > 1.") return values @property @@ -215,9 +242,25 @@ class BaseOpenAI(BaseLLM, BaseModel): # Includes prompt, completion, and total tokens used. _keys = {"completion_tokens", "prompt_tokens", "total_tokens"} for _prompts in sub_prompts: - response = self.completion_with_retry(prompt=_prompts, **params) - choices.extend(response["choices"]) - update_token_usage(_keys, response, token_usage) + if self.streaming: + if len(_prompts) > 1: + raise ValueError("Cannot stream results with multiple prompts.") + params["stream"] = True + response = _streaming_response_template() + for stream_resp in self.completion_with_retry( + prompt=_prompts, **params + ): + self.callback_manager.on_llm_new_token( + stream_resp["choices"][0]["text"], verbose=self.verbose + ) + _update_response(response, stream_resp) + choices.extend(response["choices"]) + else: + response = self.completion_with_retry(prompt=_prompts, **params) + choices.extend(response["choices"]) + if not self.streaming: + # Can't update token usage if streaming + update_token_usage(_keys, response, token_usage) return self.create_llm_result(choices, prompts, token_usage) async def _agenerate( @@ -232,10 +275,30 @@ class BaseOpenAI(BaseLLM, BaseModel): # Includes prompt, completion, and total tokens used. _keys = {"completion_tokens", "prompt_tokens", "total_tokens"} for _prompts in sub_prompts: - # Use OpenAI's async api https://github.com/openai/openai-python#async-api - response = await self.acompletion_with_retry(prompt=_prompts, **params) - choices.extend(response["choices"]) - update_token_usage(_keys, response, token_usage) + if self.streaming: + if len(_prompts) > 1: + raise ValueError("Cannot stream results with multiple prompts.") + params["stream"] = True + response = _streaming_response_template() + async for stream_resp in await self.acompletion_with_retry( + prompt=_prompts, **params + ): + if self.callback_manager.is_async: + await self.callback_manager.on_llm_new_token( + stream_resp["choices"][0]["text"], verbose=self.verbose + ) + else: + self.callback_manager.on_llm_new_token( + stream_resp["choices"][0]["text"], verbose=self.verbose + ) + _update_response(response, stream_resp) + choices.extend(response["choices"]) + else: + response = await self.acompletion_with_retry(prompt=_prompts, **params) + choices.extend(response["choices"]) + if not self.streaming: + # Can't update token usage if streaming + update_token_usage(_keys, response, token_usage) return self.create_llm_result(choices, prompts, token_usage) def get_sub_prompts( @@ -304,6 +367,13 @@ class BaseOpenAI(BaseLLM, BaseModel): for token in generator: yield token """ + params = self.prep_streaming_params(stop) + generator = self.client.create(prompt=prompt, **params) + + return generator + + def prep_streaming_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]: + """Prepare the params for streaming.""" params = self._invocation_params if params["best_of"] != 1: raise ValueError("OpenAI only supports best_of == 1 for streaming") @@ -312,9 +382,7 @@ class BaseOpenAI(BaseLLM, BaseModel): raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop params["stream"] = True - generator = self.client.create(prompt=prompt, **params) - - return generator + return params @property def _invocation_params(self) -> Dict[str, Any]: diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index 232e1dd9..d0bc9f9d 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -5,9 +5,11 @@ from typing import Generator import pytest +from langchain.callbacks.base import CallbackManager from langchain.llms.loading import load_llm from langchain.llms.openai import OpenAI from langchain.schema import LLMResult +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler def test_openai_call() -> None: @@ -77,9 +79,66 @@ def test_openai_streaming_error() -> None: llm.stream("I'm Pickle Rick") +def test_openai_streaming_best_of_error() -> None: + """Test validation for streaming fails if best_of is not 1.""" + with pytest.raises(ValueError): + OpenAI(best_of=2, streaming=True) + + +def test_openai_streaming_n_error() -> None: + """Test validation for streaming fails if n is not 1.""" + with pytest.raises(ValueError): + OpenAI(n=2, streaming=True) + + +def test_openai_streaming_multiple_prompts_error() -> None: + """Test validation for streaming fails if multiple prompts are given.""" + with pytest.raises(ValueError): + OpenAI(streaming=True).generate(["I'm Pickle Rick", "I'm Pickle Rick"]) + + +def test_openai_streaming_call() -> None: + """Test valid call to openai.""" + llm = OpenAI(max_tokens=10, streaming=True) + output = llm("Say foo:") + assert isinstance(output, str) + + +def test_openai_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + llm = OpenAI( + max_tokens=10, + streaming=True, + temperature=0, + callback_manager=callback_manager, + verbose=True, + ) + llm("Write me a sentence with 100 words.") + assert callback_handler.llm_streams == 10 + + @pytest.mark.asyncio async def test_openai_async_generate() -> None: """Test async generation.""" llm = OpenAI(max_tokens=10) output = await llm.agenerate(["Hello, how are you?"]) assert isinstance(output, LLMResult) + + +@pytest.mark.asyncio +async def test_openai_async_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + llm = OpenAI( + max_tokens=10, + streaming=True, + temperature=0, + callback_manager=callback_manager, + verbose=True, + ) + result = await llm.agenerate(["Write me a sentence with 100 words."]) + assert callback_handler.llm_streams == 10 + assert isinstance(result, LLMResult) diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index 7dd0fc01..3132faa9 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -3,12 +3,12 @@ from typing import Any, Dict, List, Union from pydantic import BaseModel -from langchain.callbacks.base import BaseCallbackHandler +from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, LLMResult -class FakeCallbackHandler(BaseModel, BaseCallbackHandler): - """Fake callback handler for testing.""" +class BaseFakeCallbackHandler(BaseModel): + """Base fake callback handler for testing.""" starts: int = 0 ends: int = 0 @@ -44,10 +44,15 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler): chain_ends: int = 0 llm_starts: int = 0 llm_ends: int = 0 + llm_streams: int = 0 tool_starts: int = 0 tool_ends: int = 0 agent_ends: int = 0 + +class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler): + """Fake callback handler for testing.""" + def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: @@ -55,6 +60,10 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler): self.llm_starts += 1 self.starts += 1 + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run when LLM generates a new token.""" + self.llm_streams += 1 + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running.""" self.llm_ends += 1 @@ -110,3 +119,74 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler): """Run when agent ends running.""" self.agent_ends += 1 self.ends += 1 + + +class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler): + """Fake async callback handler for testing.""" + + async def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts running.""" + self.llm_starts += 1 + self.starts += 1 + + async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run when LLM generates a new token.""" + self.llm_streams += 1 + + async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + self.llm_ends += 1 + self.ends += 1 + + async def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when LLM errors.""" + self.errors += 1 + + async def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + self.chain_starts += 1 + self.starts += 1 + + async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + self.chain_ends += 1 + self.ends += 1 + + async def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when chain errors.""" + self.errors += 1 + + async def on_tool_start( + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + self.tool_starts += 1 + self.starts += 1 + + async def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + self.tool_ends += 1 + self.ends += 1 + + async def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when tool errors.""" + self.errors += 1 + + async def on_text(self, text: str, **kwargs: Any) -> None: + """Run when agent is ending.""" + self.text += 1 + + async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run when agent ends running.""" + self.agent_ends += 1 + self.ends += 1 diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index acca2171..081be3dd 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -1,13 +1,24 @@ """Test CallbackManager.""" +from typing import Tuple -from langchain.callbacks.base import BaseCallbackManager, CallbackManager +import pytest + +from langchain.callbacks.base import ( + AsyncCallbackManager, + BaseCallbackManager, + CallbackManager, +) from langchain.callbacks.shared import SharedCallbackManager from langchain.schema import AgentAction, AgentFinish, LLMResult -from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +from tests.unit_tests.callbacks.fake_callback_handler import ( + BaseFakeCallbackHandler, + FakeAsyncCallbackHandler, + FakeCallbackHandler, +) def _test_callback_manager( - manager: BaseCallbackManager, *handlers: FakeCallbackHandler + manager: BaseCallbackManager, *handlers: BaseFakeCallbackHandler ) -> None: """Test the CallbackManager.""" manager.on_llm_start({}, []) @@ -20,6 +31,27 @@ def _test_callback_manager( manager.on_tool_end("") manager.on_tool_error(Exception()) manager.on_agent_finish(AgentFinish(log="", return_values={})) + _check_num_calls(handlers) + + +async def _test_callback_manager_async( + manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler +) -> None: + """Test the CallbackManager.""" + await manager.on_llm_start({}, []) + await manager.on_llm_end(LLMResult(generations=[])) + await manager.on_llm_error(Exception()) + await manager.on_chain_start({"name": "foo"}, {}) + await manager.on_chain_end({}) + await manager.on_chain_error(Exception()) + await manager.on_tool_start({}, AgentAction("", "", "")) + await manager.on_tool_end("") + await manager.on_tool_error(Exception()) + await manager.on_agent_finish(AgentFinish(log="", return_values={})) + _check_num_calls(handlers) + + +def _check_num_calls(handlers: Tuple[BaseFakeCallbackHandler, ...]) -> None: for handler in handlers: if handler.always_verbose: assert handler.starts == 3 @@ -128,3 +160,21 @@ def test_shared_callback_manager() -> None: manager1.add_handler(handler1) manager2.add_handler(handler2) _test_callback_manager(manager1, handler1, handler2) + + +@pytest.mark.asyncio +async def test_async_callback_manager() -> None: + """Test the AsyncCallbackManager.""" + handler1 = FakeAsyncCallbackHandler(always_verbose_=True) + handler2 = FakeAsyncCallbackHandler() + manager = AsyncCallbackManager([handler1, handler2]) + await _test_callback_manager_async(manager, handler1, handler2) + + +@pytest.mark.asyncio +async def test_async_callback_manager_sync_handler() -> None: + """Test the AsyncCallbackManager.""" + handler1 = FakeCallbackHandler(always_verbose_=True) + handler2 = FakeAsyncCallbackHandler() + manager = AsyncCallbackManager([handler1, handler2]) + await _test_callback_manager_async(manager, handler1, handler2)