mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
b0859c9b18
Handle the new retriever events in a way that (I think) is entirely backwards compatible? Needs more testing for some of the chain changes and all. This creates an entire new run type, however. We could also just treat this as an event within a chain run presumably (same with memory) Adds a subclass initializer that upgrades old retriever implementations to the new schema, along with tests to ensure they work. First commit doesn't upgrade any of our retriever implementations (to show that we can pass the tests along with additional ones testing the upgrade logic). Second commit upgrades the known universe of retrievers in langchain. - [X] Add callback handling methods for retriever start/end/error (open to renaming to 'retrieval' if you want that) - [X] Update BaseRetriever schema to support callbacks - [X] Tests for upgrading old "v1" retrievers for backwards compatibility - [X] Update existing retriever implementations to implement the new interface - [X] Update calls within chains to .{a]get_relevant_documents to pass the child callback manager - [X] Update the notebooks/docs to reflect the new interface - [X] Test notebooks thoroughly Not handled: - Memory pass throughs: retrieval memory doesn't have a parent callback manager passed through the method --------- Co-authored-by: Nuno Campos <nuno@boringbits.io> Co-authored-by: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com>
266 lines
9.9 KiB
Python
266 lines
9.9 KiB
Python
"""Test CallbackManager."""
|
|
from typing import List, Tuple
|
|
|
|
import pytest
|
|
|
|
from langchain.callbacks.base import BaseCallbackHandler
|
|
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
|
|
from langchain.callbacks.stdout import StdOutCallbackHandler
|
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
|
from tests.unit_tests.callbacks.fake_callback_handler import (
|
|
BaseFakeCallbackHandler,
|
|
FakeAsyncCallbackHandler,
|
|
FakeCallbackHandler,
|
|
)
|
|
|
|
|
|
def _test_callback_manager(
|
|
manager: CallbackManager, *handlers: BaseFakeCallbackHandler
|
|
) -> None:
|
|
"""Test the CallbackManager."""
|
|
run_managers = manager.on_llm_start({}, ["prompt"])
|
|
for run_manager in run_managers:
|
|
run_manager.on_llm_end(LLMResult(generations=[]))
|
|
run_manager.on_llm_error(Exception())
|
|
run_manager.on_llm_new_token("foo")
|
|
run_manager.on_text("foo")
|
|
|
|
run_manager_chain = manager.on_chain_start({"name": "foo"}, {})
|
|
run_manager_chain.on_chain_end({})
|
|
run_manager_chain.on_chain_error(Exception())
|
|
run_manager_chain.on_agent_action(AgentAction(tool_input="foo", log="", tool=""))
|
|
run_manager_chain.on_agent_finish(AgentFinish(log="", return_values={}))
|
|
run_manager_chain.on_text("foo")
|
|
|
|
run_manager_tool = manager.on_tool_start({}, "")
|
|
run_manager_tool.on_tool_end("")
|
|
run_manager_tool.on_tool_error(Exception())
|
|
run_manager_tool.on_text("foo")
|
|
_check_num_calls(handlers)
|
|
|
|
|
|
async def _test_callback_manager_async(
|
|
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
|
|
) -> None:
|
|
"""Test the CallbackManager."""
|
|
run_managers = await manager.on_llm_start({}, ["prompt"])
|
|
for run_manager in run_managers:
|
|
await run_manager.on_llm_end(LLMResult(generations=[]))
|
|
await run_manager.on_llm_error(Exception())
|
|
await run_manager.on_llm_new_token("foo")
|
|
await run_manager.on_text("foo")
|
|
|
|
run_manager_chain = await manager.on_chain_start({"name": "foo"}, {})
|
|
await run_manager_chain.on_chain_end({})
|
|
await run_manager_chain.on_chain_error(Exception())
|
|
await run_manager_chain.on_agent_action(
|
|
AgentAction(tool_input="foo", log="", tool="")
|
|
)
|
|
await run_manager_chain.on_agent_finish(AgentFinish(log="", return_values={}))
|
|
await run_manager_chain.on_text("foo")
|
|
|
|
run_manager_tool = await manager.on_tool_start({}, "")
|
|
await run_manager_tool.on_tool_end("")
|
|
await run_manager_tool.on_tool_error(Exception())
|
|
await run_manager_tool.on_text("foo")
|
|
_check_num_calls(handlers)
|
|
|
|
|
|
def _check_num_calls(handlers: Tuple[BaseFakeCallbackHandler, ...]) -> None:
|
|
for handler in handlers:
|
|
assert handler.starts == 4
|
|
assert handler.ends == 4
|
|
assert handler.errors == 3
|
|
assert handler.text == 3
|
|
|
|
assert handler.llm_starts == 1
|
|
assert handler.llm_ends == 1
|
|
assert handler.llm_streams == 1
|
|
|
|
assert handler.chain_starts == 1
|
|
assert handler.chain_ends == 1
|
|
|
|
assert handler.tool_starts == 1
|
|
assert handler.tool_ends == 1
|
|
|
|
|
|
def test_callback_manager() -> None:
|
|
"""Test the CallbackManager."""
|
|
handler1 = FakeCallbackHandler()
|
|
handler2 = FakeCallbackHandler()
|
|
manager = CallbackManager(handlers=[handler1, handler2])
|
|
_test_callback_manager(manager, handler1, handler2)
|
|
|
|
|
|
def test_ignore_llm() -> None:
|
|
"""Test ignore llm param for callback handlers."""
|
|
handler1 = FakeCallbackHandler(ignore_llm_=True)
|
|
handler2 = FakeCallbackHandler()
|
|
manager = CallbackManager(handlers=[handler1, handler2])
|
|
run_managers = manager.on_llm_start({}, ["prompt"])
|
|
for run_manager in run_managers:
|
|
run_manager.on_llm_end(LLMResult(generations=[]))
|
|
run_manager.on_llm_error(Exception())
|
|
assert handler1.starts == 0
|
|
assert handler1.ends == 0
|
|
assert handler1.errors == 0
|
|
assert handler2.starts == 1
|
|
assert handler2.ends == 1
|
|
assert handler2.errors == 1
|
|
|
|
|
|
def test_ignore_chain() -> None:
|
|
"""Test ignore chain param for callback handlers."""
|
|
handler1 = FakeCallbackHandler(ignore_chain_=True)
|
|
handler2 = FakeCallbackHandler()
|
|
manager = CallbackManager(handlers=[handler1, handler2])
|
|
run_manager = manager.on_chain_start({"name": "foo"}, {})
|
|
run_manager.on_chain_end({})
|
|
run_manager.on_chain_error(Exception())
|
|
assert handler1.starts == 0
|
|
assert handler1.ends == 0
|
|
assert handler1.errors == 0
|
|
assert handler2.starts == 1
|
|
assert handler2.ends == 1
|
|
assert handler2.errors == 1
|
|
|
|
|
|
def test_ignore_agent() -> None:
|
|
"""Test ignore agent param for callback handlers."""
|
|
handler1 = FakeCallbackHandler(ignore_agent_=True)
|
|
handler2 = FakeCallbackHandler()
|
|
manager = CallbackManager(handlers=[handler1, handler2])
|
|
run_manager = manager.on_tool_start({}, "")
|
|
run_manager.on_tool_end("")
|
|
run_manager.on_tool_error(Exception())
|
|
assert handler1.starts == 0
|
|
assert handler1.ends == 0
|
|
assert handler1.errors == 0
|
|
assert handler2.starts == 1
|
|
assert handler2.ends == 1
|
|
assert handler2.errors == 1
|
|
|
|
|
|
def test_ignore_retriever() -> None:
|
|
"""Test the ignore retriever param for callback handlers."""
|
|
handler1 = FakeCallbackHandler(ignore_retriever_=True)
|
|
handler2 = FakeCallbackHandler()
|
|
manager = CallbackManager(handlers=[handler1, handler2])
|
|
run_manager = manager.on_retriever_start("")
|
|
run_manager.on_retriever_end([])
|
|
run_manager.on_retriever_error(Exception())
|
|
assert handler1.starts == 0
|
|
assert handler1.ends == 0
|
|
assert handler1.errors == 0
|
|
assert handler2.starts == 1
|
|
assert handler2.ends == 1
|
|
assert handler2.errors == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_callback_manager() -> None:
|
|
"""Test the AsyncCallbackManager."""
|
|
handler1 = FakeAsyncCallbackHandler()
|
|
handler2 = FakeAsyncCallbackHandler()
|
|
manager = AsyncCallbackManager(handlers=[handler1, handler2])
|
|
await _test_callback_manager_async(manager, handler1, handler2)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_callback_manager_sync_handler() -> None:
|
|
"""Test the AsyncCallbackManager."""
|
|
handler1 = FakeCallbackHandler()
|
|
handler2 = FakeAsyncCallbackHandler()
|
|
handler3 = FakeAsyncCallbackHandler()
|
|
manager = AsyncCallbackManager(handlers=[handler1, handler2, handler3])
|
|
await _test_callback_manager_async(manager, handler1, handler2, handler3)
|
|
|
|
|
|
def test_callback_manager_inheritance() -> None:
|
|
handler1, handler2, handler3, handler4 = (
|
|
FakeCallbackHandler(),
|
|
FakeCallbackHandler(),
|
|
FakeCallbackHandler(),
|
|
FakeCallbackHandler(),
|
|
)
|
|
|
|
callback_manager1 = CallbackManager(handlers=[handler1, handler2])
|
|
assert callback_manager1.handlers == [handler1, handler2]
|
|
assert callback_manager1.inheritable_handlers == []
|
|
|
|
callback_manager2 = CallbackManager(handlers=[])
|
|
assert callback_manager2.handlers == []
|
|
assert callback_manager2.inheritable_handlers == []
|
|
|
|
callback_manager2.set_handlers([handler1, handler2])
|
|
assert callback_manager2.handlers == [handler1, handler2]
|
|
assert callback_manager2.inheritable_handlers == [handler1, handler2]
|
|
|
|
callback_manager2.set_handlers([handler3, handler4], inherit=False)
|
|
assert callback_manager2.handlers == [handler3, handler4]
|
|
assert callback_manager2.inheritable_handlers == []
|
|
|
|
callback_manager2.add_handler(handler1)
|
|
assert callback_manager2.handlers == [handler3, handler4, handler1]
|
|
assert callback_manager2.inheritable_handlers == [handler1]
|
|
|
|
callback_manager2.add_handler(handler2, inherit=False)
|
|
assert callback_manager2.handlers == [handler3, handler4, handler1, handler2]
|
|
assert callback_manager2.inheritable_handlers == [handler1]
|
|
|
|
run_manager = callback_manager2.on_chain_start({"name": "foo"}, {})
|
|
child_manager = run_manager.get_child()
|
|
assert child_manager.handlers == [handler1]
|
|
assert child_manager.inheritable_handlers == [handler1]
|
|
|
|
run_manager_tool = child_manager.on_tool_start({}, "")
|
|
assert run_manager_tool.handlers == [handler1]
|
|
assert run_manager_tool.inheritable_handlers == [handler1]
|
|
|
|
child_manager2 = run_manager_tool.get_child()
|
|
assert child_manager2.handlers == [handler1]
|
|
assert child_manager2.inheritable_handlers == [handler1]
|
|
|
|
|
|
def test_callback_manager_configure(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Test callback manager configuration."""
|
|
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "false")
|
|
monkeypatch.setenv("LANGCHAIN_TRACING", "false")
|
|
handler1, handler2, handler3, handler4 = (
|
|
FakeCallbackHandler(),
|
|
FakeCallbackHandler(),
|
|
FakeCallbackHandler(),
|
|
FakeCallbackHandler(),
|
|
)
|
|
|
|
inheritable_callbacks: List[BaseCallbackHandler] = [handler1, handler2]
|
|
local_callbacks: List[BaseCallbackHandler] = [handler3, handler4]
|
|
configured_manager = CallbackManager.configure(
|
|
inheritable_callbacks=inheritable_callbacks,
|
|
local_callbacks=local_callbacks,
|
|
verbose=True,
|
|
)
|
|
|
|
assert len(configured_manager.handlers) == 5
|
|
assert len(configured_manager.inheritable_handlers) == 2
|
|
assert configured_manager.inheritable_handlers == inheritable_callbacks
|
|
assert configured_manager.handlers[:4] == inheritable_callbacks + local_callbacks
|
|
assert isinstance(configured_manager.handlers[4], StdOutCallbackHandler)
|
|
assert isinstance(configured_manager, CallbackManager)
|
|
|
|
async_local_callbacks = AsyncCallbackManager(handlers=[handler3, handler4])
|
|
async_configured_manager = AsyncCallbackManager.configure(
|
|
inheritable_callbacks=inheritable_callbacks,
|
|
local_callbacks=async_local_callbacks,
|
|
verbose=False,
|
|
)
|
|
|
|
assert len(async_configured_manager.handlers) == 4
|
|
assert len(async_configured_manager.inheritable_handlers) == 2
|
|
assert async_configured_manager.inheritable_handlers == inheritable_callbacks
|
|
assert async_configured_manager.handlers == inheritable_callbacks + [
|
|
handler3,
|
|
handler4,
|
|
]
|
|
assert isinstance(async_configured_manager, AsyncCallbackManager)
|