forked from Archives/langchain
90ef705ced
- Remove dynamic model creation in the `args()` property. _Only infer for the decorator (and add an argument to NOT infer if someone wishes to only pass as a string)_ - Update the validation example to make it less likely to be misinterpreted as a "safe" way to run a repl There is one example of "Multi-argument tools" in the custom_tools.ipynb from yesterday, but we could add more. The output parsing for the base MRKL agent hasn't been adapted to handle structured args at this point in time --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
195 lines
5.7 KiB
Python
195 lines
5.7 KiB
Python
"""Test tool utils."""
|
|
from datetime import datetime
|
|
from typing import Optional, Type, Union
|
|
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
|
|
from langchain.agents.tools import Tool, tool
|
|
from langchain.tools.base import BaseTool
|
|
|
|
|
|
def test_unnamed_decorator() -> None:
|
|
"""Test functionality with unnamed decorator."""
|
|
|
|
@tool
|
|
def search_api(query: str) -> str:
|
|
"""Search the API for the query."""
|
|
return "API result"
|
|
|
|
assert isinstance(search_api, Tool)
|
|
assert search_api.name == "search_api"
|
|
assert not search_api.return_direct
|
|
assert search_api("test") == "API result"
|
|
|
|
|
|
class _MockSchema(BaseModel):
|
|
arg1: int
|
|
arg2: bool
|
|
arg3: Optional[dict] = None
|
|
|
|
|
|
class _MockStructuredTool(BaseTool):
|
|
name = "structured_api"
|
|
args_schema: Type[BaseModel] = _MockSchema
|
|
description = "A Structured Tool"
|
|
|
|
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
|
return f"{arg1} {arg2} {arg3}"
|
|
|
|
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
|
raise NotImplementedError
|
|
|
|
|
|
def test_structured_args() -> None:
|
|
"""Test functionality with structured arguments."""
|
|
structured_api = _MockStructuredTool()
|
|
assert isinstance(structured_api, BaseTool)
|
|
assert structured_api.name == "structured_api"
|
|
expected_result = "1 True {'foo': 'bar'}"
|
|
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
|
|
assert structured_api.run(args) == expected_result
|
|
|
|
|
|
def test_structured_args_decorator() -> None:
|
|
"""Test functionality with structured arguments parsed as a decorator."""
|
|
|
|
@tool
|
|
def structured_tool_input(
|
|
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
|
|
) -> str:
|
|
"""Return the arguments directly."""
|
|
return f"{arg1}, {arg2}, {opt_arg}"
|
|
|
|
assert isinstance(structured_tool_input, Tool)
|
|
assert structured_tool_input.name == "structured_tool_input"
|
|
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
|
|
expected_result = "1, 0.001, {'foo': 'bar'}"
|
|
assert structured_tool_input.run(args) == expected_result
|
|
|
|
|
|
def test_empty_args_decorator() -> None:
|
|
"""Test functionality with no args parsed as a decorator."""
|
|
|
|
@tool
|
|
def empty_tool_input() -> str:
|
|
"""Return a constant."""
|
|
return "the empty result"
|
|
|
|
assert isinstance(empty_tool_input, Tool)
|
|
assert empty_tool_input.name == "empty_tool_input"
|
|
assert empty_tool_input.run({}) == "the empty result"
|
|
|
|
|
|
def test_named_tool_decorator() -> None:
|
|
"""Test functionality when arguments are provided as input to decorator."""
|
|
|
|
@tool("search")
|
|
def search_api(query: str) -> str:
|
|
"""Search the API for the query."""
|
|
return "API result"
|
|
|
|
assert isinstance(search_api, Tool)
|
|
assert search_api.name == "search"
|
|
assert not search_api.return_direct
|
|
|
|
|
|
def test_named_tool_decorator_return_direct() -> None:
|
|
"""Test functionality when arguments and return direct are provided as input."""
|
|
|
|
@tool("search", return_direct=True)
|
|
def search_api(query: str) -> str:
|
|
"""Search the API for the query."""
|
|
return "API result"
|
|
|
|
assert isinstance(search_api, Tool)
|
|
assert search_api.name == "search"
|
|
assert search_api.return_direct
|
|
|
|
|
|
def test_unnamed_tool_decorator_return_direct() -> None:
|
|
"""Test functionality when only return direct is provided."""
|
|
|
|
@tool(return_direct=True)
|
|
def search_api(query: str) -> str:
|
|
"""Search the API for the query."""
|
|
return "API result"
|
|
|
|
assert isinstance(search_api, Tool)
|
|
assert search_api.name == "search_api"
|
|
assert search_api.return_direct
|
|
|
|
|
|
def test_tool_with_kwargs() -> None:
|
|
"""Test functionality when only return direct is provided."""
|
|
|
|
@tool(return_direct=True)
|
|
def search_api(
|
|
arg_1: float,
|
|
ping: str = "hi",
|
|
) -> str:
|
|
"""Search the API for the query."""
|
|
return f"arg_1={arg_1}, ping={ping}"
|
|
|
|
assert isinstance(search_api, Tool)
|
|
result = search_api.run(
|
|
tool_input={
|
|
"arg_1": 3.2,
|
|
"ping": "pong",
|
|
}
|
|
)
|
|
assert result == "arg_1=3.2, ping=pong"
|
|
|
|
result = search_api.run(
|
|
tool_input={
|
|
"arg_1": 3.2,
|
|
}
|
|
)
|
|
assert result == "arg_1=3.2, ping=hi"
|
|
|
|
|
|
def test_missing_docstring() -> None:
|
|
"""Test error is raised when docstring is missing."""
|
|
# expect to throw a value error if theres no docstring
|
|
with pytest.raises(AssertionError):
|
|
|
|
@tool
|
|
def search_api(query: str) -> str:
|
|
return "API result"
|
|
|
|
|
|
def test_create_tool_positional_args() -> None:
|
|
"""Test that positional arguments are allowed."""
|
|
test_tool = Tool("test_name", lambda x: x, "test_description")
|
|
assert test_tool("foo") == "foo"
|
|
assert test_tool.name == "test_name"
|
|
assert test_tool.description == "test_description"
|
|
|
|
|
|
def test_create_tool_keyword_args() -> None:
|
|
"""Test that keyword arguments are allowed."""
|
|
test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")
|
|
assert test_tool("foo") == "foo"
|
|
assert test_tool.name == "test_name"
|
|
assert test_tool.description == "test_description"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_async_tool() -> None:
|
|
"""Test that async tools are allowed."""
|
|
|
|
async def _test_func(x: str) -> str:
|
|
return x
|
|
|
|
test_tool = Tool(
|
|
name="test_name",
|
|
func=lambda x: x,
|
|
description="test_description",
|
|
coroutine=_test_func,
|
|
)
|
|
assert test_tool("foo") == "foo"
|
|
assert test_tool.name == "test_name"
|
|
assert test_tool.description == "test_description"
|
|
assert test_tool.coroutine is not None
|
|
assert await test_tool.arun("foo") == "foo"
|