From aed9f9febe6154790fb3b4773dd5bebf80fc02c8 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 13 Mar 2023 07:54:29 -0700 Subject: [PATCH] Harrison/return intermediate (#1633) Co-authored-by: Mario Kostelac --- langchain/agents/agent.py | 31 +++++++++++++++++++++------ tests/unit_tests/agents/test_agent.py | 30 ++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 01bc233a..e3eeb98c 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -435,10 +435,6 @@ class AgentExecutor(Chain, BaseModel): llm_prefix="", observation_prefix=self.agent.observation_prefix, ) - return_direct = False - if return_direct: - # Set the log to "" because we do not want to log it. - return AgentFinish({self.agent.return_values[0]: observation}, "") return output, observation async def _atake_next_step( @@ -483,9 +479,6 @@ class AgentExecutor(Chain, BaseModel): observation_prefix=self.agent.observation_prefix, ) return_direct = False - if return_direct: - # Set the log to "" because we do not want to log it. - return AgentFinish({self.agent.return_values[0]: observation}, "") return output, observation def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: @@ -510,6 +503,10 @@ class AgentExecutor(Chain, BaseModel): return self._return(next_step_output, intermediate_steps) intermediate_steps.append(next_step_output) + # See if tool should return directly + tool_return = self._get_tool_return(next_step_output) + if tool_return is not None: + return self._return(tool_return, intermediate_steps) iterations += 1 output = self.agent.return_stopped_response( self.early_stopping_method, intermediate_steps, **inputs @@ -538,8 +535,28 @@ class AgentExecutor(Chain, BaseModel): return await self._areturn(next_step_output, intermediate_steps) intermediate_steps.append(next_step_output) + # See if tool should return directly + tool_return = self._get_tool_return(next_step_output) + if tool_return is not None: + return await self._areturn(tool_return, intermediate_steps) + iterations += 1 output = self.agent.return_stopped_response( self.early_stopping_method, intermediate_steps, **inputs ) return await self._areturn(output, intermediate_steps) + + def _get_tool_return( + self, next_step_output: Tuple[AgentAction, str] + ) -> Optional[AgentFinish]: + """Check if the tool is a returning tool.""" + agent_action, observation = next_step_output + name_to_tool_map = {tool.name: tool for tool in self.tools} + # Invalid tools won't be in the map, so we return False. + if agent_action.tool in name_to_tool_map: + if name_to_tool_map[agent_action.tool].return_direct: + return AgentFinish( + {self.agent.return_values[0]: observation}, + "", + ) + return None diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index 6d906a83..e9d08317 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -230,6 +230,36 @@ def test_agent_tool_return_direct() -> None: assert output == "misalignment" +def test_agent_tool_return_direct_in_intermediate_steps() -> None: + """Test agent using tools that return directly.""" + tool = "Search" + responses = [ + f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", + "Oh well\nAction: Final Answer\nAction Input: curses foiled again", + ] + fake_llm = FakeListLLM(responses=responses) + tools = [ + Tool( + name="Search", + func=lambda x: x, + description="Useful for searching", + return_direct=True, + ), + ] + agent = initialize_agent( + tools, + fake_llm, + agent="zero-shot-react-description", + return_intermediate_steps=True, + ) + + resp = agent("when was langchain made") + assert resp["output"] == "misalignment" + assert len(resp["intermediate_steps"]) == 1 + action, _action_intput = resp["intermediate_steps"][0] + assert action.tool == "Search" + + def test_agent_with_new_prefix_suffix() -> None: """Test agent initilization kwargs with new prefix and suffix.""" fake_llm = FakeListLLM(