diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index f60646ad52..505cecc99f 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -897,7 +897,10 @@ s else: tool_run_kwargs = self.agent.tool_run_logging_kwargs() observation = InvalidTool().run( - agent_action.tool, + { + "requested_tool_name": agent_action.tool, + "available_tool_names": list(name_to_tool_map.keys()), + }, verbose=self.verbose, color=None, callbacks=run_manager.get_child() if run_manager else None, @@ -992,7 +995,10 @@ s else: tool_run_kwargs = self.agent.tool_run_logging_kwargs() observation = await InvalidTool().arun( - agent_action.tool, + { + "requested_tool_name": agent_action.tool, + "available_tool_names": list(name_to_tool_map.keys()), + }, verbose=self.verbose, color=None, callbacks=run_manager.get_child() if run_manager else None, diff --git a/libs/langchain/langchain/agents/tools.py b/libs/langchain/langchain/agents/tools.py index 39925fd759..e13acc42bd 100644 --- a/libs/langchain/langchain/agents/tools.py +++ b/libs/langchain/langchain/agents/tools.py @@ -1,5 +1,5 @@ """Interface for tools.""" -from typing import Optional +from typing import List, Optional from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, @@ -12,23 +12,33 @@ class InvalidTool(BaseTool): """Tool that is run when invalid tool name is encountered by agent.""" name = "invalid_tool" - """Name of the tool.""" - description = "Called when tool name is invalid." - """Description of the tool.""" + description = "Called when tool name is invalid. Suggests valid tool names." def _run( - self, tool_name: str, run_manager: Optional[CallbackManagerForToolRun] = None + self, + requested_tool_name: str, + available_tool_names: List[str], + run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """Use the tool.""" - return f"{tool_name} is not a valid tool, try another one." + available_tool_names_str = ", ".join([tool for tool in available_tool_names]) + return ( + f"{requested_tool_name} is not a valid tool, " + f"try one of [{available_tool_names_str}]." + ) async def _arun( self, - tool_name: str, + requested_tool_name: str, + available_tool_names: List[str], run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: """Use the tool asynchronously.""" - return f"{tool_name} is not a valid tool, try another one." + available_tool_names_str = ", ".join([tool for tool in available_tool_names]) + return ( + f"{requested_tool_name} is not a valid tool, " + f"try one of [{available_tool_names_str}]." + ) __all__ = ["InvalidTool", "BaseTool", "tool", "Tool"] diff --git a/libs/langchain/tests/unit_tests/agents/test_agent.py b/libs/langchain/tests/unit_tests/agents/test_agent.py index cc4ab557c4..9ba2a31897 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent.py @@ -257,3 +257,26 @@ def test_agent_lookup_tool() -> None: ) assert agent.lookup_tool("Search") == tools[0] + + +def test_agent_invalid_tool() -> None: + """Test agent invalid tool and correct suggestions.""" + fake_llm = FakeListLLM(responses=["FooBarBaz\nAction: Foo\nAction Input: Bar"]) + tools = [ + Tool( + name="Search", + func=lambda x: x, + description="Useful for searching", + return_direct=True, + ), + ] + agent = initialize_agent( + tools=tools, + llm=fake_llm, + agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, + return_intermediate_steps=True, + max_iterations=1, + ) + + resp = agent("when was langchain made") + resp["intermediate_steps"][0][1] == "Foo is not a valid tool, try one of [Search]."