Harrison/tool decorator (#790)

Co-authored-by: Jason Liu <jxnl@users.noreply.github.com>
Co-authored-by: Jason Liu <jason@jxnl.coA>
ankush/async-llmchain
Harrison Chase 1 year ago committed by GitHub
parent 5f73d06502
commit 1ad7973cc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,15 +10,17 @@
"When constructing your own agent, you will need to provide it with a list of Tools that it can use. A Tool is defined as below.\n",
"\n",
"```python\n",
"class Tool(NamedTuple):\n",
"@dataclass \n",
"class Tool:\n",
" \"\"\"Interface for tools.\"\"\"\n",
"\n",
" name: str\n",
" func: Callable[[str], str]\n",
" description: Optional[str] = None\n",
" return_direct: bool = True\n",
"```\n",
"\n",
"The two required components of a Tool are the name and then the tool itself. A tool description is optional, as it is needed for some agents but not all."
"The two required components of a Tool are the name and then the tool itself. A tool description is optional, as it is needed for some agents but not all. You can create these tools directly, but we also provide a decorator to easily convert any function into a tool."
]
},
{
@ -151,6 +153,94 @@
"agent.run(\"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?\")"
]
},
{
"cell_type": "markdown",
"id": "824eaf74",
"metadata": {},
"source": [
"## Using the `tool` decorator\n",
"\n",
"To make it easier to define custom tools, a `@tool` decorator is provided. This decorator can be used to quickly create a `Tool` from a simple function. The decorator uses the function name as the tool name by default, but this can be overridden by passing a string as the first argument. Additionally, the decorator will use the function's docstring as the tool's description."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8f15307d",
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import tool\n",
"\n",
"@tool\n",
"def search_api(query: str) -> str:\n",
" \"\"\"Searches the API for the query.\"\"\"\n",
" return \"Results\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0a23b91b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tool(name='search_api', func=<function search_api at 0x10dad7d90>, description='search_api(query: str) -> str - Searches the API for the query.', return_direct=False)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"search_api"
]
},
{
"cell_type": "markdown",
"id": "cc6ee8c1",
"metadata": {},
"source": [
"You can also provide arguments like the tool name and whether to return directly."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "28cdf04d",
"metadata": {},
"outputs": [],
"source": [
"@tool(\"search\", return_direct=True)\n",
"def search_api(query: str) -> str:\n",
" \"\"\"Searches the API for the query.\"\"\"\n",
" return \"Results\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1085a4bd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tool(name='search', func=<function search_api at 0x112301bd0>, description='search(query: str) -> str - Searches the API for the query.', return_direct=True)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"search_api"
]
},
{
"cell_type": "markdown",
"id": "1d0430d6",
@ -432,7 +522,7 @@
},
"vscode": {
"interpreter": {
"hash": "cb23c3a7a387ab03496baa08507270f8e0861b23170e79d5edc545893cdca840"
"hash": "e90c8aa204a57276aa905271aff2d11799d0acb3547adabc5892e639a5e45e34"
}
}
},

@ -7,7 +7,7 @@ from langchain.agents.loading import load_agent
from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent
from langchain.agents.react.base import ReActChain, ReActTextWorldAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
from langchain.agents.tools import Tool
from langchain.agents.tools import Tool, tool
__all__ = [
"MRKLChain",
@ -16,6 +16,7 @@ __all__ = [
"AgentExecutor",
"Agent",
"Tool",
"tool",
"initialize_agent",
"ZeroShotAgent",
"ReActTextWorldAgent",

@ -1,6 +1,7 @@
"""Interface for tools."""
from dataclasses import dataclass
from typing import Callable, Optional
from inspect import signature
from typing import Any, Callable, Optional, Union
@dataclass
@ -11,3 +12,65 @@ class Tool:
func: Callable[[str], str]
description: Optional[str] = None
return_direct: bool = False
def __call__(self, *args: Any, **kwargs: Any) -> str:
"""Make tools callable by piping through to `func`."""
return self.func(*args, **kwargs)
def tool(
*args: Union[str, Callable], return_direct: bool = False
) -> Union[Callable, Tool]:
"""Make tools out of functions, can be used with or without arguments.
Requires:
- Function must be of type (str) -> str
- Function must have a docstring
Examples:
.. code-block:: python
@tool
def search_api(query: str) -> str:
# Searches the API for the query.
return
@tool("search", return_direct=True)
def search_api(query: str) -> str:
# Searches the API for the query.
return
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(func: Callable[[str], str]) -> Tool:
assert func.__doc__, "Function must have a docstring"
# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
tool = Tool(
name=tool_name,
func=func,
description=description,
return_direct=return_direct,
)
return tool
return _make_tool
if len(args) == 1 and isinstance(args[0], str):
# if the argument is a string, then we use the string as the tool name
# Example usage: @tool("search", return_direct=True)
return _make_with_name(args[0])
elif len(args) == 1 and callable(args[0]):
# if the argument is a function, then we use the function name as the tool name
# Example usage: @tool
return _make_with_name(args[0].__name__)(args[0])
elif len(args) == 0:
# if there are no arguments, then we use the function name as the tool name
# Example usage: @tool(return_direct=True)
def _partial(func: Callable[[str], str]) -> Tool:
return _make_with_name(func.__name__)(func)
return _partial
else:
raise ValueError("Too many arguments for tool decorator")

@ -0,0 +1,67 @@
"""Test tool utils."""
import pytest
from langchain.agents.tools import Tool, tool
def test_unnamed_decorator() -> None:
"""Test functionality with unnamed decorator."""
@tool
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert search_api.name == "search_api"
assert not search_api.return_direct
assert search_api("test") == "API result"
def test_named_tool_decorator() -> None:
"""Test functionality when arguments are provided as input to decorator."""
@tool("search")
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert search_api.name == "search"
assert not search_api.return_direct
def test_named_tool_decorator_return_direct() -> None:
"""Test functionality when arguments and return direct are provided as input."""
@tool("search", return_direct=True)
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert search_api.name == "search"
assert search_api.return_direct
def test_unnamed_tool_decorator_return_direct() -> None:
"""Test functionality when only return direct is provided."""
@tool(return_direct=True)
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert search_api.name == "search_api"
assert search_api.return_direct
def test_missing_docstring() -> None:
"""Test error is raised when docstring is missing."""
# expect to throw a value error if theres no docstring
with pytest.raises(AssertionError):
@tool
def search_api(query: str) -> str:
return "API result"
Loading…
Cancel
Save