Add validation on agent instantiation for multi-input tools (#3681)

Tradeoffs here:
- No lint-time checking for compatibility
- Differs from JS package
- The signature inference, etc. in the base tool isn't simple
- The `args_schema` is optional 

Pros:
- Forwards compatibility retained
- Doesn't break backwards compatibility
- User doesn't have to think about which class to subclass (single base
tool or dynamic `Tool` interface regardless of input)
-  No need to change the load_tools, etc. interfaces

Co-authored-by: Hasan Patel <mangafield@gmail.com>
fix_agent_callbacks
Zander Chase 1 year ago committed by GitHub
parent 212aadd4af
commit 4654c58f72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -454,7 +454,11 @@ class Agent(BaseSingleActionAgent):
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
"""Validate that appropriate tools are passed in."""
pass
for tool in tools:
if not tool.is_single_input:
raise ValueError(
f"{cls.__name__} does not support multi-input tool {tool.name}."
)
@classmethod
@abstractmethod

@ -122,6 +122,7 @@ class ZeroShotAgent(Agent):
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
super()._validate_tools(tools)
for tool in tools:
if tool.description is None:
raise ValueError(

@ -37,6 +37,7 @@ class ReActDocstoreAgent(Agent):
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
super()._validate_tools(tools)
if len(tools) != 2:
raise ValueError(f"Exactly two tools must be specified, but got {tools}")
tool_names = {tool.name for tool in tools}
@ -119,6 +120,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
super()._validate_tools(tools)
if len(tools) != 1:
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
tool_names = {tool.name for tool in tools}

@ -36,6 +36,7 @@ class SelfAskWithSearchAgent(Agent):
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
super()._validate_tools(tools)
if len(tools) != 1:
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
tool_names = {tool.name for tool in tools}

@ -115,6 +115,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def is_single_input(self) -> bool:
"""Whether the tool only accepts a single input."""
return len(self.args) == 1
@property
def args(self) -> dict:
if self.args_schema is not None:
@ -148,11 +153,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
return callback_manager or get_callback_manager()
@abstractmethod
def _run(self, *args: Any, **kwargs: Any) -> str:
def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool."""
@abstractmethod
async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""
def run(
@ -183,7 +188,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
str(observation), verbose=verbose_, color=color, name=self.name, **kwargs
)
return observation
@ -194,7 +199,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any,
) -> str:
) -> Any:
"""Run the tool asynchronously."""
self._parse_input(tool_input)
if not self.verbose and verbose is not None:
@ -229,7 +234,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
str(observation),
verbose=verbose_,
color=color,
name=self.name,
**kwargs,
)
else:
self.callback_manager.on_tool_end(
@ -237,6 +246,6 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
)
return observation
def __call__(self, tool_input: str) -> str:
def __call__(self, tool_input: Union[str, dict]) -> Any:
"""Make tool callable."""
return self.run(tool_input)

@ -2,11 +2,19 @@
from datetime import datetime
from functools import partial
from typing import Optional, Type, Union
from unittest.mock import MagicMock
import pydantic
import pytest
from pydantic import BaseModel
from langchain.agents.agent import Agent
from langchain.agents.chat.base import ChatAgent
from langchain.agents.conversational.base import ConversationalAgent
from langchain.agents.conversational_chat.base import ConversationalChatAgent
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool, tool
from langchain.tools.base import BaseTool, SchemaAnnotationError
@ -152,6 +160,7 @@ def test_decorated_function_schema_equivalent() -> None:
return f"{arg1} {arg2} {arg3}"
assert isinstance(structured_tool_input, Tool)
assert structured_tool_input.args_schema is not None
assert (
structured_tool_input.args_schema.schema()["properties"]
== _MockSchema.schema()["properties"]
@ -309,33 +318,38 @@ def test_tool_with_kwargs() -> None:
@tool(return_direct=True)
def search_api(
arg_1: float,
arg_0: str,
arg_1: float = 4.3,
ping: str = "hi",
) -> str:
"""Search the API for the query."""
return f"arg_1={arg_1}, ping={ping}"
return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}"
assert isinstance(search_api, Tool)
result = search_api.run(
tool_input={
"arg_0": "foo",
"arg_1": 3.2,
"ping": "pong",
}
)
assert result == "arg_1=3.2, ping=pong"
assert result == "arg_0=foo, arg_1=3.2, ping=pong"
result = search_api.run(
tool_input={
"arg_1": 3.2,
"arg_0": "foo",
}
)
assert result == "arg_1=3.2, ping=hi"
assert result == "arg_0=foo, arg_1=4.3, ping=hi"
# For backwards compatibility, we still accept a single str arg
result = search_api.run("foobar")
assert result == "arg_0=foobar, arg_1=4.3, ping=hi"
def test_missing_docstring() -> None:
"""Test error is raised when docstring is missing."""
# expect to throw a value error if theres no docstring
with pytest.raises(AssertionError):
with pytest.raises(AssertionError, match="Function must have a docstring"):
@tool
def search_api(query: str) -> str:
@ -348,11 +362,13 @@ def test_create_tool_positional_args() -> None:
assert test_tool("foo") == "foo"
assert test_tool.name == "test_name"
assert test_tool.description == "test_description"
assert test_tool.is_single_input
def test_create_tool_keyword_args() -> None:
"""Test that keyword arguments are allowed."""
test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")
assert test_tool.is_single_input
assert test_tool("foo") == "foo"
assert test_tool.name == "test_name"
assert test_tool.description == "test_description"
@ -371,8 +387,39 @@ async def test_create_async_tool() -> None:
description="test_description",
coroutine=_test_func,
)
assert test_tool.is_single_input
assert test_tool("foo") == "foo"
assert test_tool.name == "test_name"
assert test_tool.description == "test_description"
assert test_tool.coroutine is not None
assert await test_tool.arun("foo") == "foo"
@pytest.mark.parametrize(
"agent_cls",
[
ChatAgent,
ZeroShotAgent,
ConversationalChatAgent,
ConversationalAgent,
ReActDocstoreAgent,
ReActTextWorldAgent,
SelfAskWithSearchAgent,
],
)
def test_single_input_agent_raises_error_on_structured_tool(
agent_cls: Type[Agent],
) -> None:
"""Test that older agents raise errors on older tools."""
@tool
def the_tool(foo: str, bar: str) -> str:
"""Return the concat of foo and bar."""
return foo + bar
with pytest.raises(
ValueError,
match=f"{agent_cls.__name__} does not support" # type: ignore
f" multi-input tool {the_tool.name}.",
):
agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore

Loading…
Cancel
Save