"""Test tool utils.""" import unittest from typing import Any, Type from unittest.mock import MagicMock, Mock import pytest from langchain.agents.agent import Agent from langchain.agents.chat.base import ChatAgent from langchain.agents.conversational.base import ConversationalAgent from langchain.agents.conversational_chat.base import ConversationalChatAgent from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain_core.tools import Tool, ToolException, tool from langchain_community.agent_toolkits.load_tools import load_tools from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @pytest.mark.parametrize( "agent_cls", [ ZeroShotAgent, ChatAgent, ConversationalChatAgent, ConversationalAgent, ReActDocstoreAgent, ReActTextWorldAgent, SelfAskWithSearchAgent, ], ) def test_single_input_agent_raises_error_on_structured_tool( agent_cls: Type[Agent], ) -> None: """Test that older agents raise errors on older tools.""" @tool def the_tool(foo: str, bar: str) -> str: """Return the concat of foo and bar.""" return foo + bar with pytest.raises( ValueError, match=f"{agent_cls.__name__} does not support" # type: ignore f" multi-input tool the_tool.", ): agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore def test_tool_no_args_specified_assumes_str() -> None: """Older tools could assume *args and **kwargs were passed in.""" def ambiguous_function(*args: Any, **kwargs: Any) -> str: """An ambiguously defined function.""" return args[0] some_tool = Tool( name="chain_run", description="Run the chain", func=ambiguous_function, ) expected_args = {"tool_input": {"type": "string"}} assert some_tool.args == expected_args assert some_tool.run("foobar") == "foobar" assert some_tool.run({"tool_input": "foobar"}) == "foobar" with pytest.raises(ToolException, match="Too many arguments to single-input tool"): some_tool.run({"tool_input": "foobar", "other_input": "bar"}) def test_load_tools_with_callback_manager_raises_deprecation_warning() -> None: """Test load_tools raises a deprecation for old callback manager kwarg.""" callback_manager = MagicMock() with pytest.warns(DeprecationWarning, match="callback_manager is deprecated"): tools = load_tools( ["requests_get"], callback_manager=callback_manager, allow_dangerous_tools=True, ) assert len(tools) == 1 assert tools[0].callbacks == callback_manager def test_load_tools_with_callbacks_is_called() -> None: """Test callbacks are called when provided to load_tools fn.""" callbacks = [FakeCallbackHandler()] tools = load_tools( ["requests_get"], # type: ignore callbacks=callbacks, # type: ignore allow_dangerous_tools=True, ) assert len(tools) == 1 # Patch the requests.get() method to return a mock response with unittest.mock.patch( "langchain.requests.TextRequestsWrapper.get", return_value=Mock(text="Hello world!"), ): result = tools[0].run("https://www.google.com") assert result.text == "Hello world!" assert callbacks[0].tool_starts == 1 assert callbacks[0].tool_ends == 1