improve AsyncCallbackManager (#2410)

This commit is contained in:
Ankush Gola 2023-04-05 09:31:42 +02:00 committed by GitHub
parent af7f20fa42
commit 4d730a9bbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 76 additions and 143 deletions

View File

@ -2,7 +2,7 @@
import asyncio import asyncio
import functools import functools
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Optional, Union
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
@ -328,6 +328,25 @@ class AsyncCallbackHandler(BaseCallbackHandler):
"""Run on agent end.""" """Run on agent end."""
async def _handle_event_for_handler(
handler: BaseCallbackHandler,
event_name: str,
ignore_condition_name: Optional[str],
verbose: bool,
*args: Any,
**kwargs: Any
) -> None:
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
if verbose or handler.always_verbose:
event = getattr(handler, event_name)
if asyncio.iscoroutinefunction(event):
await event(*args, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None, functools.partial(event, *args, **kwargs)
)
class AsyncCallbackManager(BaseCallbackManager): class AsyncCallbackManager(BaseCallbackManager):
"""Async callback manager that can be used to handle callbacks from LangChain.""" """Async callback manager that can be used to handle callbacks from LangChain."""
@ -340,6 +359,24 @@ class AsyncCallbackManager(BaseCallbackManager):
"""Initialize callback manager.""" """Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers self.handlers: List[BaseCallbackHandler] = handlers
async def _handle_event(
self,
event_name: str,
ignore_condition_name: Optional[str],
verbose: bool,
*args: Any,
**kwargs: Any
) -> None:
"""Generic event handler for AsyncCallbackManager."""
await asyncio.gather(
*(
_handle_event_for_handler(
handler, event_name, ignore_condition_name, verbose, *args, **kwargs
)
for handler in self.handlers
)
)
async def on_llm_start( async def on_llm_start(
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
@ -348,50 +385,25 @@ class AsyncCallbackManager(BaseCallbackManager):
**kwargs: Any **kwargs: Any
) -> None: ) -> None:
"""Run when LLM starts running.""" """Run when LLM starts running."""
for handler in self.handlers: await self._handle_event(
if not handler.ignore_llm: "on_llm_start", "ignore_llm", verbose, serialized, prompts, **kwargs
if verbose or handler.always_verbose: )
if asyncio.iscoroutinefunction(handler.on_llm_start):
await handler.on_llm_start(serialized, prompts, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_llm_start, serialized, prompts, **kwargs
),
)
async def on_llm_new_token( async def on_llm_new_token(
self, token: str, verbose: bool = False, **kwargs: Any self, token: str, verbose: bool = False, **kwargs: Any
) -> None: ) -> None:
"""Run on new LLM token. Only available when streaming is enabled.""" """Run on new LLM token. Only available when streaming is enabled."""
for handler in self.handlers: await self._handle_event(
if not handler.ignore_llm: "on_llm_new_token", "ignore_llm", verbose, token, **kwargs
if verbose or handler.always_verbose: )
if asyncio.iscoroutinefunction(handler.on_llm_new_token):
await handler.on_llm_new_token(token, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_llm_new_token, token, **kwargs
),
)
async def on_llm_end( async def on_llm_end(
self, response: LLMResult, verbose: bool = False, **kwargs: Any self, response: LLMResult, verbose: bool = False, **kwargs: Any
) -> None: ) -> None:
"""Run when LLM ends running.""" """Run when LLM ends running."""
for handler in self.handlers: await self._handle_event(
if not handler.ignore_llm: "on_llm_end", "ignore_llm", verbose, response, **kwargs
if verbose or handler.always_verbose: )
if asyncio.iscoroutinefunction(handler.on_llm_end):
await handler.on_llm_end(response, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_llm_end, response, **kwargs),
)
async def on_llm_error( async def on_llm_error(
self, self,
@ -400,16 +412,7 @@ class AsyncCallbackManager(BaseCallbackManager):
**kwargs: Any **kwargs: Any
) -> None: ) -> None:
"""Run when LLM errors.""" """Run when LLM errors."""
for handler in self.handlers: await self._handle_event("on_llm_error", "ignore_llm", verbose, error, **kwargs)
if not handler.ignore_llm:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_llm_error):
await handler.on_llm_error(error, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_llm_error, error, **kwargs),
)
async def on_chain_start( async def on_chain_start(
self, self,
@ -419,33 +422,17 @@ class AsyncCallbackManager(BaseCallbackManager):
**kwargs: Any **kwargs: Any
) -> None: ) -> None:
"""Run when chain starts running.""" """Run when chain starts running."""
for handler in self.handlers: await self._handle_event(
if not handler.ignore_chain: "on_chain_start", "ignore_chain", verbose, serialized, inputs, **kwargs
if verbose or handler.always_verbose: )
if asyncio.iscoroutinefunction(handler.on_chain_start):
await handler.on_chain_start(serialized, inputs, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_chain_start, serialized, inputs, **kwargs
),
)
async def on_chain_end( async def on_chain_end(
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
) -> None: ) -> None:
"""Run when chain ends running.""" """Run when chain ends running."""
for handler in self.handlers: await self._handle_event(
if not handler.ignore_chain: "on_chain_end", "ignore_chain", verbose, outputs, **kwargs
if verbose or handler.always_verbose: )
if asyncio.iscoroutinefunction(handler.on_chain_end):
await handler.on_chain_end(outputs, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_chain_end, outputs, **kwargs),
)
async def on_chain_error( async def on_chain_error(
self, self,
@ -454,16 +441,9 @@ class AsyncCallbackManager(BaseCallbackManager):
**kwargs: Any **kwargs: Any
) -> None: ) -> None:
"""Run when chain errors.""" """Run when chain errors."""
for handler in self.handlers: await self._handle_event(
if not handler.ignore_chain: "on_chain_error", "ignore_chain", verbose, error, **kwargs
if verbose or handler.always_verbose: )
if asyncio.iscoroutinefunction(handler.on_chain_error):
await handler.on_chain_error(error, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_chain_error, error, **kwargs),
)
async def on_tool_start( async def on_tool_start(
self, self,
@ -473,33 +453,17 @@ class AsyncCallbackManager(BaseCallbackManager):
**kwargs: Any **kwargs: Any
) -> None: ) -> None:
"""Run when tool starts running.""" """Run when tool starts running."""
for handler in self.handlers: await self._handle_event(
if not handler.ignore_agent: "on_tool_start", "ignore_agent", verbose, serialized, input_str, **kwargs
if verbose or handler.always_verbose: )
if asyncio.iscoroutinefunction(handler.on_tool_start):
await handler.on_tool_start(serialized, input_str, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_tool_start, serialized, input_str, **kwargs
),
)
async def on_tool_end( async def on_tool_end(
self, output: str, verbose: bool = False, **kwargs: Any self, output: str, verbose: bool = False, **kwargs: Any
) -> None: ) -> None:
"""Run when tool ends running.""" """Run when tool ends running."""
for handler in self.handlers: await self._handle_event(
if not handler.ignore_agent: "on_tool_end", "ignore_agent", verbose, output, **kwargs
if verbose or handler.always_verbose: )
if asyncio.iscoroutinefunction(handler.on_tool_end):
await handler.on_tool_end(output, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_tool_end, output, **kwargs),
)
async def on_tool_error( async def on_tool_error(
self, self,
@ -508,61 +472,29 @@ class AsyncCallbackManager(BaseCallbackManager):
**kwargs: Any **kwargs: Any
) -> None: ) -> None:
"""Run when tool errors.""" """Run when tool errors."""
for handler in self.handlers: await self._handle_event(
if not handler.ignore_agent: "on_tool_error", "ignore_agent", verbose, error, **kwargs
if verbose or handler.always_verbose: )
if asyncio.iscoroutinefunction(handler.on_tool_error):
await handler.on_tool_error(error, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_tool_error, error, **kwargs),
)
async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None: async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
"""Run when text is printed.""" """Run when text is printed."""
for handler in self.handlers: await self._handle_event("on_text", None, verbose, text, **kwargs)
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_text):
await handler.on_text(text, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None, functools.partial(handler.on_text, text, **kwargs)
)
async def on_agent_action( async def on_agent_action(
self, action: AgentAction, verbose: bool = False, **kwargs: Any self, action: AgentAction, verbose: bool = False, **kwargs: Any
) -> None: ) -> None:
"""Run on agent action.""" """Run on agent action."""
for handler in self.handlers: await self._handle_event(
if not handler.ignore_agent: "on_agent_action", "ignore_agent", verbose, action, **kwargs
if verbose or handler.always_verbose: )
if asyncio.iscoroutinefunction(handler.on_agent_action):
await handler.on_agent_action(action, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_agent_action, action, **kwargs
),
)
async def on_agent_finish( async def on_agent_finish(
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
) -> None: ) -> None:
"""Run when agent finishes.""" """Run when agent finishes."""
for handler in self.handlers: await self._handle_event(
if not handler.ignore_agent: "on_agent_finish", "ignore_agent", verbose, finish, **kwargs
if verbose or handler.always_verbose: )
if asyncio.iscoroutinefunction(handler.on_agent_finish):
await handler.on_agent_finish(finish, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_agent_finish, finish, **kwargs
),
)
def add_handler(self, handler: BaseCallbackHandler) -> None: def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager.""" """Add a handler to the callback manager."""

View File

@ -176,5 +176,6 @@ async def test_async_callback_manager_sync_handler() -> None:
"""Test the AsyncCallbackManager.""" """Test the AsyncCallbackManager."""
handler1 = FakeCallbackHandler(always_verbose_=True) handler1 = FakeCallbackHandler(always_verbose_=True)
handler2 = FakeAsyncCallbackHandler() handler2 = FakeAsyncCallbackHandler()
manager = AsyncCallbackManager([handler1, handler2]) handler3 = FakeAsyncCallbackHandler(always_verbose_=True)
await _test_callback_manager_async(manager, handler1, handler2) manager = AsyncCallbackManager([handler1, handler2, handler3])
await _test_callback_manager_async(manager, handler1, handler2, handler3)