langchain/tests/unit_tests/callbacks/test_callback_manager.py
2023-04-05 09:31:42 +02:00

182 lines
6.5 KiB
Python

"""Test CallbackManager."""
from typing import Tuple
import pytest
from langchain.callbacks.base import (
AsyncCallbackManager,
BaseCallbackManager,
CallbackManager,
)
from langchain.callbacks.shared import SharedCallbackManager
from langchain.schema import AgentFinish, LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import (
BaseFakeCallbackHandler,
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)
def _test_callback_manager(
manager: BaseCallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
manager.on_llm_start({}, [])
manager.on_llm_end(LLMResult(generations=[]))
manager.on_llm_error(Exception())
manager.on_chain_start({"name": "foo"}, {})
manager.on_chain_end({})
manager.on_chain_error(Exception())
manager.on_tool_start({}, "")
manager.on_tool_end("")
manager.on_tool_error(Exception())
manager.on_agent_finish(AgentFinish(log="", return_values={}))
_check_num_calls(handlers)
async def _test_callback_manager_async(
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
await manager.on_llm_start({}, [])
await manager.on_llm_end(LLMResult(generations=[]))
await manager.on_llm_error(Exception())
await manager.on_chain_start({"name": "foo"}, {})
await manager.on_chain_end({})
await manager.on_chain_error(Exception())
await manager.on_tool_start({}, "")
await manager.on_tool_end("")
await manager.on_tool_error(Exception())
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
_check_num_calls(handlers)
def _check_num_calls(handlers: Tuple[BaseFakeCallbackHandler, ...]) -> None:
for handler in handlers:
if handler.always_verbose:
assert handler.starts == 3
assert handler.ends == 4
assert handler.errors == 3
else:
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0
def _test_callback_manager_pass_in_verbose(
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
manager.on_llm_start({}, [], verbose=True)
manager.on_llm_end(LLMResult(generations=[]), verbose=True)
manager.on_llm_error(Exception(), verbose=True)
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
manager.on_chain_end({}, verbose=True)
manager.on_chain_error(Exception(), verbose=True)
manager.on_tool_start({}, "", verbose=True)
manager.on_tool_end("", verbose=True)
manager.on_tool_error(Exception(), verbose=True)
manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
for handler in handlers:
assert handler.starts == 3
assert handler.ends == 4
assert handler.errors == 3
def test_callback_manager() -> None:
"""Test the CallbackManager."""
handler1 = FakeCallbackHandler(always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=False)
manager = CallbackManager([handler1, handler2])
_test_callback_manager(manager, handler1, handler2)
def test_callback_manager_pass_in_verbose() -> None:
"""Test the CallbackManager."""
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
manager = CallbackManager([handler1, handler2])
_test_callback_manager_pass_in_verbose(manager, handler1, handler2)
def test_ignore_llm() -> None:
"""Test ignore llm param for callback handlers."""
handler1 = FakeCallbackHandler(ignore_llm_=True, always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=True)
manager = CallbackManager(handlers=[handler1, handler2])
manager.on_llm_start({}, [], verbose=True)
manager.on_llm_end(LLMResult(generations=[]), verbose=True)
manager.on_llm_error(Exception(), verbose=True)
assert handler1.starts == 0
assert handler1.ends == 0
assert handler1.errors == 0
assert handler2.starts == 1
assert handler2.ends == 1
assert handler2.errors == 1
def test_ignore_chain() -> None:
"""Test ignore chain param for callback handlers."""
handler1 = FakeCallbackHandler(ignore_chain_=True, always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=True)
manager = CallbackManager(handlers=[handler1, handler2])
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
manager.on_chain_end({}, verbose=True)
manager.on_chain_error(Exception(), verbose=True)
assert handler1.starts == 0
assert handler1.ends == 0
assert handler1.errors == 0
assert handler2.starts == 1
assert handler2.ends == 1
assert handler2.errors == 1
def test_ignore_agent() -> None:
"""Test ignore agent param for callback handlers."""
handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=True)
manager = CallbackManager(handlers=[handler1, handler2])
manager.on_tool_start({}, "", verbose=True)
manager.on_tool_end("", verbose=True)
manager.on_tool_error(Exception(), verbose=True)
manager.on_agent_finish(AgentFinish({}, ""), verbose=True)
assert handler1.starts == 0
assert handler1.ends == 0
assert handler1.errors == 0
assert handler2.starts == 1
assert handler2.ends == 2
assert handler2.errors == 1
def test_shared_callback_manager() -> None:
"""Test the SharedCallbackManager."""
manager1 = SharedCallbackManager()
manager2 = SharedCallbackManager()
assert manager1 is manager2
handler1 = FakeCallbackHandler(always_verbose_=True)
handler2 = FakeCallbackHandler()
manager1.add_handler(handler1)
manager2.add_handler(handler2)
_test_callback_manager(manager1, handler1, handler2)
@pytest.mark.asyncio
async def test_async_callback_manager() -> None:
"""Test the AsyncCallbackManager."""
handler1 = FakeAsyncCallbackHandler(always_verbose_=True)
handler2 = FakeAsyncCallbackHandler()
manager = AsyncCallbackManager([handler1, handler2])
await _test_callback_manager_async(manager, handler1, handler2)
@pytest.mark.asyncio
async def test_async_callback_manager_sync_handler() -> None:
"""Test the AsyncCallbackManager."""
handler1 = FakeCallbackHandler(always_verbose_=True)
handler2 = FakeAsyncCallbackHandler()
handler3 = FakeAsyncCallbackHandler(always_verbose_=True)
manager = AsyncCallbackManager([handler1, handler2, handler3])
await _test_callback_manager_async(manager, handler1, handler2, handler3)