langchain/libs/community/tests/unit_tests/agents/test_tools.py

101 lines
3.5 KiB
Python
Raw Normal View History

"""Test tool utils."""
import unittest
from typing import Any, Type
from unittest.mock import MagicMock, Mock
import pytest
from langchain.agents.agent import Agent
from langchain.agents.chat.base import ChatAgent
from langchain.agents.conversational.base import ConversationalAgent
from langchain.agents.conversational_chat.base import ConversationalChatAgent
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
multiple: langchain 0.2 in master (#21191) 0.2rc migrations - [x] Move memory - [x] Move remaining retrievers - [x] graph_qa chains - [x] some dependency from evaluation code potentially on math utils - [x] Move openapi chain from `langchain.chains.api.openapi` to `langchain_community.chains.openapi` - [x] Migrate `langchain.chains.ernie_functions` to `langchain_community.chains.ernie_functions` - [x] migrate `langchain/chains/llm_requests.py` to `langchain_community.chains.llm_requests` - [x] Moving `langchain_community.cross_enoders.base:BaseCrossEncoder` -> `langchain_community.retrievers.document_compressors.cross_encoder:BaseCrossEncoder` (namespace not ideal, but it needs to be moved to `langchain` to avoid circular deps) - [x] unit tests langchain -- add pytest.mark.community to some unit tests that will stay in langchain - [x] unit tests community -- move unit tests that depend on community to community - [x] mv integration tests that depend on community to community - [x] mypy checks Other todo - [x] Make deprecation warnings not noisy (need to use warn deprecated and check that things are implemented properly) - [x] Update deprecation messages with timeline for code removal (likely we actually won't be removing things until 0.4 release) -- will give people more time to transition their code. - [ ] Add information to deprecation warning to show users how to migrate their code base using langchain-cli - [ ] Remove any unnecessary requirements in langchain (e.g., is SQLALchemy required?) --------- Co-authored-by: Erick Friis <erick@langchain.dev>
2024-05-08 20:46:52 +00:00
from langchain_core.tools import Tool, ToolException, tool
from langchain_community.agent_toolkits.load_tools import load_tools
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@pytest.mark.parametrize(
"agent_cls",
[
ZeroShotAgent,
ChatAgent,
ConversationalChatAgent,
ConversationalAgent,
ReActDocstoreAgent,
ReActTextWorldAgent,
SelfAskWithSearchAgent,
],
)
def test_single_input_agent_raises_error_on_structured_tool(
agent_cls: Type[Agent],
) -> None:
"""Test that older agents raise errors on older tools."""
@tool
def the_tool(foo: str, bar: str) -> str:
"""Return the concat of foo and bar."""
return foo + bar
with pytest.raises(
ValueError,
match=f"{agent_cls.__name__} does not support" # type: ignore
f" multi-input tool the_tool.",
):
agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore
def test_tool_no_args_specified_assumes_str() -> None:
"""Older tools could assume *args and **kwargs were passed in."""
def ambiguous_function(*args: Any, **kwargs: Any) -> str:
"""An ambiguously defined function."""
return args[0]
some_tool = Tool(
name="chain_run",
description="Run the chain",
func=ambiguous_function,
)
expected_args = {"tool_input": {"type": "string"}}
assert some_tool.args == expected_args
assert some_tool.run("foobar") == "foobar"
assert some_tool.run({"tool_input": "foobar"}) == "foobar"
with pytest.raises(ToolException, match="Too many arguments to single-input tool"):
some_tool.run({"tool_input": "foobar", "other_input": "bar"})
def test_load_tools_with_callback_manager_raises_deprecation_warning() -> None:
"""Test load_tools raises a deprecation for old callback manager kwarg."""
callback_manager = MagicMock()
with pytest.warns(DeprecationWarning, match="callback_manager is deprecated"):
tools = load_tools(
["requests_get"],
callback_manager=callback_manager,
allow_dangerous_tools=True,
)
assert len(tools) == 1
assert tools[0].callbacks == callback_manager
def test_load_tools_with_callbacks_is_called() -> None:
"""Test callbacks are called when provided to load_tools fn."""
callbacks = [FakeCallbackHandler()]
tools = load_tools(
["requests_get"], # type: ignore
callbacks=callbacks, # type: ignore
allow_dangerous_tools=True,
)
assert len(tools) == 1
# Patch the requests.get() method to return a mock response
with unittest.mock.patch(
"langchain.requests.TextRequestsWrapper.get",
return_value=Mock(text="Hello world!"),
):
result = tools[0].run("https://www.google.com")
assert result.text == "Hello world!"
assert callbacks[0].tool_starts == 1
assert callbacks[0].tool_ends == 1