Move Tool Validation (#3923)

Move tool validation to each implementation of the Agent.

Another alternative would be to adjust the `_validate_tools()` signature
to accept the output parser (and format instructions) and add logic
there. Something like

`parser.outputs_structured_actions(format_instructions)`

But don't think that's needed right now.
fix_agent_callbacks
Zander Chase 1 year ago committed by GitHub
parent 7cce68a051
commit 84ea17b786
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -497,11 +497,7 @@ class Agent(BaseSingleActionAgent):
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
"""Validate that appropriate tools are passed in."""
for tool in tools:
if not tool.is_single_input:
raise ValueError(
f"{cls.__name__} does not support multi-input tool {tool.name}."
)
pass
@classmethod
@abstractmethod

@ -5,6 +5,7 @@ from pydantic import Field
from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.chat.output_parser import ChatOutputParser
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.utils import validate_tools_single_input
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
@ -15,7 +16,7 @@ from langchain.prompts.chat import (
SystemMessagePromptTemplate,
)
from langchain.schema import AgentAction
from langchain.tools import BaseTool
from langchain.tools.base import BaseTool
class ChatAgent(Agent):
@ -50,6 +51,11 @@ class ChatAgent(Agent):
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return ChatOutputParser()
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
super()._validate_tools(tools)
validate_tools_single_input(class_name=cls.__name__, tools=tools)
@property
def _stop(self) -> List[str]:
return ["Observation:"]

@ -9,6 +9,7 @@ from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.agent_types import AgentType
from langchain.agents.conversational.output_parser import ConvoOutputParser
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.utils import validate_tools_single_input
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
@ -80,6 +81,11 @@ class ConversationalAgent(Agent):
input_variables = ["input", "chat_history", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
super()._validate_tools(tools)
validate_tools_single_input(cls.__name__, tools)
@classmethod
def from_llm_and_tools(
cls,

@ -12,6 +12,7 @@ from langchain.agents.conversational_chat.prompt import (
SUFFIX,
TEMPLATE_TOOL_RESPONSE,
)
from langchain.agents.utils import validate_tools_single_input
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
@ -55,6 +56,11 @@ class ConversationalChatAgent(Agent):
"""Prefix to append the llm call with."""
return "Thought:"
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
super()._validate_tools(tools)
validate_tools_single_input(cls.__name__, tools)
@classmethod
def create_prompt(
cls,

@ -10,6 +10,7 @@ from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.output_parser import MRKLOutputParser
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
@ -122,13 +123,14 @@ class ZeroShotAgent(Agent):
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
super()._validate_tools(tools)
validate_tools_single_input(cls.__name__, tools)
for tool in tools:
if tool.description is None:
raise ValueError(
f"Got a tool {tool.name} without a description. For this agent, "
f"a description must always be provided."
)
super()._validate_tools(tools)
class MRKLChain(AgentExecutor):

@ -9,6 +9,7 @@ from langchain.agents.react.output_parser import ReActOutputParser
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
from langchain.agents.react.wiki_prompt import WIKI_PROMPT
from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.llms.base import BaseLLM
@ -37,6 +38,7 @@ class ReActDocstoreAgent(Agent):
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
validate_tools_single_input(cls.__name__, tools)
super()._validate_tools(tools)
if len(tools) != 2:
raise ValueError(f"Exactly two tools must be specified, but got {tools}")
@ -120,6 +122,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
validate_tools_single_input(cls.__name__, tools)
super()._validate_tools(tools)
if len(tools) != 1:
raise ValueError(f"Exactly one tool must be specified, but got {tools}")

@ -8,6 +8,7 @@ from langchain.agents.agent_types import AgentType
from langchain.agents.self_ask_with_search.output_parser import SelfAskOutputParser
from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.tools.base import BaseTool
@ -36,6 +37,7 @@ class SelfAskWithSearchAgent(Agent):
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
validate_tools_single_input(cls.__name__, tools)
super()._validate_tools(tools)
if len(tools) != 1:
raise ValueError(f"Exactly one tool must be specified, but got {tools}")

@ -0,0 +1,12 @@
from typing import Sequence
from langchain.tools.base import BaseTool
def validate_tools_single_input(class_name: str, tools: Sequence[BaseTool]) -> None:
"""Validate tools for single input."""
for tool in tools:
if not tool.is_single_input:
raise ValueError(
f"{class_name} does not support multi-input tool {tool.name}."
)

@ -400,8 +400,8 @@ async def test_create_async_tool() -> None:
@pytest.mark.parametrize(
"agent_cls",
[
ChatAgent,
ZeroShotAgent,
ChatAgent,
ConversationalChatAgent,
ConversationalAgent,
ReActDocstoreAgent,

Loading…
Cancel
Save