You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/tests/unit_tests/agents/test_tools.py

195 lines
5.7 KiB
Python

"""Test tool utils."""
from datetime import datetime
from typing import Optional, Type, Union
import pytest
from pydantic import BaseModel
from langchain.agents.tools import Tool, tool
from langchain.tools.base import BaseTool
def test_unnamed_decorator() -> None:
"""Test functionality with unnamed decorator."""
@tool
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert search_api.name == "search_api"
assert not search_api.return_direct
assert search_api("test") == "API result"
class _MockSchema(BaseModel):
arg1: int
arg2: bool
arg3: Optional[dict] = None
class _MockStructuredTool(BaseTool):
name = "structured_api"
args_schema: Type[BaseModel] = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
raise NotImplementedError
def test_structured_args() -> None:
"""Test functionality with structured arguments."""
structured_api = _MockStructuredTool()
assert isinstance(structured_api, BaseTool)
assert structured_api.name == "structured_api"
expected_result = "1 True {'foo': 'bar'}"
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
assert structured_api.run(args) == expected_result
def test_structured_args_decorator() -> None:
"""Test functionality with structured arguments parsed as a decorator."""
@tool
def structured_tool_input(
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1}, {arg2}, {opt_arg}"
assert isinstance(structured_tool_input, Tool)
assert structured_tool_input.name == "structured_tool_input"
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
expected_result = "1, 0.001, {'foo': 'bar'}"
assert structured_tool_input.run(args) == expected_result
def test_empty_args_decorator() -> None:
"""Test functionality with no args parsed as a decorator."""
@tool
def empty_tool_input() -> str:
"""Return a constant."""
return "the empty result"
assert isinstance(empty_tool_input, Tool)
assert empty_tool_input.name == "empty_tool_input"
assert empty_tool_input.run({}) == "the empty result"
def test_named_tool_decorator() -> None:
"""Test functionality when arguments are provided as input to decorator."""
@tool("search")
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert search_api.name == "search"
assert not search_api.return_direct
def test_named_tool_decorator_return_direct() -> None:
"""Test functionality when arguments and return direct are provided as input."""
@tool("search", return_direct=True)
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert search_api.name == "search"
assert search_api.return_direct
def test_unnamed_tool_decorator_return_direct() -> None:
"""Test functionality when only return direct is provided."""
@tool(return_direct=True)
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert search_api.name == "search_api"
assert search_api.return_direct
def test_tool_with_kwargs() -> None:
"""Test functionality when only return direct is provided."""
@tool(return_direct=True)
def search_api(
arg_1: float,
ping: str = "hi",
) -> str:
"""Search the API for the query."""
return f"arg_1={arg_1}, ping={ping}"
assert isinstance(search_api, Tool)
result = search_api.run(
tool_input={
"arg_1": 3.2,
"ping": "pong",
}
)
assert result == "arg_1=3.2, ping=pong"
result = search_api.run(
tool_input={
"arg_1": 3.2,
}
)
assert result == "arg_1=3.2, ping=hi"
def test_missing_docstring() -> None:
"""Test error is raised when docstring is missing."""
# expect to throw a value error if theres no docstring
with pytest.raises(AssertionError):
@tool
def search_api(query: str) -> str:
return "API result"
def test_create_tool_positional_args() -> None:
"""Test that positional arguments are allowed."""
test_tool = Tool("test_name", lambda x: x, "test_description")
assert test_tool("foo") == "foo"
assert test_tool.name == "test_name"
assert test_tool.description == "test_description"
def test_create_tool_keyword_args() -> None:
"""Test that keyword arguments are allowed."""
test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")
assert test_tool("foo") == "foo"
assert test_tool.name == "test_name"
assert test_tool.description == "test_description"
@pytest.mark.asyncio
async def test_create_async_tool() -> None:
"""Test that async tools are allowed."""
async def _test_func(x: str) -> str:
return x
test_tool = Tool(
name="test_name",
func=lambda x: x,
description="test_description",
coroutine=_test_func,
)
assert test_tool("foo") == "foo"
assert test_tool.name == "test_name"
assert test_tool.description == "test_description"
assert test_tool.coroutine is not None
assert await test_tool.arun("foo") == "foo"