From da7b51455c3739cec6e2617cee9cd4f24508b569 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Fri, 28 Apr 2023 09:38:41 -0700 Subject: [PATCH] Dynamic tool -> single purpose (#3697) I think the logic of https://github.com/hwchase17/langchain/pull/3684#pullrequestreview-1405358565 is too confusing. I prefer this alternative because: - All `Tool()` implementations by default will be treated the same as before. No breaking changes. - Less reliance on pydantic magic - The decorator (which only is typed as returning a callable) can infer schema and generate a structured tool - Either way, the recommended way to create a custom tool is through inheriting from the base tool --- langchain/agents/tools.py | 62 ++++++++------- langchain/tools/base.py | 105 ++++++++++++++++++++++---- tests/unit_tests/agents/test_tools.py | 55 +++++++++----- 3 files changed, 164 insertions(+), 58 deletions(-) diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index 913110ba..7a6637c2 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -1,15 +1,10 @@ """Interface for tools.""" from functools import partial -from inspect import signature -from typing import Any, Awaitable, Callable, Optional, Type, Union +from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union -from pydantic import BaseModel, validate_arguments, validator +from pydantic import BaseModel, validator -from langchain.tools.base import ( - BaseTool, - create_schema_from_function, - get_filtered_args, -) +from langchain.tools.base import BaseTool, StructuredTool class Tool(BaseTool): @@ -30,17 +25,30 @@ class Tool(BaseTool): @property def args(self) -> dict: + """The tool's input arguments.""" if self.args_schema is not None: return self.args_schema.schema()["properties"] - else: - inferred_model = validate_arguments(self.func).model # type: ignore - return get_filtered_args(inferred_model, self.func) + # For backwards compatibility, if the function signature is ambiguous, + # assume it takes a single string input. + return {"tool_input": {"type": "string"}} + + def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + """Convert tool input to pydantic model.""" + args, kwargs = super()._to_args_and_kwargs(tool_input) + # For backwards compatibility. The tool must be run with a single input + all_args = list(args) + list(kwargs.values()) + if len(all_args) != 1: + raise ValueError( + f"Too many arguments to single-input tool {self.name}." + f" Args: {all_args}" + ) + return tuple(all_args), {} - def _run(self, *args: Any, **kwargs: Any) -> str: + def _run(self, *args: Any, **kwargs: Any) -> Any: """Use the tool.""" return self.func(*args, **kwargs) - async def _arun(self, *args: Any, **kwargs: Any) -> str: + async def _arun(self, *args: Any, **kwargs: Any) -> Any: """Use the tool asynchronously.""" if self.coroutine: return await self.coroutine(*args, **kwargs) @@ -48,7 +56,7 @@ class Tool(BaseTool): # TODO: this is for backwards compatibility, remove in future def __init__( - self, name: str, func: Callable[[str], str], description: str, **kwargs: Any + self, name: str, func: Callable, description: str, **kwargs: Any ) -> None: """Initialize tool.""" super(Tool, self).__init__( @@ -107,22 +115,24 @@ def tool( """ def _make_with_name(tool_name: str) -> Callable: - def _make_tool(func: Callable) -> Tool: - assert func.__doc__, "Function must have a docstring" - # Description example: - # search_api(query: str) - Searches the API for the query. - description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}" - _args_schema = args_schema - if _args_schema is None and infer_schema: - _args_schema = create_schema_from_function(f"{tool_name}Schema", func) - tool_ = Tool( + def _make_tool(func: Callable) -> BaseTool: + if infer_schema or args_schema is not None: + return StructuredTool.from_function( + func, + name=tool_name, + return_direct=return_direct, + args_schema=args_schema, + infer_schema=infer_schema, + ) + # If someone doesn't want a schema applied, we must treat it as + # a simple string->string function + assert func.__doc__ is not None, "Function must have a docstring" + return Tool( name=tool_name, func=func, - args_schema=_args_schema, - description=description, + description=f"{tool_name} tool", return_direct=return_direct, ) - return tool_ return _make_tool diff --git a/langchain/tools/base.py b/langchain/tools/base.py index 95aae536..9090c683 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from inspect import signature -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union from pydantic import ( BaseModel, @@ -19,15 +19,6 @@ from langchain.callbacks import get_callback_manager from langchain.callbacks.base import BaseCallbackManager -def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]: - # For backwards compatability, if run_input is a string, - # pass as a positional argument. - if isinstance(run_input, str): - return (run_input,), {} - else: - return [], run_input - - class SchemaAnnotationError(TypeError): """Raised when 'args_schema' is missing or has an incorrect type annotation.""" @@ -81,14 +72,20 @@ def _create_subset_model( return create_model(name, **fields) # type: ignore -def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict: +def get_filtered_args( + inferred_model: Type[BaseModel], + func: Callable, +) -> dict: """Get the arguments from a function's signature.""" schema = inferred_model.schema()["properties"] valid_keys = signature(func).parameters return {k: schema[k] for k in valid_keys} -def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]: +def create_schema_from_function( + model_name: str, + func: Callable, +) -> Type[BaseModel]: """Create a pydantic schema from a function's signature.""" inferred_model = validate_arguments(func).model # type: ignore # Pydantic adds placeholder virtual fields we need to strip @@ -102,12 +99,23 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): """Interface LangChain tools must implement.""" name: str + """The unique name of the tool that clearly communicates its purpose.""" description: str + """Used to tell the model how/when/why to use the tool. + + You can provide few-shot examples as a part of the description. + """ args_schema: Optional[Type[BaseModel]] = None """Pydantic model class to validate and parse the tool's input arguments.""" return_direct: bool = False + """Whether to return the tool's output directly. Setting this to True means + + that after the tool is called, the AgentExecutor will stop looping. + """ verbose: bool = False + """Whether to log the tool's progress.""" callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) + """Callback manager for this tool.""" class Config: """Configuration for this pydantic object.""" @@ -160,6 +168,14 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): async def _arun(self, *args: Any, **kwargs: Any) -> Any: """Use the tool asynchronously.""" + def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + # For backwards compatibility, if run_input is a string, + # pass as a positional argument. + if isinstance(tool_input, str): + return (tool_input,), {} + else: + return (), tool_input + def run( self, tool_input: Union[str, Dict], @@ -182,7 +198,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): **kwargs, ) try: - tool_args, tool_kwargs = _to_args_and_kwargs(tool_input) + tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) observation = self._run(*tool_args, **tool_kwargs) except (Exception, KeyboardInterrupt) as e: self.callback_manager.on_tool_error(e, verbose=verbose_) @@ -224,8 +240,8 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): ) try: # We then call the tool on the tool input to get an observation - args, kwargs = _to_args_and_kwargs(tool_input) - observation = await self._arun(*args, **kwargs) + tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) + observation = await self._arun(*tool_args, **tool_kwargs) except (Exception, KeyboardInterrupt) as e: if self.callback_manager.is_async: await self.callback_manager.on_tool_error(e, verbose=verbose_) @@ -249,3 +265,62 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): def __call__(self, tool_input: Union[str, dict]) -> Any: """Make tool callable.""" return self.run(tool_input) + + +class StructuredTool(BaseTool): + """Tool that can operate on any number of inputs.""" + + description: str = "" + args_schema: Type[BaseModel] = Field(..., description="The tool schema.") + """The input arguments' schema.""" + func: Callable[..., str] + """The function to run when the tool is called.""" + coroutine: Optional[Callable[..., Awaitable[str]]] = None + """The asynchronous version of the function.""" + + @property + def args(self) -> dict: + """The tool's input arguments.""" + return self.args_schema.schema()["properties"] + + def _run(self, *args: Any, **kwargs: Any) -> Any: + """Use the tool.""" + return self.func(*args, **kwargs) + + async def _arun(self, *args: Any, **kwargs: Any) -> Any: + """Use the tool asynchronously.""" + if self.coroutine: + return await self.coroutine(*args, **kwargs) + raise NotImplementedError("Tool does not support async") + + @classmethod + def from_function( + cls, + func: Callable, + name: Optional[str] = None, + description: Optional[str] = None, + return_direct: bool = False, + args_schema: Optional[Type[BaseModel]] = None, + infer_schema: bool = True, + **kwargs: Any, + ) -> StructuredTool: + name = name or func.__name__ + description = description or func.__doc__ + assert ( + description is not None + ), "Function must have a docstring if description not provided." + + # Description example: + # search_api(query: str) - Searches the API for the query. + description = f"{name}{signature(func)} - {description.strip()}" + _args_schema = args_schema + if _args_schema is None and infer_schema: + _args_schema = create_schema_from_function(f"{name}Schema", func) + return cls( + name=name, + func=func, + args_schema=_args_schema, + description=description, + return_direct=return_direct, + **kwargs, + ) diff --git a/tests/unit_tests/agents/test_tools.py b/tests/unit_tests/agents/test_tools.py index 6cac896f..965d44db 100644 --- a/tests/unit_tests/agents/test_tools.py +++ b/tests/unit_tests/agents/test_tools.py @@ -1,7 +1,7 @@ """Test tool utils.""" from datetime import datetime from functools import partial -from typing import Optional, Type, Union +from typing import Any, Optional, Type, Union from unittest.mock import MagicMock import pydantic @@ -16,7 +16,7 @@ from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.tools import Tool, tool -from langchain.tools.base import BaseTool, SchemaAnnotationError +from langchain.tools.base import BaseTool, SchemaAnnotationError, StructuredTool def test_unnamed_decorator() -> None: @@ -27,7 +27,7 @@ def test_unnamed_decorator() -> None: """Search the API for the query.""" return "API result" - assert isinstance(search_api, Tool) + assert isinstance(search_api, BaseTool) assert search_api.name == "search_api" assert not search_api.return_direct assert search_api("test") == "API result" @@ -145,7 +145,7 @@ def test_decorator_with_specified_schema() -> None: """Return the arguments directly.""" return f"{arg1} {arg2} {arg3}" - assert isinstance(tool_func, Tool) + assert isinstance(tool_func, BaseTool) assert tool_func.args_schema == _MockSchema @@ -159,7 +159,7 @@ def test_decorated_function_schema_equivalent() -> None: """Return the arguments directly.""" return f"{arg1} {arg2} {arg3}" - assert isinstance(structured_tool_input, Tool) + assert isinstance(structured_tool_input, BaseTool) assert structured_tool_input.args_schema is not None assert ( structured_tool_input.args_schema.schema()["properties"] @@ -171,14 +171,14 @@ def test_decorated_function_schema_equivalent() -> None: def test_structured_args_decorator_no_infer_schema() -> None: """Test functionality with structured arguments parsed as a decorator.""" - @tool(infer_schema=False) + @tool def structured_tool_input( arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None ) -> str: """Return the arguments directly.""" return f"{arg1}, {arg2}, {opt_arg}" - assert isinstance(structured_tool_input, Tool) + assert isinstance(structured_tool_input, BaseTool) assert structured_tool_input.name == "structured_tool_input" args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}} expected_result = "1, 0.001, {'foo': 'bar'}" @@ -193,8 +193,9 @@ def test_structured_single_str_decorator_no_infer_schema() -> None: """Return the arguments directly.""" return f"{tool_input}" - assert isinstance(unstructured_tool_input, Tool) + assert isinstance(unstructured_tool_input, BaseTool) assert unstructured_tool_input.args_schema is None + assert unstructured_tool_input.run("foo") == "foo" def test_base_tool_inheritance_base_schema() -> None: @@ -225,18 +226,18 @@ def test_tool_lambda_args_schema() -> None: func=lambda tool_input: tool_input, ) assert tool.args_schema is None - expected_args = {"tool_input": {"title": "Tool Input"}} + expected_args = {"tool_input": {"type": "string"}} assert tool.args == expected_args -def test_tool_lambda_multi_args_schema() -> None: +def test_structured_tool_lambda_multi_args_schema() -> None: """Test args schema inference when the tool argument is a lambda function.""" - tool = Tool( + tool = StructuredTool.from_function( name="tool", description="A tool", func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore ) - assert tool.args_schema is None + assert tool.args_schema is not None expected_args = { "tool_input": {"title": "Tool Input"}, "other_arg": {"title": "Other Arg"}, @@ -268,7 +269,7 @@ def test_empty_args_decorator() -> None: """Return a constant.""" return "the empty result" - assert isinstance(empty_tool_input, Tool) + assert isinstance(empty_tool_input, BaseTool) assert empty_tool_input.name == "empty_tool_input" assert empty_tool_input.args == {} assert empty_tool_input.run({}) == "the empty result" @@ -282,7 +283,7 @@ def test_named_tool_decorator() -> None: """Search the API for the query.""" return "API result" - assert isinstance(search_api, Tool) + assert isinstance(search_api, BaseTool) assert search_api.name == "search" assert not search_api.return_direct @@ -295,7 +296,7 @@ def test_named_tool_decorator_return_direct() -> None: """Search the API for the query.""" return "API result" - assert isinstance(search_api, Tool) + assert isinstance(search_api, BaseTool) assert search_api.name == "search" assert search_api.return_direct @@ -308,7 +309,7 @@ def test_unnamed_tool_decorator_return_direct() -> None: """Search the API for the query.""" return "API result" - assert isinstance(search_api, Tool) + assert isinstance(search_api, BaseTool) assert search_api.name == "search_api" assert search_api.return_direct @@ -325,7 +326,7 @@ def test_tool_with_kwargs() -> None: """Search the API for the query.""" return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}" - assert isinstance(search_api, Tool) + assert isinstance(search_api, BaseTool) result = search_api.run( tool_input={ "arg_0": "foo", @@ -423,3 +424,23 @@ def test_single_input_agent_raises_error_on_structured_tool( f" multi-input tool {the_tool.name}.", ): agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore + + +def test_tool_no_args_specified_assumes_str() -> None: + """Older tools could assume *args and **kwargs were passed in.""" + + def ambiguous_function(*args: Any, **kwargs: Any) -> str: + """An ambiguously defined function.""" + return args[0] + + some_tool = Tool( + name="chain_run", + description="Run the chain", + func=ambiguous_function, + ) + expected_args = {"tool_input": {"type": "string"}} + assert some_tool.args == expected_args + assert some_tool.run("foobar") == "foobar" + assert some_tool.run({"tool_input": "foobar"}) == "foobar" + with pytest.raises(ValueError, match="Too many arguments to single-input tool"): + some_tool.run({"tool_input": "foobar", "other_input": "bar"})