From 35c9e6ab407003e0c1f16fcf6d4c73f6637db731 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Mon, 8 May 2023 08:44:26 -0700 Subject: [PATCH] Pass Callbacks through load_tools (#4298) - Update the load_tools method to properly accept `callbacks` arguments. - Add a deprecation warning when `callback_manager` is passed - Add two unit tests to check the deprecation warning is raised and to confirm the callback is passed through. Closes issue #4096 --- langchain/agents/load_tools.py | 35 ++++++++++++++++++++------- tests/unit_tests/agents/test_tools.py | 30 ++++++++++++++++++++++- 2 files changed, 55 insertions(+), 10 deletions(-) diff --git a/langchain/agents/load_tools.py b/langchain/agents/load_tools.py index bdae3fd132..e740841134 100644 --- a/langchain/agents/load_tools.py +++ b/langchain/agents/load_tools.py @@ -7,6 +7,7 @@ from mypy_extensions import Arg, KwArg from langchain.agents.tools import Tool from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import Callbacks from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs from langchain.chains.api.base import APIChain from langchain.chains.llm_math.base import LLMMathChain @@ -279,10 +280,26 @@ _EXTRA_OPTIONAL_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[st } +def _handle_callbacks( + callback_manager: Optional[BaseCallbackManager], callbacks: Callbacks +) -> Callbacks: + if callback_manager is not None: + warnings.warn( + "callback_manager is deprecated. Please use callbacks instead.", + DeprecationWarning, + ) + if callbacks is not None: + raise ValueError( + "Cannot specify both callback_manager and callbacks arguments." + ) + return callback_manager + return callbacks + + def load_tools( tool_names: List[str], llm: Optional[BaseLanguageModel] = None, - callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> List[BaseTool]: """Load tools based on their name. @@ -290,13 +307,16 @@ def load_tools( Args: tool_names: name of tools to load. llm: Optional language model, may be needed to initialize certain tools. - callback_manager: Optional callback manager. If not provided, default global callback manager will be used. + callbacks: Optional callback manager or list of callback handlers. + If not provided, default global callback manager will be used. Returns: List of tools. """ tools = [] - + callbacks = _handle_callbacks( + callback_manager=kwargs.get("callback_manager"), callbacks=callbacks + ) for name in tool_names: if name == "requests": warnings.warn( @@ -316,8 +336,6 @@ def load_tools( if llm is None: raise ValueError(f"Tool {name} requires an LLM to be provided") tool = _LLM_TOOLS[name](llm) - if callback_manager is not None: - tool.callback_manager = callback_manager tools.append(tool) elif name in _EXTRA_LLM_TOOLS: if llm is None: @@ -331,18 +349,17 @@ def load_tools( ) sub_kwargs = {k: kwargs[k] for k in extra_keys} tool = _get_llm_tool_func(llm=llm, **sub_kwargs) - if callback_manager is not None: - tool.callback_manager = callback_manager tools.append(tool) elif name in _EXTRA_OPTIONAL_TOOLS: _get_tool_func, extra_keys = _EXTRA_OPTIONAL_TOOLS[name] sub_kwargs = {k: kwargs[k] for k in extra_keys if k in kwargs} tool = _get_tool_func(**sub_kwargs) - if callback_manager is not None: - tool.callback_manager = callback_manager tools.append(tool) else: raise ValueError(f"Got unknown tool {name}") + if callbacks is not None: + for tool in tools: + tool.callbacks = callbacks return tools diff --git a/tests/unit_tests/agents/test_tools.py b/tests/unit_tests/agents/test_tools.py index a8557cc24e..dd1be32f93 100644 --- a/tests/unit_tests/agents/test_tools.py +++ b/tests/unit_tests/agents/test_tools.py @@ -1,9 +1,11 @@ """Test tool utils.""" +import unittest from typing import Any, Type -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock import pytest +from langchain.agents import load_tools from langchain.agents.agent import Agent from langchain.agents.chat.base import ChatAgent from langchain.agents.conversational.base import ConversationalAgent @@ -12,6 +14,7 @@ from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.tools import Tool, tool +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @pytest.mark.parametrize( @@ -62,3 +65,28 @@ def test_tool_no_args_specified_assumes_str() -> None: assert some_tool.run({"tool_input": "foobar"}) == "foobar" with pytest.raises(ValueError, match="Too many arguments to single-input tool"): some_tool.run({"tool_input": "foobar", "other_input": "bar"}) + + +def test_load_tools_with_callback_manager_raises_deprecation_warning() -> None: + """Test load_tools raises a deprecation for old callback manager kwarg.""" + callback_manager = MagicMock() + with pytest.warns(DeprecationWarning, match="callback_manager is deprecated"): + tools = load_tools(["requests_get"], callback_manager=callback_manager) + assert len(tools) == 1 + assert tools[0].callbacks == callback_manager + + +def test_load_tools_with_callbacks_is_called() -> None: + """Test callbacks are called when provided to load_tools fn.""" + callbacks = [FakeCallbackHandler()] + tools = load_tools(["requests_get"], callbacks=callbacks) # type: ignore + assert len(tools) == 1 + # Patch the requests.get() method to return a mock response + with unittest.mock.patch( + "langchain.requests.TextRequestsWrapper.get", + return_value=Mock(text="Hello world!"), + ): + result = tools[0].run("https://www.google.com") + assert result.text == "Hello world!" + assert callbacks[0].tool_starts == 1 + assert callbacks[0].tool_ends == 1