2023-01-29 02:26:24 +00:00
|
|
|
"""Test tool utils."""
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from langchain.agents.tools import Tool, tool
|
|
|
|
|
|
|
|
|
|
|
|
def test_unnamed_decorator() -> None:
|
|
|
|
"""Test functionality with unnamed decorator."""
|
|
|
|
|
|
|
|
@tool
|
|
|
|
def search_api(query: str) -> str:
|
|
|
|
"""Search the API for the query."""
|
|
|
|
return "API result"
|
|
|
|
|
|
|
|
assert isinstance(search_api, Tool)
|
|
|
|
assert search_api.name == "search_api"
|
|
|
|
assert not search_api.return_direct
|
|
|
|
assert search_api("test") == "API result"
|
|
|
|
|
|
|
|
|
|
|
|
def test_named_tool_decorator() -> None:
|
|
|
|
"""Test functionality when arguments are provided as input to decorator."""
|
|
|
|
|
|
|
|
@tool("search")
|
|
|
|
def search_api(query: str) -> str:
|
|
|
|
"""Search the API for the query."""
|
|
|
|
return "API result"
|
|
|
|
|
|
|
|
assert isinstance(search_api, Tool)
|
|
|
|
assert search_api.name == "search"
|
|
|
|
assert not search_api.return_direct
|
|
|
|
|
|
|
|
|
|
|
|
def test_named_tool_decorator_return_direct() -> None:
|
|
|
|
"""Test functionality when arguments and return direct are provided as input."""
|
|
|
|
|
|
|
|
@tool("search", return_direct=True)
|
|
|
|
def search_api(query: str) -> str:
|
|
|
|
"""Search the API for the query."""
|
|
|
|
return "API result"
|
|
|
|
|
|
|
|
assert isinstance(search_api, Tool)
|
|
|
|
assert search_api.name == "search"
|
|
|
|
assert search_api.return_direct
|
|
|
|
|
|
|
|
|
|
|
|
def test_unnamed_tool_decorator_return_direct() -> None:
|
|
|
|
"""Test functionality when only return direct is provided."""
|
|
|
|
|
|
|
|
@tool(return_direct=True)
|
|
|
|
def search_api(query: str) -> str:
|
|
|
|
"""Search the API for the query."""
|
|
|
|
return "API result"
|
|
|
|
|
|
|
|
assert isinstance(search_api, Tool)
|
|
|
|
assert search_api.name == "search_api"
|
|
|
|
assert search_api.return_direct
|
|
|
|
|
|
|
|
|
|
|
|
def test_missing_docstring() -> None:
|
|
|
|
"""Test error is raised when docstring is missing."""
|
|
|
|
# expect to throw a value error if theres no docstring
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
|
|
|
|
@tool
|
|
|
|
def search_api(query: str) -> str:
|
|
|
|
return "API result"
|
2023-02-18 21:40:43 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_create_tool_posistional_args() -> None:
|
|
|
|
"""Test that positional arguments are allowed."""
|
|
|
|
test_tool = Tool("test_name", lambda x: x, "test_description")
|
|
|
|
assert test_tool("foo") == "foo"
|
|
|
|
assert test_tool.name == "test_name"
|
|
|
|
assert test_tool.description == "test_description"
|
|
|
|
|
|
|
|
|
|
|
|
def test_create_tool_keyword_args() -> None:
|
|
|
|
"""Test that keyword arguments are allowed."""
|
|
|
|
test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")
|
|
|
|
assert test_tool("foo") == "foo"
|
|
|
|
assert test_tool.name == "test_name"
|
|
|
|
assert test_tool.description == "test_description"
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_create_async_tool() -> None:
|
|
|
|
"""Test that async tools are allowed."""
|
|
|
|
|
|
|
|
async def _test_func(x: str) -> str:
|
|
|
|
return x
|
|
|
|
|
|
|
|
test_tool = Tool(
|
|
|
|
name="test_name",
|
|
|
|
func=lambda x: x,
|
|
|
|
description="test_description",
|
|
|
|
coroutine=_test_func,
|
|
|
|
)
|
|
|
|
assert test_tool("foo") == "foo"
|
|
|
|
assert test_tool.name == "test_name"
|
|
|
|
assert test_tool.description == "test_description"
|
|
|
|
assert test_tool.coroutine is not None
|
2023-02-21 06:54:15 +00:00
|
|
|
assert await test_tool.arun("foo") == "foo"
|