forked from Archives/langchain
improve AsyncCallbackManager
(#2410)
This commit is contained in:
parent
af7f20fa42
commit
4d730a9bbc
@ -2,7 +2,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
@ -328,6 +328,25 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Run on agent end."""
|
"""Run on agent end."""
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_event_for_handler(
|
||||||
|
handler: BaseCallbackHandler,
|
||||||
|
event_name: str,
|
||||||
|
ignore_condition_name: Optional[str],
|
||||||
|
verbose: bool,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
event = getattr(handler, event_name)
|
||||||
|
if asyncio.iscoroutinefunction(event):
|
||||||
|
await event(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, functools.partial(event, *args, **kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AsyncCallbackManager(BaseCallbackManager):
|
class AsyncCallbackManager(BaseCallbackManager):
|
||||||
"""Async callback manager that can be used to handle callbacks from LangChain."""
|
"""Async callback manager that can be used to handle callbacks from LangChain."""
|
||||||
|
|
||||||
@ -340,6 +359,24 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
"""Initialize callback manager."""
|
"""Initialize callback manager."""
|
||||||
self.handlers: List[BaseCallbackHandler] = handlers
|
self.handlers: List[BaseCallbackHandler] = handlers
|
||||||
|
|
||||||
|
async def _handle_event(
|
||||||
|
self,
|
||||||
|
event_name: str,
|
||||||
|
ignore_condition_name: Optional[str],
|
||||||
|
verbose: bool,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Generic event handler for AsyncCallbackManager."""
|
||||||
|
await asyncio.gather(
|
||||||
|
*(
|
||||||
|
_handle_event_for_handler(
|
||||||
|
handler, event_name, ignore_condition_name, verbose, *args, **kwargs
|
||||||
|
)
|
||||||
|
for handler in self.handlers
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def on_llm_start(
|
async def on_llm_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
@ -348,50 +385,25 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running."""
|
||||||
for handler in self.handlers:
|
await self._handle_event(
|
||||||
if not handler.ignore_llm:
|
"on_llm_start", "ignore_llm", verbose, serialized, prompts, **kwargs
|
||||||
if verbose or handler.always_verbose:
|
)
|
||||||
if asyncio.iscoroutinefunction(handler.on_llm_start):
|
|
||||||
await handler.on_llm_start(serialized, prompts, **kwargs)
|
|
||||||
else:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
functools.partial(
|
|
||||||
handler.on_llm_start, serialized, prompts, **kwargs
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_llm_new_token(
|
async def on_llm_new_token(
|
||||||
self, token: str, verbose: bool = False, **kwargs: Any
|
self, token: str, verbose: bool = False, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||||
for handler in self.handlers:
|
await self._handle_event(
|
||||||
if not handler.ignore_llm:
|
"on_llm_new_token", "ignore_llm", verbose, token, **kwargs
|
||||||
if verbose or handler.always_verbose:
|
)
|
||||||
if asyncio.iscoroutinefunction(handler.on_llm_new_token):
|
|
||||||
await handler.on_llm_new_token(token, **kwargs)
|
|
||||||
else:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
functools.partial(
|
|
||||||
handler.on_llm_new_token, token, **kwargs
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_llm_end(
|
async def on_llm_end(
|
||||||
self, response: LLMResult, verbose: bool = False, **kwargs: Any
|
self, response: LLMResult, verbose: bool = False, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM ends running."""
|
"""Run when LLM ends running."""
|
||||||
for handler in self.handlers:
|
await self._handle_event(
|
||||||
if not handler.ignore_llm:
|
"on_llm_end", "ignore_llm", verbose, response, **kwargs
|
||||||
if verbose or handler.always_verbose:
|
)
|
||||||
if asyncio.iscoroutinefunction(handler.on_llm_end):
|
|
||||||
await handler.on_llm_end(response, **kwargs)
|
|
||||||
else:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
functools.partial(handler.on_llm_end, response, **kwargs),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_llm_error(
|
async def on_llm_error(
|
||||||
self,
|
self,
|
||||||
@ -400,16 +412,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM errors."""
|
"""Run when LLM errors."""
|
||||||
for handler in self.handlers:
|
await self._handle_event("on_llm_error", "ignore_llm", verbose, error, **kwargs)
|
||||||
if not handler.ignore_llm:
|
|
||||||
if verbose or handler.always_verbose:
|
|
||||||
if asyncio.iscoroutinefunction(handler.on_llm_error):
|
|
||||||
await handler.on_llm_error(error, **kwargs)
|
|
||||||
else:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
functools.partial(handler.on_llm_error, error, **kwargs),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_chain_start(
|
async def on_chain_start(
|
||||||
self,
|
self,
|
||||||
@ -419,33 +422,17 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when chain starts running."""
|
"""Run when chain starts running."""
|
||||||
for handler in self.handlers:
|
await self._handle_event(
|
||||||
if not handler.ignore_chain:
|
"on_chain_start", "ignore_chain", verbose, serialized, inputs, **kwargs
|
||||||
if verbose or handler.always_verbose:
|
)
|
||||||
if asyncio.iscoroutinefunction(handler.on_chain_start):
|
|
||||||
await handler.on_chain_start(serialized, inputs, **kwargs)
|
|
||||||
else:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
functools.partial(
|
|
||||||
handler.on_chain_start, serialized, inputs, **kwargs
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_chain_end(
|
async def on_chain_end(
|
||||||
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
|
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when chain ends running."""
|
"""Run when chain ends running."""
|
||||||
for handler in self.handlers:
|
await self._handle_event(
|
||||||
if not handler.ignore_chain:
|
"on_chain_end", "ignore_chain", verbose, outputs, **kwargs
|
||||||
if verbose or handler.always_verbose:
|
)
|
||||||
if asyncio.iscoroutinefunction(handler.on_chain_end):
|
|
||||||
await handler.on_chain_end(outputs, **kwargs)
|
|
||||||
else:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
functools.partial(handler.on_chain_end, outputs, **kwargs),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_chain_error(
|
async def on_chain_error(
|
||||||
self,
|
self,
|
||||||
@ -454,16 +441,9 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when chain errors."""
|
"""Run when chain errors."""
|
||||||
for handler in self.handlers:
|
await self._handle_event(
|
||||||
if not handler.ignore_chain:
|
"on_chain_error", "ignore_chain", verbose, error, **kwargs
|
||||||
if verbose or handler.always_verbose:
|
)
|
||||||
if asyncio.iscoroutinefunction(handler.on_chain_error):
|
|
||||||
await handler.on_chain_error(error, **kwargs)
|
|
||||||
else:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
functools.partial(handler.on_chain_error, error, **kwargs),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_tool_start(
|
async def on_tool_start(
|
||||||
self,
|
self,
|
||||||
@ -473,33 +453,17 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
for handler in self.handlers:
|
await self._handle_event(
|
||||||
if not handler.ignore_agent:
|
"on_tool_start", "ignore_agent", verbose, serialized, input_str, **kwargs
|
||||||
if verbose or handler.always_verbose:
|
)
|
||||||
if asyncio.iscoroutinefunction(handler.on_tool_start):
|
|
||||||
await handler.on_tool_start(serialized, input_str, **kwargs)
|
|
||||||
else:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
functools.partial(
|
|
||||||
handler.on_tool_start, serialized, input_str, **kwargs
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_tool_end(
|
async def on_tool_end(
|
||||||
self, output: str, verbose: bool = False, **kwargs: Any
|
self, output: str, verbose: bool = False, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
for handler in self.handlers:
|
await self._handle_event(
|
||||||
if not handler.ignore_agent:
|
"on_tool_end", "ignore_agent", verbose, output, **kwargs
|
||||||
if verbose or handler.always_verbose:
|
)
|
||||||
if asyncio.iscoroutinefunction(handler.on_tool_end):
|
|
||||||
await handler.on_tool_end(output, **kwargs)
|
|
||||||
else:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
functools.partial(handler.on_tool_end, output, **kwargs),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_tool_error(
|
async def on_tool_error(
|
||||||
self,
|
self,
|
||||||
@ -508,61 +472,29 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
for handler in self.handlers:
|
await self._handle_event(
|
||||||
if not handler.ignore_agent:
|
"on_tool_error", "ignore_agent", verbose, error, **kwargs
|
||||||
if verbose or handler.always_verbose:
|
)
|
||||||
if asyncio.iscoroutinefunction(handler.on_tool_error):
|
|
||||||
await handler.on_tool_error(error, **kwargs)
|
|
||||||
else:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
functools.partial(handler.on_tool_error, error, **kwargs),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
|
async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
|
||||||
"""Run when text is printed."""
|
"""Run when text is printed."""
|
||||||
for handler in self.handlers:
|
await self._handle_event("on_text", None, verbose, text, **kwargs)
|
||||||
if verbose or handler.always_verbose:
|
|
||||||
if asyncio.iscoroutinefunction(handler.on_text):
|
|
||||||
await handler.on_text(text, **kwargs)
|
|
||||||
else:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None, functools.partial(handler.on_text, text, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_agent_action(
|
async def on_agent_action(
|
||||||
self, action: AgentAction, verbose: bool = False, **kwargs: Any
|
self, action: AgentAction, verbose: bool = False, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run on agent action."""
|
"""Run on agent action."""
|
||||||
for handler in self.handlers:
|
await self._handle_event(
|
||||||
if not handler.ignore_agent:
|
"on_agent_action", "ignore_agent", verbose, action, **kwargs
|
||||||
if verbose or handler.always_verbose:
|
)
|
||||||
if asyncio.iscoroutinefunction(handler.on_agent_action):
|
|
||||||
await handler.on_agent_action(action, **kwargs)
|
|
||||||
else:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
functools.partial(
|
|
||||||
handler.on_agent_action, action, **kwargs
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_agent_finish(
|
async def on_agent_finish(
|
||||||
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
|
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when agent finishes."""
|
"""Run when agent finishes."""
|
||||||
for handler in self.handlers:
|
await self._handle_event(
|
||||||
if not handler.ignore_agent:
|
"on_agent_finish", "ignore_agent", verbose, finish, **kwargs
|
||||||
if verbose or handler.always_verbose:
|
)
|
||||||
if asyncio.iscoroutinefunction(handler.on_agent_finish):
|
|
||||||
await handler.on_agent_finish(finish, **kwargs)
|
|
||||||
else:
|
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
functools.partial(
|
|
||||||
handler.on_agent_finish, finish, **kwargs
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
"""Add a handler to the callback manager."""
|
"""Add a handler to the callback manager."""
|
||||||
|
@ -176,5 +176,6 @@ async def test_async_callback_manager_sync_handler() -> None:
|
|||||||
"""Test the AsyncCallbackManager."""
|
"""Test the AsyncCallbackManager."""
|
||||||
handler1 = FakeCallbackHandler(always_verbose_=True)
|
handler1 = FakeCallbackHandler(always_verbose_=True)
|
||||||
handler2 = FakeAsyncCallbackHandler()
|
handler2 = FakeAsyncCallbackHandler()
|
||||||
manager = AsyncCallbackManager([handler1, handler2])
|
handler3 = FakeAsyncCallbackHandler(always_verbose_=True)
|
||||||
await _test_callback_manager_async(manager, handler1, handler2)
|
manager = AsyncCallbackManager([handler1, handler2, handler3])
|
||||||
|
await _test_callback_manager_async(manager, handler1, handler2, handler3)
|
||||||
|
Loading…
Reference in New Issue
Block a user