diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index b98b48da..da27db35 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -375,6 +375,59 @@ class AgentExecutor(Chain, BaseModel): final_output["intermediate_steps"] = intermediate_steps return final_output + def _take_next_step( + self, + name_to_tool_map: Dict[str, Tool], + color_mapping: Dict[str, str], + inputs: Dict[str, str], + intermediate_steps: List[Tuple[AgentAction, str]], + ) -> Union[AgentFinish, Tuple[AgentAction, str]]: + """Take a single step in the thought-action-observation loop. + + Override this to take control of how the agent makes and acts on choices. + """ + # Call the LLM to see what to do. + output = self.agent.plan(intermediate_steps, **inputs) + # If the tool chosen is the finishing tool, then we end and return. + if isinstance(output, AgentFinish): + return output + # Otherwise we lookup the tool + if output.tool in name_to_tool_map: + tool = name_to_tool_map[output.tool] + self.callback_manager.on_tool_start( + {"name": str(tool.func)[:60] + "..."}, + output, + color="green", + verbose=self.verbose, + ) + try: + # We then call the tool on the tool input to get an observation + observation = tool.func(output.tool_input) + color = color_mapping[output.tool] + return_direct = tool.return_direct + except (KeyboardInterrupt, Exception) as e: + self.callback_manager.on_tool_error(e, verbose=self.verbose) + raise e + else: + self.callback_manager.on_tool_start( + {"name": "N/A"}, output, color="green", verbose=self.verbose + ) + observation = f"{output.tool} is not a valid tool, try another one." + color = None + return_direct = False + llm_prefix = "" if return_direct else self.agent.llm_prefix + self.callback_manager.on_tool_end( + observation, + color=color, + observation_prefix=self.agent.observation_prefix, + llm_prefix=llm_prefix, + verbose=self.verbose, + ) + if return_direct: + # Set the log to "" because we do not want to log it. + return AgentFinish({self.agent.return_values[0]: observation}, "") + return output, observation + def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: """Run text through and get agent response.""" # Make sure that every tool is synchronous (not a coroutine) @@ -398,49 +451,13 @@ class AgentExecutor(Chain, BaseModel): iterations = 0 # We now enter the agent loop (until it returns something). while self._should_continue(iterations): - # Call the LLM to see what to do. - output = self.agent.plan(intermediate_steps, **inputs) - # If the tool chosen is the finishing tool, then we end and return. - if isinstance(output, AgentFinish): - return self._return(output, intermediate_steps) - - # Otherwise we lookup the tool - if output.tool in name_to_tool_map: - tool = name_to_tool_map[output.tool] - self.callback_manager.on_tool_start( - {"name": str(tool.func)[:60] + "..."}, - output, - color="green", - verbose=self.verbose, - ) - try: - # We then call the tool on the tool input to get an observation - observation = tool.func(output.tool_input) - color = color_mapping[output.tool] - return_direct = tool.return_direct - except (KeyboardInterrupt, Exception) as e: - self.callback_manager.on_tool_error(e, verbose=self.verbose) - raise e - else: - self.callback_manager.on_tool_start( - {"name": "N/A"}, output, color="green", verbose=self.verbose - ) - observation = f"{output.tool} is not a valid tool, try another one." - color = None - return_direct = False - llm_prefix = "" if return_direct else self.agent.llm_prefix - self.callback_manager.on_tool_end( - observation, - color=color, - observation_prefix=self.agent.observation_prefix, - llm_prefix=llm_prefix, - verbose=self.verbose, + next_step_output = self._take_next_step( + name_to_tool_map, color_mapping, inputs, intermediate_steps ) - intermediate_steps.append((output, observation)) - if return_direct: - # Set the log to "" because we do not want to log it. - output = AgentFinish({self.agent.return_values[0]: observation}, "") - return self._return(output, intermediate_steps) + if isinstance(next_step_output, AgentFinish): + return self._return(next_step_output, intermediate_steps) + + intermediate_steps.append(next_step_output) iterations += 1 output = self.agent.return_stopped_response( self.early_stopping_method, intermediate_steps, **inputs