From 68666d6a22a56ae5bcc1e2f8e9916d1c9dd557ad Mon Sep 17 00:00:00 2001 From: John McDonnell Date: Tue, 6 Dec 2022 21:52:48 -0800 Subject: [PATCH] Gracefully degrade when model asks for nonexistent tool (#268) Not yet tested, but very simple change, assumption is that we're cool with just producing a generic output when tool is not found --- langchain/agents/agent.py | 13 +++++--- tests/unit_tests/agents/test_agent.py | 45 +++++++++++++++++++++++++++ tests/unit_tests/agents/test_react.py | 10 +++--- 3 files changed, 59 insertions(+), 9 deletions(-) create mode 100644 tests/unit_tests/agents/test_agent.py diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 80ee3ab1..62c4dcca 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -150,12 +150,17 @@ class Agent(Chain, BaseModel, ABC): if output.tool == self.finish_tool_name: return {self.output_key: output.tool_input} # Otherwise we lookup the tool - chain = name_to_tool_map[output.tool] - # We then call the tool on the tool input to get an observation - observation = chain(output.tool_input) + if output.tool in name_to_tool_map: + chain = name_to_tool_map[output.tool] + # We then call the tool on the tool input to get an observation + observation = chain(output.tool_input) + color = color_mapping[output.tool] + else: + observation = f"{output.tool} is not a valid tool, try another one." + color = None # We then log the observation chained_input.add(f"\n{self.observation_prefix}") - chained_input.add(observation, color=color_mapping[output.tool]) + chained_input.add(observation, color=color) # We then add the LLM prefix into the prompt to get the LLM to start # thinking, and start the loop all over. chained_input.add(f"\n{self.llm_prefix}") diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py new file mode 100644 index 00000000..85853e70 --- /dev/null +++ b/tests/unit_tests/agents/test_agent.py @@ -0,0 +1,45 @@ +"""Unit tests for agents.""" + +from typing import Any, List, Mapping, Optional + +from langchain.agents import Tool, initialize_agent +from langchain.llms.base import LLM + + +class FakeListLLM(LLM): + """Fake LLM for testing that outputs elements of a list.""" + + def __init__(self, responses: List[str]): + """Initialize with list of responses.""" + self.responses = responses + self.i = -1 + + def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Increment counter, and then return response in that index.""" + self.i += 1 + print(self.i) + print(self.responses) + return self.responses[self.i] + + @property + def _identifying_params(self) -> Mapping[str, Any]: + return {} + + +def test_agent_bad_action() -> None: + """Test react chain when bad action given.""" + bad_action_name = "BadAction" + responses = [ + f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment", + "Oh well\nAction: Final Answer\nAction Input: curses foiled again", + ] + fake_llm = FakeListLLM(responses) + tools = [ + Tool("Search", lambda x: x, "Useful for searching"), + Tool("Lookup", lambda x: x, "Useful for looking up things in a table"), + ] + agent = initialize_agent( + tools, fake_llm, agent="zero-shot-react-description", verbose=True + ) + output = agent.run("when was langchain made") + assert output == "curses foiled again" diff --git a/tests/unit_tests/agents/test_react.py b/tests/unit_tests/agents/test_react.py index e9ea6a01..2d6c4cca 100644 --- a/tests/unit_tests/agents/test_react.py +++ b/tests/unit_tests/agents/test_react.py @@ -2,8 +2,6 @@ from typing import Any, List, Mapping, Optional, Union -import pytest - from langchain.agents.react.base import ReActChain, ReActDocstoreAgent from langchain.agents.tools import Tool from langchain.docstore.base import Docstore @@ -94,10 +92,12 @@ def test_react_chain() -> None: def test_react_chain_bad_action() -> None: """Test react chain when bad action given.""" + bad_action_name = "BadAction" responses = [ - "I should probably search\nAction 1: BadAction[langchain]", + f"I'm turning evil\nAction 1: {bad_action_name}[langchain]", + "Oh well\nAction 2: Finish[curses foiled again]", ] fake_llm = FakeListLLM(responses) react_chain = ReActChain(llm=fake_llm, docstore=FakeDocstore()) - with pytest.raises(KeyError): - react_chain.run("when was langchain made") + output = react_chain.run("when was langchain made") + assert output == "curses foiled again"