Improving the text of the invalid tool to list the available tools. (#8767)

Description: When using a ReAct Agent with tools and no tool is found,
the InvalidTool gets called. Previously it just asked for a different
action, but I've found that if you list the available actions it
improves the chances of getting a valid action in the next round. I've
added a UnitTest for it also.

@hinthornw
pull/8808/head
Paul Hager 1 year ago committed by GitHub
parent d9bc46186d
commit 2111ed3c75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -897,7 +897,10 @@ s
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = InvalidTool().run(
agent_action.tool,
{
"requested_tool_name": agent_action.tool,
"available_tool_names": list(name_to_tool_map.keys()),
},
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
@ -992,7 +995,10 @@ s
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = await InvalidTool().arun(
agent_action.tool,
{
"requested_tool_name": agent_action.tool,
"available_tool_names": list(name_to_tool_map.keys()),
},
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,

@ -1,5 +1,5 @@
"""Interface for tools."""
from typing import Optional
from typing import List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
@ -12,23 +12,33 @@ class InvalidTool(BaseTool):
"""Tool that is run when invalid tool name is encountered by agent."""
name = "invalid_tool"
"""Name of the tool."""
description = "Called when tool name is invalid."
"""Description of the tool."""
description = "Called when tool name is invalid. Suggests valid tool names."
def _run(
self, tool_name: str, run_manager: Optional[CallbackManagerForToolRun] = None
self,
requested_tool_name: str,
available_tool_names: List[str],
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the tool."""
return f"{tool_name} is not a valid tool, try another one."
available_tool_names_str = ", ".join([tool for tool in available_tool_names])
return (
f"{requested_tool_name} is not a valid tool, "
f"try one of [{available_tool_names_str}]."
)
async def _arun(
self,
tool_name: str,
requested_tool_name: str,
available_tool_names: List[str],
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Use the tool asynchronously."""
return f"{tool_name} is not a valid tool, try another one."
available_tool_names_str = ", ".join([tool for tool in available_tool_names])
return (
f"{requested_tool_name} is not a valid tool, "
f"try one of [{available_tool_names_str}]."
)
__all__ = ["InvalidTool", "BaseTool", "tool", "Tool"]

@ -257,3 +257,26 @@ def test_agent_lookup_tool() -> None:
)
assert agent.lookup_tool("Search") == tools[0]
def test_agent_invalid_tool() -> None:
"""Test agent invalid tool and correct suggestions."""
fake_llm = FakeListLLM(responses=["FooBarBaz\nAction: Foo\nAction Input: Bar"])
tools = [
Tool(
name="Search",
func=lambda x: x,
description="Useful for searching",
return_direct=True,
),
]
agent = initialize_agent(
tools=tools,
llm=fake_llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
return_intermediate_steps=True,
max_iterations=1,
)
resp = agent("when was langchain made")
resp["intermediate_steps"][0][1] == "Foo is not a valid tool, try one of [Search]."

Loading…
Cancel
Save