diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index c675b558..0099a035 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -454,7 +454,11 @@ class Agent(BaseSingleActionAgent): @classmethod def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: """Validate that appropriate tools are passed in.""" - pass + for tool in tools: + if not tool.is_single_input: + raise ValueError( + f"{cls.__name__} does not support multi-input tool {tool.name}." + ) @classmethod @abstractmethod diff --git a/langchain/agents/mrkl/base.py b/langchain/agents/mrkl/base.py index 4bb1d519..60f7c981 100644 --- a/langchain/agents/mrkl/base.py +++ b/langchain/agents/mrkl/base.py @@ -122,6 +122,7 @@ class ZeroShotAgent(Agent): @classmethod def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: + super()._validate_tools(tools) for tool in tools: if tool.description is None: raise ValueError( diff --git a/langchain/agents/react/base.py b/langchain/agents/react/base.py index 08d27d4d..afb199c2 100644 --- a/langchain/agents/react/base.py +++ b/langchain/agents/react/base.py @@ -37,6 +37,7 @@ class ReActDocstoreAgent(Agent): @classmethod def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: + super()._validate_tools(tools) if len(tools) != 2: raise ValueError(f"Exactly two tools must be specified, but got {tools}") tool_names = {tool.name for tool in tools} @@ -119,6 +120,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent): @classmethod def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: + super()._validate_tools(tools) if len(tools) != 1: raise ValueError(f"Exactly one tool must be specified, but got {tools}") tool_names = {tool.name for tool in tools} diff --git a/langchain/agents/self_ask_with_search/base.py b/langchain/agents/self_ask_with_search/base.py index 297b0345..5e1905b5 100644 --- a/langchain/agents/self_ask_with_search/base.py +++ b/langchain/agents/self_ask_with_search/base.py @@ -36,6 +36,7 @@ class SelfAskWithSearchAgent(Agent): @classmethod def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: + super()._validate_tools(tools) if len(tools) != 1: raise ValueError(f"Exactly one tool must be specified, but got {tools}") tool_names = {tool.name for tool in tools} diff --git a/langchain/tools/base.py b/langchain/tools/base.py index 54d8db5d..95aae536 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -115,6 +115,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): extra = Extra.forbid arbitrary_types_allowed = True + @property + def is_single_input(self) -> bool: + """Whether the tool only accepts a single input.""" + return len(self.args) == 1 + @property def args(self) -> dict: if self.args_schema is not None: @@ -148,11 +153,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): return callback_manager or get_callback_manager() @abstractmethod - def _run(self, *args: Any, **kwargs: Any) -> str: + def _run(self, *args: Any, **kwargs: Any) -> Any: """Use the tool.""" @abstractmethod - async def _arun(self, *args: Any, **kwargs: Any) -> str: + async def _arun(self, *args: Any, **kwargs: Any) -> Any: """Use the tool asynchronously.""" def run( @@ -183,7 +188,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): self.callback_manager.on_tool_error(e, verbose=verbose_) raise e self.callback_manager.on_tool_end( - observation, verbose=verbose_, color=color, name=self.name, **kwargs + str(observation), verbose=verbose_, color=color, name=self.name, **kwargs ) return observation @@ -194,7 +199,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): start_color: Optional[str] = "green", color: Optional[str] = "green", **kwargs: Any, - ) -> str: + ) -> Any: """Run the tool asynchronously.""" self._parse_input(tool_input) if not self.verbose and verbose is not None: @@ -229,7 +234,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): raise e if self.callback_manager.is_async: await self.callback_manager.on_tool_end( - observation, verbose=verbose_, color=color, name=self.name, **kwargs + str(observation), + verbose=verbose_, + color=color, + name=self.name, + **kwargs, ) else: self.callback_manager.on_tool_end( @@ -237,6 +246,6 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): ) return observation - def __call__(self, tool_input: str) -> str: + def __call__(self, tool_input: Union[str, dict]) -> Any: """Make tool callable.""" return self.run(tool_input) diff --git a/tests/unit_tests/agents/test_tools.py b/tests/unit_tests/agents/test_tools.py index d76011dc..6cac896f 100644 --- a/tests/unit_tests/agents/test_tools.py +++ b/tests/unit_tests/agents/test_tools.py @@ -2,11 +2,19 @@ from datetime import datetime from functools import partial from typing import Optional, Type, Union +from unittest.mock import MagicMock import pydantic import pytest from pydantic import BaseModel +from langchain.agents.agent import Agent +from langchain.agents.chat.base import ChatAgent +from langchain.agents.conversational.base import ConversationalAgent +from langchain.agents.conversational_chat.base import ConversationalChatAgent +from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent +from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.tools import Tool, tool from langchain.tools.base import BaseTool, SchemaAnnotationError @@ -152,6 +160,7 @@ def test_decorated_function_schema_equivalent() -> None: return f"{arg1} {arg2} {arg3}" assert isinstance(structured_tool_input, Tool) + assert structured_tool_input.args_schema is not None assert ( structured_tool_input.args_schema.schema()["properties"] == _MockSchema.schema()["properties"] @@ -309,33 +318,38 @@ def test_tool_with_kwargs() -> None: @tool(return_direct=True) def search_api( - arg_1: float, + arg_0: str, + arg_1: float = 4.3, ping: str = "hi", ) -> str: """Search the API for the query.""" - return f"arg_1={arg_1}, ping={ping}" + return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}" assert isinstance(search_api, Tool) result = search_api.run( tool_input={ + "arg_0": "foo", "arg_1": 3.2, "ping": "pong", } ) - assert result == "arg_1=3.2, ping=pong" + assert result == "arg_0=foo, arg_1=3.2, ping=pong" result = search_api.run( tool_input={ - "arg_1": 3.2, + "arg_0": "foo", } ) - assert result == "arg_1=3.2, ping=hi" + assert result == "arg_0=foo, arg_1=4.3, ping=hi" + # For backwards compatibility, we still accept a single str arg + result = search_api.run("foobar") + assert result == "arg_0=foobar, arg_1=4.3, ping=hi" def test_missing_docstring() -> None: """Test error is raised when docstring is missing.""" # expect to throw a value error if theres no docstring - with pytest.raises(AssertionError): + with pytest.raises(AssertionError, match="Function must have a docstring"): @tool def search_api(query: str) -> str: @@ -348,11 +362,13 @@ def test_create_tool_positional_args() -> None: assert test_tool("foo") == "foo" assert test_tool.name == "test_name" assert test_tool.description == "test_description" + assert test_tool.is_single_input def test_create_tool_keyword_args() -> None: """Test that keyword arguments are allowed.""" test_tool = Tool(name="test_name", func=lambda x: x, description="test_description") + assert test_tool.is_single_input assert test_tool("foo") == "foo" assert test_tool.name == "test_name" assert test_tool.description == "test_description" @@ -371,8 +387,39 @@ async def test_create_async_tool() -> None: description="test_description", coroutine=_test_func, ) + assert test_tool.is_single_input assert test_tool("foo") == "foo" assert test_tool.name == "test_name" assert test_tool.description == "test_description" assert test_tool.coroutine is not None assert await test_tool.arun("foo") == "foo" + + +@pytest.mark.parametrize( + "agent_cls", + [ + ChatAgent, + ZeroShotAgent, + ConversationalChatAgent, + ConversationalAgent, + ReActDocstoreAgent, + ReActTextWorldAgent, + SelfAskWithSearchAgent, + ], +) +def test_single_input_agent_raises_error_on_structured_tool( + agent_cls: Type[Agent], +) -> None: + """Test that older agents raise errors on older tools.""" + + @tool + def the_tool(foo: str, bar: str) -> str: + """Return the concat of foo and bar.""" + return foo + bar + + with pytest.raises( + ValueError, + match=f"{agent_cls.__name__} does not support" # type: ignore + f" multi-input tool {the_tool.name}.", + ): + agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore