Move Tool Validation (#3923)

Move tool validation to each implementation of the Agent.

Another alternative would be to adjust the `_validate_tools()` signature
to accept the output parser (and format instructions) and add logic
there. Something like

`parser.outputs_structured_actions(format_instructions)`

But don't think that's needed right now.
fix_agent_callbacks
Zander Chase 1 year ago committed by GitHub
parent 7cce68a051
commit 84ea17b786
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -497,11 +497,7 @@ class Agent(BaseSingleActionAgent):
@classmethod @classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
"""Validate that appropriate tools are passed in.""" """Validate that appropriate tools are passed in."""
for tool in tools: pass
if not tool.is_single_input:
raise ValueError(
f"{cls.__name__} does not support multi-input tool {tool.name}."
)
@classmethod @classmethod
@abstractmethod @abstractmethod

@ -5,6 +5,7 @@ from pydantic import Field
from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.chat.output_parser import ChatOutputParser from langchain.agents.chat.output_parser import ChatOutputParser
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.utils import validate_tools_single_input
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -15,7 +16,7 @@ from langchain.prompts.chat import (
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
) )
from langchain.schema import AgentAction from langchain.schema import AgentAction
from langchain.tools import BaseTool from langchain.tools.base import BaseTool
class ChatAgent(Agent): class ChatAgent(Agent):
@ -50,6 +51,11 @@ class ChatAgent(Agent):
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return ChatOutputParser() return ChatOutputParser()
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
super()._validate_tools(tools)
validate_tools_single_input(class_name=cls.__name__, tools=tools)
@property @property
def _stop(self) -> List[str]: def _stop(self) -> List[str]:
return ["Observation:"] return ["Observation:"]

@ -9,6 +9,7 @@ from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.agent_types import AgentType from langchain.agents.agent_types import AgentType
from langchain.agents.conversational.output_parser import ConvoOutputParser from langchain.agents.conversational.output_parser import ConvoOutputParser
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.utils import validate_tools_single_input
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain from langchain.chains import LLMChain
@ -80,6 +81,11 @@ class ConversationalAgent(Agent):
input_variables = ["input", "chat_history", "agent_scratchpad"] input_variables = ["input", "chat_history", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables) return PromptTemplate(template=template, input_variables=input_variables)
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
super()._validate_tools(tools)
validate_tools_single_input(cls.__name__, tools)
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,

@ -12,6 +12,7 @@ from langchain.agents.conversational_chat.prompt import (
SUFFIX, SUFFIX,
TEMPLATE_TOOL_RESPONSE, TEMPLATE_TOOL_RESPONSE,
) )
from langchain.agents.utils import validate_tools_single_input
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain from langchain.chains import LLMChain
@ -55,6 +56,11 @@ class ConversationalChatAgent(Agent):
"""Prefix to append the llm call with.""" """Prefix to append the llm call with."""
return "Thought:" return "Thought:"
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
super()._validate_tools(tools)
validate_tools_single_input(cls.__name__, tools)
@classmethod @classmethod
def create_prompt( def create_prompt(
cls, cls,

@ -10,6 +10,7 @@ from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.output_parser import MRKLOutputParser from langchain.agents.mrkl.output_parser import MRKLOutputParser
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain from langchain.chains import LLMChain
@ -122,13 +123,14 @@ class ZeroShotAgent(Agent):
@classmethod @classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
super()._validate_tools(tools) validate_tools_single_input(cls.__name__, tools)
for tool in tools: for tool in tools:
if tool.description is None: if tool.description is None:
raise ValueError( raise ValueError(
f"Got a tool {tool.name} without a description. For this agent, " f"Got a tool {tool.name} without a description. For this agent, "
f"a description must always be provided." f"a description must always be provided."
) )
super()._validate_tools(tools)
class MRKLChain(AgentExecutor): class MRKLChain(AgentExecutor):

@ -9,6 +9,7 @@ from langchain.agents.react.output_parser import ReActOutputParser
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
from langchain.agents.react.wiki_prompt import WIKI_PROMPT from langchain.agents.react.wiki_prompt import WIKI_PROMPT
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input
from langchain.docstore.base import Docstore from langchain.docstore.base import Docstore
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
@ -37,6 +38,7 @@ class ReActDocstoreAgent(Agent):
@classmethod @classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
validate_tools_single_input(cls.__name__, tools)
super()._validate_tools(tools) super()._validate_tools(tools)
if len(tools) != 2: if len(tools) != 2:
raise ValueError(f"Exactly two tools must be specified, but got {tools}") raise ValueError(f"Exactly two tools must be specified, but got {tools}")
@ -120,6 +122,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
@classmethod @classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
validate_tools_single_input(cls.__name__, tools)
super()._validate_tools(tools) super()._validate_tools(tools)
if len(tools) != 1: if len(tools) != 1:
raise ValueError(f"Exactly one tool must be specified, but got {tools}") raise ValueError(f"Exactly one tool must be specified, but got {tools}")

@ -8,6 +8,7 @@ from langchain.agents.agent_types import AgentType
from langchain.agents.self_ask_with_search.output_parser import SelfAskOutputParser from langchain.agents.self_ask_with_search.output_parser import SelfAskOutputParser
from langchain.agents.self_ask_with_search.prompt import PROMPT from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
@ -36,6 +37,7 @@ class SelfAskWithSearchAgent(Agent):
@classmethod @classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
validate_tools_single_input(cls.__name__, tools)
super()._validate_tools(tools) super()._validate_tools(tools)
if len(tools) != 1: if len(tools) != 1:
raise ValueError(f"Exactly one tool must be specified, but got {tools}") raise ValueError(f"Exactly one tool must be specified, but got {tools}")

@ -0,0 +1,12 @@
from typing import Sequence
from langchain.tools.base import BaseTool
def validate_tools_single_input(class_name: str, tools: Sequence[BaseTool]) -> None:
"""Validate tools for single input."""
for tool in tools:
if not tool.is_single_input:
raise ValueError(
f"{class_name} does not support multi-input tool {tool.name}."
)

@ -400,8 +400,8 @@ async def test_create_async_tool() -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"agent_cls", "agent_cls",
[ [
ChatAgent,
ZeroShotAgent, ZeroShotAgent,
ChatAgent,
ConversationalChatAgent, ConversationalChatAgent,
ConversationalAgent, ConversationalAgent,
ReActDocstoreAgent, ReActDocstoreAgent,

Loading…
Cancel
Save