From 4d730a9bbcd384da27e2e7c5a6a6efa5eac55838 Mon Sep 17 00:00:00 2001 From: Ankush Gola <9536492+agola11@users.noreply.github.com> Date: Wed, 5 Apr 2023 09:31:42 +0200 Subject: [PATCH] improve `AsyncCallbackManager` (#2410) --- langchain/callbacks/base.py | 214 ++++++------------ .../callbacks/test_callback_manager.py | 5 +- 2 files changed, 76 insertions(+), 143 deletions(-) diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index 6a425f18..5b47b82a 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -2,7 +2,7 @@ import asyncio import functools from abc import ABC, abstractmethod -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from langchain.schema import AgentAction, AgentFinish, LLMResult @@ -328,6 +328,25 @@ class AsyncCallbackHandler(BaseCallbackHandler): """Run on agent end.""" +async def _handle_event_for_handler( + handler: BaseCallbackHandler, + event_name: str, + ignore_condition_name: Optional[str], + verbose: bool, + *args: Any, + **kwargs: Any +) -> None: + if ignore_condition_name is None or not getattr(handler, ignore_condition_name): + if verbose or handler.always_verbose: + event = getattr(handler, event_name) + if asyncio.iscoroutinefunction(event): + await event(*args, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, functools.partial(event, *args, **kwargs) + ) + + class AsyncCallbackManager(BaseCallbackManager): """Async callback manager that can be used to handle callbacks from LangChain.""" @@ -340,6 +359,24 @@ class AsyncCallbackManager(BaseCallbackManager): """Initialize callback manager.""" self.handlers: List[BaseCallbackHandler] = handlers + async def _handle_event( + self, + event_name: str, + ignore_condition_name: Optional[str], + verbose: bool, + *args: Any, + **kwargs: Any + ) -> None: + """Generic event handler for AsyncCallbackManager.""" + await asyncio.gather( + *( + _handle_event_for_handler( + handler, event_name, ignore_condition_name, verbose, *args, **kwargs + ) + for handler in self.handlers + ) + ) + async def on_llm_start( self, serialized: Dict[str, Any], @@ -348,50 +385,25 @@ class AsyncCallbackManager(BaseCallbackManager): **kwargs: Any ) -> None: """Run when LLM starts running.""" - for handler in self.handlers: - if not handler.ignore_llm: - if verbose or handler.always_verbose: - if asyncio.iscoroutinefunction(handler.on_llm_start): - await handler.on_llm_start(serialized, prompts, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, - functools.partial( - handler.on_llm_start, serialized, prompts, **kwargs - ), - ) + await self._handle_event( + "on_llm_start", "ignore_llm", verbose, serialized, prompts, **kwargs + ) async def on_llm_new_token( self, token: str, verbose: bool = False, **kwargs: Any ) -> None: """Run on new LLM token. Only available when streaming is enabled.""" - for handler in self.handlers: - if not handler.ignore_llm: - if verbose or handler.always_verbose: - if asyncio.iscoroutinefunction(handler.on_llm_new_token): - await handler.on_llm_new_token(token, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, - functools.partial( - handler.on_llm_new_token, token, **kwargs - ), - ) + await self._handle_event( + "on_llm_new_token", "ignore_llm", verbose, token, **kwargs + ) async def on_llm_end( self, response: LLMResult, verbose: bool = False, **kwargs: Any ) -> None: """Run when LLM ends running.""" - for handler in self.handlers: - if not handler.ignore_llm: - if verbose or handler.always_verbose: - if asyncio.iscoroutinefunction(handler.on_llm_end): - await handler.on_llm_end(response, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, - functools.partial(handler.on_llm_end, response, **kwargs), - ) + await self._handle_event( + "on_llm_end", "ignore_llm", verbose, response, **kwargs + ) async def on_llm_error( self, @@ -400,16 +412,7 @@ class AsyncCallbackManager(BaseCallbackManager): **kwargs: Any ) -> None: """Run when LLM errors.""" - for handler in self.handlers: - if not handler.ignore_llm: - if verbose or handler.always_verbose: - if asyncio.iscoroutinefunction(handler.on_llm_error): - await handler.on_llm_error(error, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, - functools.partial(handler.on_llm_error, error, **kwargs), - ) + await self._handle_event("on_llm_error", "ignore_llm", verbose, error, **kwargs) async def on_chain_start( self, @@ -419,33 +422,17 @@ class AsyncCallbackManager(BaseCallbackManager): **kwargs: Any ) -> None: """Run when chain starts running.""" - for handler in self.handlers: - if not handler.ignore_chain: - if verbose or handler.always_verbose: - if asyncio.iscoroutinefunction(handler.on_chain_start): - await handler.on_chain_start(serialized, inputs, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, - functools.partial( - handler.on_chain_start, serialized, inputs, **kwargs - ), - ) + await self._handle_event( + "on_chain_start", "ignore_chain", verbose, serialized, inputs, **kwargs + ) async def on_chain_end( self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any ) -> None: """Run when chain ends running.""" - for handler in self.handlers: - if not handler.ignore_chain: - if verbose or handler.always_verbose: - if asyncio.iscoroutinefunction(handler.on_chain_end): - await handler.on_chain_end(outputs, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, - functools.partial(handler.on_chain_end, outputs, **kwargs), - ) + await self._handle_event( + "on_chain_end", "ignore_chain", verbose, outputs, **kwargs + ) async def on_chain_error( self, @@ -454,16 +441,9 @@ class AsyncCallbackManager(BaseCallbackManager): **kwargs: Any ) -> None: """Run when chain errors.""" - for handler in self.handlers: - if not handler.ignore_chain: - if verbose or handler.always_verbose: - if asyncio.iscoroutinefunction(handler.on_chain_error): - await handler.on_chain_error(error, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, - functools.partial(handler.on_chain_error, error, **kwargs), - ) + await self._handle_event( + "on_chain_error", "ignore_chain", verbose, error, **kwargs + ) async def on_tool_start( self, @@ -473,33 +453,17 @@ class AsyncCallbackManager(BaseCallbackManager): **kwargs: Any ) -> None: """Run when tool starts running.""" - for handler in self.handlers: - if not handler.ignore_agent: - if verbose or handler.always_verbose: - if asyncio.iscoroutinefunction(handler.on_tool_start): - await handler.on_tool_start(serialized, input_str, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, - functools.partial( - handler.on_tool_start, serialized, input_str, **kwargs - ), - ) + await self._handle_event( + "on_tool_start", "ignore_agent", verbose, serialized, input_str, **kwargs + ) async def on_tool_end( self, output: str, verbose: bool = False, **kwargs: Any ) -> None: """Run when tool ends running.""" - for handler in self.handlers: - if not handler.ignore_agent: - if verbose or handler.always_verbose: - if asyncio.iscoroutinefunction(handler.on_tool_end): - await handler.on_tool_end(output, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, - functools.partial(handler.on_tool_end, output, **kwargs), - ) + await self._handle_event( + "on_tool_end", "ignore_agent", verbose, output, **kwargs + ) async def on_tool_error( self, @@ -508,61 +472,29 @@ class AsyncCallbackManager(BaseCallbackManager): **kwargs: Any ) -> None: """Run when tool errors.""" - for handler in self.handlers: - if not handler.ignore_agent: - if verbose or handler.always_verbose: - if asyncio.iscoroutinefunction(handler.on_tool_error): - await handler.on_tool_error(error, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, - functools.partial(handler.on_tool_error, error, **kwargs), - ) + await self._handle_event( + "on_tool_error", "ignore_agent", verbose, error, **kwargs + ) async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None: """Run when text is printed.""" - for handler in self.handlers: - if verbose or handler.always_verbose: - if asyncio.iscoroutinefunction(handler.on_text): - await handler.on_text(text, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, functools.partial(handler.on_text, text, **kwargs) - ) + await self._handle_event("on_text", None, verbose, text, **kwargs) async def on_agent_action( self, action: AgentAction, verbose: bool = False, **kwargs: Any ) -> None: """Run on agent action.""" - for handler in self.handlers: - if not handler.ignore_agent: - if verbose or handler.always_verbose: - if asyncio.iscoroutinefunction(handler.on_agent_action): - await handler.on_agent_action(action, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, - functools.partial( - handler.on_agent_action, action, **kwargs - ), - ) + await self._handle_event( + "on_agent_action", "ignore_agent", verbose, action, **kwargs + ) async def on_agent_finish( self, finish: AgentFinish, verbose: bool = False, **kwargs: Any ) -> None: """Run when agent finishes.""" - for handler in self.handlers: - if not handler.ignore_agent: - if verbose or handler.always_verbose: - if asyncio.iscoroutinefunction(handler.on_agent_finish): - await handler.on_agent_finish(finish, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, - functools.partial( - handler.on_agent_finish, finish, **kwargs - ), - ) + await self._handle_event( + "on_agent_finish", "ignore_agent", verbose, finish, **kwargs + ) def add_handler(self, handler: BaseCallbackHandler) -> None: """Add a handler to the callback manager.""" diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index 7819798c..0f61fdd3 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -176,5 +176,6 @@ async def test_async_callback_manager_sync_handler() -> None: """Test the AsyncCallbackManager.""" handler1 = FakeCallbackHandler(always_verbose_=True) handler2 = FakeAsyncCallbackHandler() - manager = AsyncCallbackManager([handler1, handler2]) - await _test_callback_manager_async(manager, handler1, handler2) + handler3 = FakeAsyncCallbackHandler(always_verbose_=True) + manager = AsyncCallbackManager([handler1, handler2, handler3]) + await _test_callback_manager_async(manager, handler1, handler2, handler3)