From 7198a1cb22936f389b3833faebe20745fe7b36b8 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 28 Jan 2023 08:24:13 -0800 Subject: [PATCH] Harrison/refactor agent (#781) Co-authored-by: Amos Ng --- langchain/agents/agent.py | 45 +++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index f520c56248..9eb9f97bca 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -48,6 +48,29 @@ class Agent(BaseModel): def _stop(self) -> List[str]: return [f"\n{self.observation_prefix}"] + def _construct_scratchpad( + self, intermediate_steps: List[Tuple[AgentAction, str]] + ) -> str: + """Construct the scratchpad that lets the agent continue its thought process.""" + thoughts = "" + for action, observation in intermediate_steps: + thoughts += action.log + thoughts += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" + return thoughts + + def _get_next_action(self, full_inputs: Dict[str, str]) -> AgentAction: + full_output = self.llm_chain.predict(**full_inputs) + parsed_output = self._extract_tool_and_input(full_output) + while parsed_output is None: + full_output = self._fix_text(full_output) + full_inputs["agent_scratchpad"] += full_output + output = self.llm_chain.predict(**full_inputs) + full_output += output + parsed_output = self._extract_tool_and_input(full_output) + return AgentAction( + tool=parsed_output[0], tool_input=parsed_output[1], log=full_output + ) + def plan( self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any ) -> Union[AgentAction, AgentFinish]: @@ -61,24 +84,14 @@ class Agent(BaseModel): Returns: Action specifying what tool to use. """ - thoughts = "" - for action, observation in intermediate_steps: - thoughts += action.log - thoughts += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" + thoughts = self._construct_scratchpad(intermediate_steps) new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop} full_inputs = {**kwargs, **new_inputs} - full_output = self.llm_chain.predict(**full_inputs) - parsed_output = self._extract_tool_and_input(full_output) - while parsed_output is None: - full_output = self._fix_text(full_output) - full_inputs["agent_scratchpad"] += full_output - output = self.llm_chain.predict(**full_inputs) - full_output += output - parsed_output = self._extract_tool_and_input(full_output) - tool, tool_input = parsed_output - if tool == self.finish_tool_name: - return AgentFinish({"output": tool_input}, full_output) - return AgentAction(tool, tool_input, full_output) + + action = self._get_next_action(full_inputs) + if action.tool == self.finish_tool_name: + return AgentFinish({"output": action.tool_input}, action.log) + return action def prepare_for_new_call(self) -> None: """Prepare the agent for new call, if needed."""