diff --git a/docs/modules/agents/examples/custom_tools.ipynb b/docs/modules/agents/examples/custom_tools.ipynb index ae814782..0c7339e6 100644 --- a/docs/modules/agents/examples/custom_tools.ipynb +++ b/docs/modules/agents/examples/custom_tools.ipynb @@ -10,15 +10,17 @@ "When constructing your own agent, you will need to provide it with a list of Tools that it can use. A Tool is defined as below.\n", "\n", "```python\n", - "class Tool(NamedTuple):\n", + "@dataclass \n", + "class Tool:\n", " \"\"\"Interface for tools.\"\"\"\n", "\n", " name: str\n", " func: Callable[[str], str]\n", " description: Optional[str] = None\n", + " return_direct: bool = True\n", "```\n", "\n", - "The two required components of a Tool are the name and then the tool itself. A tool description is optional, as it is needed for some agents but not all." + "The two required components of a Tool are the name and then the tool itself. A tool description is optional, as it is needed for some agents but not all. You can create these tools directly, but we also provide a decorator to easily convert any function into a tool." ] }, { @@ -151,6 +153,94 @@ "agent.run(\"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?\")" ] }, + { + "cell_type": "markdown", + "id": "824eaf74", + "metadata": {}, + "source": [ + "## Using the `tool` decorator\n", + "\n", + "To make it easier to define custom tools, a `@tool` decorator is provided. This decorator can be used to quickly create a `Tool` from a simple function. The decorator uses the function name as the tool name by default, but this can be overridden by passing a string as the first argument. Additionally, the decorator will use the function's docstring as the tool's description." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8f15307d", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.agents import tool\n", + "\n", + "@tool\n", + "def search_api(query: str) -> str:\n", + " \"\"\"Searches the API for the query.\"\"\"\n", + " return \"Results\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0a23b91b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Tool(name='search_api', func=, description='search_api(query: str) -> str - Searches the API for the query.', return_direct=False)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "search_api" + ] + }, + { + "cell_type": "markdown", + "id": "cc6ee8c1", + "metadata": {}, + "source": [ + "You can also provide arguments like the tool name and whether to return directly." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "28cdf04d", + "metadata": {}, + "outputs": [], + "source": [ + "@tool(\"search\", return_direct=True)\n", + "def search_api(query: str) -> str:\n", + " \"\"\"Searches the API for the query.\"\"\"\n", + " return \"Results\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1085a4bd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Tool(name='search', func=, description='search(query: str) -> str - Searches the API for the query.', return_direct=True)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "search_api" + ] + }, { "cell_type": "markdown", "id": "1d0430d6", @@ -432,7 +522,7 @@ }, "vscode": { "interpreter": { - "hash": "cb23c3a7a387ab03496baa08507270f8e0861b23170e79d5edc545893cdca840" + "hash": "e90c8aa204a57276aa905271aff2d11799d0acb3547adabc5892e639a5e45e34" } } }, diff --git a/langchain/agents/__init__.py b/langchain/agents/__init__.py index b25d17a1..314814b7 100644 --- a/langchain/agents/__init__.py +++ b/langchain/agents/__init__.py @@ -7,7 +7,7 @@ from langchain.agents.loading import load_agent from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent from langchain.agents.react.base import ReActChain, ReActTextWorldAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain -from langchain.agents.tools import Tool +from langchain.agents.tools import Tool, tool __all__ = [ "MRKLChain", @@ -16,6 +16,7 @@ __all__ = [ "AgentExecutor", "Agent", "Tool", + "tool", "initialize_agent", "ZeroShotAgent", "ReActTextWorldAgent", diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index cdcd0650..fb65bdbc 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -1,6 +1,7 @@ """Interface for tools.""" from dataclasses import dataclass -from typing import Callable, Optional +from inspect import signature +from typing import Any, Callable, Optional, Union @dataclass @@ -11,3 +12,65 @@ class Tool: func: Callable[[str], str] description: Optional[str] = None return_direct: bool = False + + def __call__(self, *args: Any, **kwargs: Any) -> str: + """Make tools callable by piping through to `func`.""" + return self.func(*args, **kwargs) + + +def tool( + *args: Union[str, Callable], return_direct: bool = False +) -> Union[Callable, Tool]: + """Make tools out of functions, can be used with or without arguments. + + Requires: + - Function must be of type (str) -> str + - Function must have a docstring + + Examples: + .. code-block:: python + + @tool + def search_api(query: str) -> str: + # Searches the API for the query. + return + + @tool("search", return_direct=True) + def search_api(query: str) -> str: + # Searches the API for the query. + return + """ + + def _make_with_name(tool_name: str) -> Callable: + def _make_tool(func: Callable[[str], str]) -> Tool: + assert func.__doc__, "Function must have a docstring" + # Description example: + # search_api(query: str) - Searches the API for the query. + description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}" + tool = Tool( + name=tool_name, + func=func, + description=description, + return_direct=return_direct, + ) + return tool + + return _make_tool + + if len(args) == 1 and isinstance(args[0], str): + # if the argument is a string, then we use the string as the tool name + # Example usage: @tool("search", return_direct=True) + return _make_with_name(args[0]) + elif len(args) == 1 and callable(args[0]): + # if the argument is a function, then we use the function name as the tool name + # Example usage: @tool + return _make_with_name(args[0].__name__)(args[0]) + elif len(args) == 0: + # if there are no arguments, then we use the function name as the tool name + # Example usage: @tool(return_direct=True) + def _partial(func: Callable[[str], str]) -> Tool: + return _make_with_name(func.__name__)(func) + + return _partial + else: + raise ValueError("Too many arguments for tool decorator") diff --git a/tests/unit_tests/agents/test_tools.py b/tests/unit_tests/agents/test_tools.py new file mode 100644 index 00000000..f3f27d85 --- /dev/null +++ b/tests/unit_tests/agents/test_tools.py @@ -0,0 +1,67 @@ +"""Test tool utils.""" +import pytest + +from langchain.agents.tools import Tool, tool + + +def test_unnamed_decorator() -> None: + """Test functionality with unnamed decorator.""" + + @tool + def search_api(query: str) -> str: + """Search the API for the query.""" + return "API result" + + assert isinstance(search_api, Tool) + assert search_api.name == "search_api" + assert not search_api.return_direct + assert search_api("test") == "API result" + + +def test_named_tool_decorator() -> None: + """Test functionality when arguments are provided as input to decorator.""" + + @tool("search") + def search_api(query: str) -> str: + """Search the API for the query.""" + return "API result" + + assert isinstance(search_api, Tool) + assert search_api.name == "search" + assert not search_api.return_direct + + +def test_named_tool_decorator_return_direct() -> None: + """Test functionality when arguments and return direct are provided as input.""" + + @tool("search", return_direct=True) + def search_api(query: str) -> str: + """Search the API for the query.""" + return "API result" + + assert isinstance(search_api, Tool) + assert search_api.name == "search" + assert search_api.return_direct + + +def test_unnamed_tool_decorator_return_direct() -> None: + """Test functionality when only return direct is provided.""" + + @tool(return_direct=True) + def search_api(query: str) -> str: + """Search the API for the query.""" + return "API result" + + assert isinstance(search_api, Tool) + assert search_api.name == "search_api" + assert search_api.return_direct + + +def test_missing_docstring() -> None: + """Test error is raised when docstring is missing.""" + # expect to throw a value error if theres no docstring + with pytest.raises(AssertionError): + + @tool + def search_api(query: str) -> str: + return "API result"