Dynamic tool -> single purpose (#3697)

I think the logic of
https://github.com/hwchase17/langchain/pull/3684#pullrequestreview-1405358565
is too confusing.

I prefer this alternative because:
- All `Tool()` implementations by default will be treated the same as
before. No breaking changes.
- Less reliance on pydantic magic
- The decorator (which only is typed as returning a callable) can infer
schema and generate a structured tool
- Either way, the recommended way to create a custom tool is through
inheriting from the base tool
This commit is contained in:
Zander Chase 2023-04-28 09:38:41 -07:00 committed by GitHub
parent 1bf1c37c0c
commit da7b51455c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 164 additions and 58 deletions

View File

@ -1,15 +1,10 @@
"""Interface for tools."""
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Type, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
from pydantic import BaseModel, validate_arguments, validator
from pydantic import BaseModel, validator
from langchain.tools.base import (
BaseTool,
create_schema_from_function,
get_filtered_args,
)
from langchain.tools.base import BaseTool, StructuredTool
class Tool(BaseTool):
@ -30,17 +25,30 @@ class Tool(BaseTool):
@property
def args(self) -> dict:
"""The tool's input arguments."""
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
inferred_model = validate_arguments(self.func).model # type: ignore
return get_filtered_args(inferred_model, self.func)
# For backwards compatibility, if the function signature is ambiguous,
# assume it takes a single string input.
return {"tool_input": {"type": "string"}}
def _run(self, *args: Any, **kwargs: Any) -> str:
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
"""Convert tool input to pydantic model."""
args, kwargs = super()._to_args_and_kwargs(tool_input)
# For backwards compatibility. The tool must be run with a single input
all_args = list(args) + list(kwargs.values())
if len(all_args) != 1:
raise ValueError(
f"Too many arguments to single-input tool {self.name}."
f" Args: {all_args}"
)
return tuple(all_args), {}
def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool."""
return self.func(*args, **kwargs)
async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
@ -48,7 +56,7 @@ class Tool(BaseTool):
# TODO: this is for backwards compatibility, remove in future
def __init__(
self, name: str, func: Callable[[str], str], description: str, **kwargs: Any
self, name: str, func: Callable, description: str, **kwargs: Any
) -> None:
"""Initialize tool."""
super(Tool, self).__init__(
@ -107,22 +115,24 @@ def tool(
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(func: Callable) -> Tool:
assert func.__doc__, "Function must have a docstring"
# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{tool_name}Schema", func)
tool_ = Tool(
def _make_tool(func: Callable) -> BaseTool:
if infer_schema or args_schema is not None:
return StructuredTool.from_function(
func,
name=tool_name,
return_direct=return_direct,
args_schema=args_schema,
infer_schema=infer_schema,
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
assert func.__doc__ is not None, "Function must have a docstring"
return Tool(
name=tool_name,
func=func,
args_schema=_args_schema,
description=description,
description=f"{tool_name} tool",
return_direct=return_direct,
)
return tool_
return _make_tool

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
from pydantic import (
BaseModel,
@ -19,15 +19,6 @@ from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
# For backwards compatability, if run_input is a string,
# pass as a positional argument.
if isinstance(run_input, str):
return (run_input,), {}
else:
return [], run_input
class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
@ -81,14 +72,20 @@ def _create_subset_model(
return create_model(name, **fields) # type: ignore
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
def get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys}
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
def create_schema_from_function(
model_name: str,
func: Callable,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
inferred_model = validate_arguments(func).model # type: ignore
# Pydantic adds placeholder virtual fields we need to strip
@ -102,12 +99,23 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
"""Interface LangChain tools must implement."""
name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Used to tell the model how/when/why to use the tool.
You can provide few-shot examples as a part of the description.
"""
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False
"""Whether to return the tool's output directly. Setting this to True means
that after the tool is called, the AgentExecutor will stop looping.
"""
verbose: bool = False
"""Whether to log the tool's progress."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
"""Callback manager for this tool."""
class Config:
"""Configuration for this pydantic object."""
@ -160,6 +168,14 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
return (), tool_input
def run(
self,
tool_input: Union[str, Dict],
@ -182,7 +198,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
**kwargs,
)
try:
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
observation = self._run(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
@ -224,8 +240,8 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
)
try:
# We then call the tool on the tool input to get an observation
args, kwargs = _to_args_and_kwargs(tool_input)
observation = await self._arun(*args, **kwargs)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
observation = await self._arun(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_)
@ -249,3 +265,62 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
def __call__(self, tool_input: Union[str, dict]) -> Any:
"""Make tool callable."""
return self.run(tool_input)
class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""
description: str = ""
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
"""The input arguments' schema."""
func: Callable[..., str]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function."""
@property
def args(self) -> dict:
"""The tool's input arguments."""
return self.args_schema.schema()["properties"]
def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool."""
return self.func(*args, **kwargs)
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
raise NotImplementedError("Tool does not support async")
@classmethod
def from_function(
cls,
func: Callable,
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
**kwargs: Any,
) -> StructuredTool:
name = name or func.__name__
description = description or func.__doc__
assert (
description is not None
), "Function must have a docstring if description not provided."
# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{name}{signature(func)} - {description.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", func)
return cls(
name=name,
func=func,
args_schema=_args_schema,
description=description,
return_direct=return_direct,
**kwargs,
)

View File

@ -1,7 +1,7 @@
"""Test tool utils."""
from datetime import datetime
from functools import partial
from typing import Optional, Type, Union
from typing import Any, Optional, Type, Union
from unittest.mock import MagicMock
import pydantic
@ -16,7 +16,7 @@ from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool, tool
from langchain.tools.base import BaseTool, SchemaAnnotationError
from langchain.tools.base import BaseTool, SchemaAnnotationError, StructuredTool
def test_unnamed_decorator() -> None:
@ -27,7 +27,7 @@ def test_unnamed_decorator() -> None:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert isinstance(search_api, BaseTool)
assert search_api.name == "search_api"
assert not search_api.return_direct
assert search_api("test") == "API result"
@ -145,7 +145,7 @@ def test_decorator_with_specified_schema() -> None:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(tool_func, Tool)
assert isinstance(tool_func, BaseTool)
assert tool_func.args_schema == _MockSchema
@ -159,7 +159,7 @@ def test_decorated_function_schema_equivalent() -> None:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(structured_tool_input, Tool)
assert isinstance(structured_tool_input, BaseTool)
assert structured_tool_input.args_schema is not None
assert (
structured_tool_input.args_schema.schema()["properties"]
@ -171,14 +171,14 @@ def test_decorated_function_schema_equivalent() -> None:
def test_structured_args_decorator_no_infer_schema() -> None:
"""Test functionality with structured arguments parsed as a decorator."""
@tool(infer_schema=False)
@tool
def structured_tool_input(
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1}, {arg2}, {opt_arg}"
assert isinstance(structured_tool_input, Tool)
assert isinstance(structured_tool_input, BaseTool)
assert structured_tool_input.name == "structured_tool_input"
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
expected_result = "1, 0.001, {'foo': 'bar'}"
@ -193,8 +193,9 @@ def test_structured_single_str_decorator_no_infer_schema() -> None:
"""Return the arguments directly."""
return f"{tool_input}"
assert isinstance(unstructured_tool_input, Tool)
assert isinstance(unstructured_tool_input, BaseTool)
assert unstructured_tool_input.args_schema is None
assert unstructured_tool_input.run("foo") == "foo"
def test_base_tool_inheritance_base_schema() -> None:
@ -225,18 +226,18 @@ def test_tool_lambda_args_schema() -> None:
func=lambda tool_input: tool_input,
)
assert tool.args_schema is None
expected_args = {"tool_input": {"title": "Tool Input"}}
expected_args = {"tool_input": {"type": "string"}}
assert tool.args == expected_args
def test_tool_lambda_multi_args_schema() -> None:
def test_structured_tool_lambda_multi_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = Tool(
tool = StructuredTool.from_function(
name="tool",
description="A tool",
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
)
assert tool.args_schema is None
assert tool.args_schema is not None
expected_args = {
"tool_input": {"title": "Tool Input"},
"other_arg": {"title": "Other Arg"},
@ -268,7 +269,7 @@ def test_empty_args_decorator() -> None:
"""Return a constant."""
return "the empty result"
assert isinstance(empty_tool_input, Tool)
assert isinstance(empty_tool_input, BaseTool)
assert empty_tool_input.name == "empty_tool_input"
assert empty_tool_input.args == {}
assert empty_tool_input.run({}) == "the empty result"
@ -282,7 +283,7 @@ def test_named_tool_decorator() -> None:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert isinstance(search_api, BaseTool)
assert search_api.name == "search"
assert not search_api.return_direct
@ -295,7 +296,7 @@ def test_named_tool_decorator_return_direct() -> None:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert isinstance(search_api, BaseTool)
assert search_api.name == "search"
assert search_api.return_direct
@ -308,7 +309,7 @@ def test_unnamed_tool_decorator_return_direct() -> None:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, Tool)
assert isinstance(search_api, BaseTool)
assert search_api.name == "search_api"
assert search_api.return_direct
@ -325,7 +326,7 @@ def test_tool_with_kwargs() -> None:
"""Search the API for the query."""
return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}"
assert isinstance(search_api, Tool)
assert isinstance(search_api, BaseTool)
result = search_api.run(
tool_input={
"arg_0": "foo",
@ -423,3 +424,23 @@ def test_single_input_agent_raises_error_on_structured_tool(
f" multi-input tool {the_tool.name}.",
):
agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore
def test_tool_no_args_specified_assumes_str() -> None:
"""Older tools could assume *args and **kwargs were passed in."""
def ambiguous_function(*args: Any, **kwargs: Any) -> str:
"""An ambiguously defined function."""
return args[0]
some_tool = Tool(
name="chain_run",
description="Run the chain",
func=ambiguous_function,
)
expected_args = {"tool_input": {"type": "string"}}
assert some_tool.args == expected_args
assert some_tool.run("foobar") == "foobar"
assert some_tool.run({"tool_input": "foobar"}) == "foobar"
with pytest.raises(ValueError, match="Too many arguments to single-input tool"):
some_tool.run({"tool_input": "foobar", "other_input": "bar"})