diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 2ac0d539..eb9b859f 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -497,11 +497,7 @@ class Agent(BaseSingleActionAgent): @classmethod def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: """Validate that appropriate tools are passed in.""" - for tool in tools: - if not tool.is_single_input: - raise ValueError( - f"{cls.__name__} does not support multi-input tool {tool.name}." - ) + pass @classmethod @abstractmethod diff --git a/langchain/agents/chat/base.py b/langchain/agents/chat/base.py index 04ceca71..72c5b845 100644 --- a/langchain/agents/chat/base.py +++ b/langchain/agents/chat/base.py @@ -5,6 +5,7 @@ from pydantic import Field from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.chat.output_parser import ChatOutputParser from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX +from langchain.agents.utils import validate_tools_single_input from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain @@ -15,7 +16,7 @@ from langchain.prompts.chat import ( SystemMessagePromptTemplate, ) from langchain.schema import AgentAction -from langchain.tools import BaseTool +from langchain.tools.base import BaseTool class ChatAgent(Agent): @@ -50,6 +51,11 @@ class ChatAgent(Agent): def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: return ChatOutputParser() + @classmethod + def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: + super()._validate_tools(tools) + validate_tools_single_input(class_name=cls.__name__, tools=tools) + @property def _stop(self) -> List[str]: return ["Observation:"] diff --git a/langchain/agents/conversational/base.py b/langchain/agents/conversational/base.py index 16a43a90..ce91d9c6 100644 --- a/langchain/agents/conversational/base.py +++ b/langchain/agents/conversational/base.py @@ -9,6 +9,7 @@ from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.agent_types import AgentType from langchain.agents.conversational.output_parser import ConvoOutputParser from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX +from langchain.agents.utils import validate_tools_single_input from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain @@ -80,6 +81,11 @@ class ConversationalAgent(Agent): input_variables = ["input", "chat_history", "agent_scratchpad"] return PromptTemplate(template=template, input_variables=input_variables) + @classmethod + def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: + super()._validate_tools(tools) + validate_tools_single_input(cls.__name__, tools) + @classmethod def from_llm_and_tools( cls, diff --git a/langchain/agents/conversational_chat/base.py b/langchain/agents/conversational_chat/base.py index d9b83ecc..17128543 100644 --- a/langchain/agents/conversational_chat/base.py +++ b/langchain/agents/conversational_chat/base.py @@ -12,6 +12,7 @@ from langchain.agents.conversational_chat.prompt import ( SUFFIX, TEMPLATE_TOOL_RESPONSE, ) +from langchain.agents.utils import validate_tools_single_input from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain @@ -55,6 +56,11 @@ class ConversationalChatAgent(Agent): """Prefix to append the llm call with.""" return "Thought:" + @classmethod + def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: + super()._validate_tools(tools) + validate_tools_single_input(cls.__name__, tools) + @classmethod def create_prompt( cls, diff --git a/langchain/agents/mrkl/base.py b/langchain/agents/mrkl/base.py index 8ce4411c..f28a01ab 100644 --- a/langchain/agents/mrkl/base.py +++ b/langchain/agents/mrkl/base.py @@ -10,6 +10,7 @@ from langchain.agents.agent_types import AgentType from langchain.agents.mrkl.output_parser import MRKLOutputParser from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.tools import Tool +from langchain.agents.utils import validate_tools_single_input from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain @@ -122,13 +123,14 @@ class ZeroShotAgent(Agent): @classmethod def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: - super()._validate_tools(tools) + validate_tools_single_input(cls.__name__, tools) for tool in tools: if tool.description is None: raise ValueError( f"Got a tool {tool.name} without a description. For this agent, " f"a description must always be provided." ) + super()._validate_tools(tools) class MRKLChain(AgentExecutor): diff --git a/langchain/agents/react/base.py b/langchain/agents/react/base.py index afb199c2..a1210be9 100644 --- a/langchain/agents/react/base.py +++ b/langchain/agents/react/base.py @@ -9,6 +9,7 @@ from langchain.agents.react.output_parser import ReActOutputParser from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT from langchain.agents.react.wiki_prompt import WIKI_PROMPT from langchain.agents.tools import Tool +from langchain.agents.utils import validate_tools_single_input from langchain.docstore.base import Docstore from langchain.docstore.document import Document from langchain.llms.base import BaseLLM @@ -37,6 +38,7 @@ class ReActDocstoreAgent(Agent): @classmethod def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: + validate_tools_single_input(cls.__name__, tools) super()._validate_tools(tools) if len(tools) != 2: raise ValueError(f"Exactly two tools must be specified, but got {tools}") @@ -120,6 +122,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent): @classmethod def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: + validate_tools_single_input(cls.__name__, tools) super()._validate_tools(tools) if len(tools) != 1: raise ValueError(f"Exactly one tool must be specified, but got {tools}") diff --git a/langchain/agents/self_ask_with_search/base.py b/langchain/agents/self_ask_with_search/base.py index 5e1905b5..a445e07e 100644 --- a/langchain/agents/self_ask_with_search/base.py +++ b/langchain/agents/self_ask_with_search/base.py @@ -8,6 +8,7 @@ from langchain.agents.agent_types import AgentType from langchain.agents.self_ask_with_search.output_parser import SelfAskOutputParser from langchain.agents.self_ask_with_search.prompt import PROMPT from langchain.agents.tools import Tool +from langchain.agents.utils import validate_tools_single_input from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.tools.base import BaseTool @@ -36,6 +37,7 @@ class SelfAskWithSearchAgent(Agent): @classmethod def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: + validate_tools_single_input(cls.__name__, tools) super()._validate_tools(tools) if len(tools) != 1: raise ValueError(f"Exactly one tool must be specified, but got {tools}") diff --git a/langchain/agents/utils.py b/langchain/agents/utils.py new file mode 100644 index 00000000..5e852458 --- /dev/null +++ b/langchain/agents/utils.py @@ -0,0 +1,12 @@ +from typing import Sequence + +from langchain.tools.base import BaseTool + + +def validate_tools_single_input(class_name: str, tools: Sequence[BaseTool]) -> None: + """Validate tools for single input.""" + for tool in tools: + if not tool.is_single_input: + raise ValueError( + f"{class_name} does not support multi-input tool {tool.name}." + ) diff --git a/tests/unit_tests/agents/test_tools.py b/tests/unit_tests/agents/test_tools.py index 055720d8..094774fb 100644 --- a/tests/unit_tests/agents/test_tools.py +++ b/tests/unit_tests/agents/test_tools.py @@ -400,8 +400,8 @@ async def test_create_async_tool() -> None: @pytest.mark.parametrize( "agent_cls", [ - ChatAgent, ZeroShotAgent, + ChatAgent, ConversationalChatAgent, ConversationalAgent, ReActDocstoreAgent,