Dynamic tool -> single purpose (#3697)

I think the logic of
https://github.com/hwchase17/langchain/pull/3684#pullrequestreview-1405358565
is too confusing.

I prefer this alternative because:
- All `Tool()` implementations by default will be treated the same as
before. No breaking changes.
- Less reliance on pydantic magic
- The decorator (which only is typed as returning a callable) can infer
schema and generate a structured tool
- Either way, the recommended way to create a custom tool is through
inheriting from the base tool
This commit is contained in:
Zander Chase 2023-04-28 09:38:41 -07:00 committed by GitHub
parent 1bf1c37c0c
commit da7b51455c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 164 additions and 58 deletions

View File

@ -1,15 +1,10 @@
"""Interface for tools.""" """Interface for tools."""
from functools import partial from functools import partial
from inspect import signature from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
from typing import Any, Awaitable, Callable, Optional, Type, Union
from pydantic import BaseModel, validate_arguments, validator from pydantic import BaseModel, validator
from langchain.tools.base import ( from langchain.tools.base import BaseTool, StructuredTool
BaseTool,
create_schema_from_function,
get_filtered_args,
)
class Tool(BaseTool): class Tool(BaseTool):
@ -30,17 +25,30 @@ class Tool(BaseTool):
@property @property
def args(self) -> dict: def args(self) -> dict:
"""The tool's input arguments."""
if self.args_schema is not None: if self.args_schema is not None:
return self.args_schema.schema()["properties"] return self.args_schema.schema()["properties"]
else: # For backwards compatibility, if the function signature is ambiguous,
inferred_model = validate_arguments(self.func).model # type: ignore # assume it takes a single string input.
return get_filtered_args(inferred_model, self.func) return {"tool_input": {"type": "string"}}
def _run(self, *args: Any, **kwargs: Any) -> str: def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
"""Convert tool input to pydantic model."""
args, kwargs = super()._to_args_and_kwargs(tool_input)
# For backwards compatibility. The tool must be run with a single input
all_args = list(args) + list(kwargs.values())
if len(all_args) != 1:
raise ValueError(
f"Too many arguments to single-input tool {self.name}."
f" Args: {all_args}"
)
return tuple(all_args), {}
def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool.""" """Use the tool."""
return self.func(*args, **kwargs) return self.func(*args, **kwargs)
async def _arun(self, *args: Any, **kwargs: Any) -> str: async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
if self.coroutine: if self.coroutine:
return await self.coroutine(*args, **kwargs) return await self.coroutine(*args, **kwargs)
@ -48,7 +56,7 @@ class Tool(BaseTool):
# TODO: this is for backwards compatibility, remove in future # TODO: this is for backwards compatibility, remove in future
def __init__( def __init__(
self, name: str, func: Callable[[str], str], description: str, **kwargs: Any self, name: str, func: Callable, description: str, **kwargs: Any
) -> None: ) -> None:
"""Initialize tool.""" """Initialize tool."""
super(Tool, self).__init__( super(Tool, self).__init__(
@ -107,22 +115,24 @@ def tool(
""" """
def _make_with_name(tool_name: str) -> Callable: def _make_with_name(tool_name: str) -> Callable:
def _make_tool(func: Callable) -> Tool: def _make_tool(func: Callable) -> BaseTool:
assert func.__doc__, "Function must have a docstring" if infer_schema or args_schema is not None:
# Description example: return StructuredTool.from_function(
# search_api(query: str) - Searches the API for the query. func,
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}" name=tool_name,
_args_schema = args_schema return_direct=return_direct,
if _args_schema is None and infer_schema: args_schema=args_schema,
_args_schema = create_schema_from_function(f"{tool_name}Schema", func) infer_schema=infer_schema,
tool_ = Tool( )
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
assert func.__doc__ is not None, "Function must have a docstring"
return Tool(
name=tool_name, name=tool_name,
func=func, func=func,
args_schema=_args_schema, description=f"{tool_name} tool",
description=description,
return_direct=return_direct, return_direct=return_direct,
) )
return tool_
return _make_tool return _make_tool

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
@ -19,15 +19,6 @@ from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
# For backwards compatability, if run_input is a string,
# pass as a positional argument.
if isinstance(run_input, str):
return (run_input,), {}
else:
return [], run_input
class SchemaAnnotationError(TypeError): class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation.""" """Raised when 'args_schema' is missing or has an incorrect type annotation."""
@ -81,14 +72,20 @@ def _create_subset_model(
return create_model(name, **fields) # type: ignore return create_model(name, **fields) # type: ignore
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict: def get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
) -> dict:
"""Get the arguments from a function's signature.""" """Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"] schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys} return {k: schema[k] for k in valid_keys}
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]: def create_schema_from_function(
model_name: str,
func: Callable,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature.""" """Create a pydantic schema from a function's signature."""
inferred_model = validate_arguments(func).model # type: ignore inferred_model = validate_arguments(func).model # type: ignore
# Pydantic adds placeholder virtual fields we need to strip # Pydantic adds placeholder virtual fields we need to strip
@ -102,12 +99,23 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
"""Interface LangChain tools must implement.""" """Interface LangChain tools must implement."""
name: str name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str description: str
"""Used to tell the model how/when/why to use the tool.
You can provide few-shot examples as a part of the description.
"""
args_schema: Optional[Type[BaseModel]] = None args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments.""" """Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False return_direct: bool = False
"""Whether to return the tool's output directly. Setting this to True means
that after the tool is called, the AgentExecutor will stop looping.
"""
verbose: bool = False verbose: bool = False
"""Whether to log the tool's progress."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
"""Callback manager for this tool."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -160,6 +168,14 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
async def _arun(self, *args: Any, **kwargs: Any) -> Any: async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
return (), tool_input
def run( def run(
self, self,
tool_input: Union[str, Dict], tool_input: Union[str, Dict],
@ -182,7 +198,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
**kwargs, **kwargs,
) )
try: try:
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input) tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
observation = self._run(*tool_args, **tool_kwargs) observation = self._run(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_) self.callback_manager.on_tool_error(e, verbose=verbose_)
@ -224,8 +240,8 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
) )
try: try:
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
args, kwargs = _to_args_and_kwargs(tool_input) tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
observation = await self._arun(*args, **kwargs) observation = await self._arun(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async: if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_) await self.callback_manager.on_tool_error(e, verbose=verbose_)
@ -249,3 +265,62 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
def __call__(self, tool_input: Union[str, dict]) -> Any: def __call__(self, tool_input: Union[str, dict]) -> Any:
"""Make tool callable.""" """Make tool callable."""
return self.run(tool_input) return self.run(tool_input)
class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""
description: str = ""
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
"""The input arguments' schema."""
func: Callable[..., str]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function."""
@property
def args(self) -> dict:
"""The tool's input arguments."""
return self.args_schema.schema()["properties"]
def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool."""
return self.func(*args, **kwargs)
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
raise NotImplementedError("Tool does not support async")
@classmethod
def from_function(
cls,
func: Callable,
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
**kwargs: Any,
) -> StructuredTool:
name = name or func.__name__
description = description or func.__doc__
assert (
description is not None
), "Function must have a docstring if description not provided."
# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{name}{signature(func)} - {description.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", func)
return cls(
name=name,
func=func,
args_schema=_args_schema,
description=description,
return_direct=return_direct,
**kwargs,
)

View File

@ -1,7 +1,7 @@
"""Test tool utils.""" """Test tool utils."""
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Optional, Type, Union from typing import Any, Optional, Type, Union
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pydantic import pydantic
@ -16,7 +16,7 @@ from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool, tool from langchain.agents.tools import Tool, tool
from langchain.tools.base import BaseTool, SchemaAnnotationError from langchain.tools.base import BaseTool, SchemaAnnotationError, StructuredTool
def test_unnamed_decorator() -> None: def test_unnamed_decorator() -> None:
@ -27,7 +27,7 @@ def test_unnamed_decorator() -> None:
"""Search the API for the query.""" """Search the API for the query."""
return "API result" return "API result"
assert isinstance(search_api, Tool) assert isinstance(search_api, BaseTool)
assert search_api.name == "search_api" assert search_api.name == "search_api"
assert not search_api.return_direct assert not search_api.return_direct
assert search_api("test") == "API result" assert search_api("test") == "API result"
@ -145,7 +145,7 @@ def test_decorator_with_specified_schema() -> None:
"""Return the arguments directly.""" """Return the arguments directly."""
return f"{arg1} {arg2} {arg3}" return f"{arg1} {arg2} {arg3}"
assert isinstance(tool_func, Tool) assert isinstance(tool_func, BaseTool)
assert tool_func.args_schema == _MockSchema assert tool_func.args_schema == _MockSchema
@ -159,7 +159,7 @@ def test_decorated_function_schema_equivalent() -> None:
"""Return the arguments directly.""" """Return the arguments directly."""
return f"{arg1} {arg2} {arg3}" return f"{arg1} {arg2} {arg3}"
assert isinstance(structured_tool_input, Tool) assert isinstance(structured_tool_input, BaseTool)
assert structured_tool_input.args_schema is not None assert structured_tool_input.args_schema is not None
assert ( assert (
structured_tool_input.args_schema.schema()["properties"] structured_tool_input.args_schema.schema()["properties"]
@ -171,14 +171,14 @@ def test_decorated_function_schema_equivalent() -> None:
def test_structured_args_decorator_no_infer_schema() -> None: def test_structured_args_decorator_no_infer_schema() -> None:
"""Test functionality with structured arguments parsed as a decorator.""" """Test functionality with structured arguments parsed as a decorator."""
@tool(infer_schema=False) @tool
def structured_tool_input( def structured_tool_input(
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
) -> str: ) -> str:
"""Return the arguments directly.""" """Return the arguments directly."""
return f"{arg1}, {arg2}, {opt_arg}" return f"{arg1}, {arg2}, {opt_arg}"
assert isinstance(structured_tool_input, Tool) assert isinstance(structured_tool_input, BaseTool)
assert structured_tool_input.name == "structured_tool_input" assert structured_tool_input.name == "structured_tool_input"
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}} args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
expected_result = "1, 0.001, {'foo': 'bar'}" expected_result = "1, 0.001, {'foo': 'bar'}"
@ -193,8 +193,9 @@ def test_structured_single_str_decorator_no_infer_schema() -> None:
"""Return the arguments directly.""" """Return the arguments directly."""
return f"{tool_input}" return f"{tool_input}"
assert isinstance(unstructured_tool_input, Tool) assert isinstance(unstructured_tool_input, BaseTool)
assert unstructured_tool_input.args_schema is None assert unstructured_tool_input.args_schema is None
assert unstructured_tool_input.run("foo") == "foo"
def test_base_tool_inheritance_base_schema() -> None: def test_base_tool_inheritance_base_schema() -> None:
@ -225,18 +226,18 @@ def test_tool_lambda_args_schema() -> None:
func=lambda tool_input: tool_input, func=lambda tool_input: tool_input,
) )
assert tool.args_schema is None assert tool.args_schema is None
expected_args = {"tool_input": {"title": "Tool Input"}} expected_args = {"tool_input": {"type": "string"}}
assert tool.args == expected_args assert tool.args == expected_args
def test_tool_lambda_multi_args_schema() -> None: def test_structured_tool_lambda_multi_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function.""" """Test args schema inference when the tool argument is a lambda function."""
tool = Tool( tool = StructuredTool.from_function(
name="tool", name="tool",
description="A tool", description="A tool",
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
) )
assert tool.args_schema is None assert tool.args_schema is not None
expected_args = { expected_args = {
"tool_input": {"title": "Tool Input"}, "tool_input": {"title": "Tool Input"},
"other_arg": {"title": "Other Arg"}, "other_arg": {"title": "Other Arg"},
@ -268,7 +269,7 @@ def test_empty_args_decorator() -> None:
"""Return a constant.""" """Return a constant."""
return "the empty result" return "the empty result"
assert isinstance(empty_tool_input, Tool) assert isinstance(empty_tool_input, BaseTool)
assert empty_tool_input.name == "empty_tool_input" assert empty_tool_input.name == "empty_tool_input"
assert empty_tool_input.args == {} assert empty_tool_input.args == {}
assert empty_tool_input.run({}) == "the empty result" assert empty_tool_input.run({}) == "the empty result"
@ -282,7 +283,7 @@ def test_named_tool_decorator() -> None:
"""Search the API for the query.""" """Search the API for the query."""
return "API result" return "API result"
assert isinstance(search_api, Tool) assert isinstance(search_api, BaseTool)
assert search_api.name == "search" assert search_api.name == "search"
assert not search_api.return_direct assert not search_api.return_direct
@ -295,7 +296,7 @@ def test_named_tool_decorator_return_direct() -> None:
"""Search the API for the query.""" """Search the API for the query."""
return "API result" return "API result"
assert isinstance(search_api, Tool) assert isinstance(search_api, BaseTool)
assert search_api.name == "search" assert search_api.name == "search"
assert search_api.return_direct assert search_api.return_direct
@ -308,7 +309,7 @@ def test_unnamed_tool_decorator_return_direct() -> None:
"""Search the API for the query.""" """Search the API for the query."""
return "API result" return "API result"
assert isinstance(search_api, Tool) assert isinstance(search_api, BaseTool)
assert search_api.name == "search_api" assert search_api.name == "search_api"
assert search_api.return_direct assert search_api.return_direct
@ -325,7 +326,7 @@ def test_tool_with_kwargs() -> None:
"""Search the API for the query.""" """Search the API for the query."""
return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}" return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}"
assert isinstance(search_api, Tool) assert isinstance(search_api, BaseTool)
result = search_api.run( result = search_api.run(
tool_input={ tool_input={
"arg_0": "foo", "arg_0": "foo",
@ -423,3 +424,23 @@ def test_single_input_agent_raises_error_on_structured_tool(
f" multi-input tool {the_tool.name}.", f" multi-input tool {the_tool.name}.",
): ):
agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore
def test_tool_no_args_specified_assumes_str() -> None:
"""Older tools could assume *args and **kwargs were passed in."""
def ambiguous_function(*args: Any, **kwargs: Any) -> str:
"""An ambiguously defined function."""
return args[0]
some_tool = Tool(
name="chain_run",
description="Run the chain",
func=ambiguous_function,
)
expected_args = {"tool_input": {"type": "string"}}
assert some_tool.args == expected_args
assert some_tool.run("foobar") == "foobar"
assert some_tool.run({"tool_input": "foobar"}) == "foobar"
with pytest.raises(ValueError, match="Too many arguments to single-input tool"):
some_tool.run({"tool_input": "foobar", "other_input": "bar"})