"""Chain that takes in an input and produces an action and action input.""" from __future__ import annotations import logging from abc import abstractmethod from typing import Any, Dict, List, Optional, Tuple, Union from pydantic import BaseModel, root_validator from langchain.agents.tools import Tool from langchain.callbacks.base import BaseCallbackManager from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.input import get_color_mapping from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.schema import AgentAction, AgentFinish logger = logging.getLogger() class Agent(BaseModel): """Class responsible for calling the language model and deciding the action. This is driven by an LLMChain. The prompt in the LLMChain MUST include a variable called "agent_scratchpad" where the agent can put its intermediary work. """ llm_chain: LLMChain return_values: List[str] = ["output"] @abstractmethod def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]: """Extract tool and tool input from llm output.""" def _fix_text(self, text: str) -> str: """Fix the text.""" raise ValueError("fix_text not implemented for this agent.") @property def _stop(self) -> List[str]: return [f"\n{self.observation_prefix}"] def plan( self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations **kwargs: User inputs. Returns: Action specifying what tool to use. """ thoughts = "" for action, observation in intermediate_steps: thoughts += action.log thoughts += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop} full_inputs = {**kwargs, **new_inputs} full_output = self.llm_chain.predict(**full_inputs) parsed_output = self._extract_tool_and_input(full_output) while parsed_output is None: full_output = self._fix_text(full_output) full_inputs["agent_scratchpad"] += full_output output = self.llm_chain.predict(**full_inputs) full_output += output parsed_output = self._extract_tool_and_input(full_output) tool, tool_input = parsed_output if tool == self.finish_tool_name: return AgentFinish({"output": tool_input}, full_output) return AgentAction(tool, tool_input, full_output) def prepare_for_new_call(self) -> None: """Prepare the agent for new call, if needed.""" pass @property def finish_tool_name(self) -> str: """Name of the tool to use to finish the chain.""" return "Final Answer" @property def input_keys(self) -> List[str]: """Return the input keys. :meta private: """ return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"}) @root_validator() def validate_prompt(cls, values: Dict) -> Dict: """Validate that prompt matches format.""" prompt = values["llm_chain"].prompt if "agent_scratchpad" not in prompt.input_variables: logger.warning( "`agent_scratchpad` should be a variable in prompt.input_variables." " Did not find it, so adding it at the end." ) prompt.input_variables.append("agent_scratchpad") if isinstance(prompt, PromptTemplate): prompt.template += "\n{agent_scratchpad}" elif isinstance(prompt, FewShotPromptTemplate): prompt.suffix += "\n{agent_scratchpad}" else: raise ValueError(f"Got unexpected prompt type {type(prompt)}") return values @property @abstractmethod def observation_prefix(self) -> str: """Prefix to append the observation with.""" @property @abstractmethod def llm_prefix(self) -> str: """Prefix to append the LLM call with.""" @classmethod @abstractmethod def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate: """Create a prompt for this class.""" @classmethod def _validate_tools(cls, tools: List[Tool]) -> None: """Validate that appropriate tools are passed in.""" pass @classmethod def from_llm_and_tools( cls, llm: BaseLLM, tools: List[Tool], callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, ) -> Agent: """Construct an agent from an LLM and tools.""" cls._validate_tools(tools) llm_chain = LLMChain( llm=llm, prompt=cls.create_prompt(tools), callback_manager=callback_manager, ) return cls(llm_chain=llm_chain, **kwargs) def return_stopped_response( self, early_stopping_method: str, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any, ) -> AgentFinish: """Return response when agent has been stopped due to max iterations.""" if early_stopping_method == "force": # `force` just returns a constant string return AgentFinish({"output": "Agent stopped due to max iterations."}, "") elif early_stopping_method == "generate": # Generate does one final forward pass thoughts = "" for action, observation in intermediate_steps: thoughts += action.log thoughts += ( f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" ) # Adding to the previous steps, we now tell the LLM to make a final pred thoughts += ( "\n\nI now need to return a final answer based on the previous steps:" ) new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop} full_inputs = {**kwargs, **new_inputs} full_output = self.llm_chain.predict(**full_inputs) # We try to extract a final answer parsed_output = self._extract_tool_and_input(full_output) if parsed_output is None: # If we cannot extract, we just return the full output return AgentFinish({"output": full_output}, full_output) tool, tool_input = parsed_output if tool == self.finish_tool_name: # If we can extract, we send the correct stuff return AgentFinish({"output": tool_input}, full_output) else: # If we can extract, but the tool is not the final tool, # we just return the full output return AgentFinish({"output": full_output}, full_output) else: raise ValueError( "early_stopping_method should be one of `force` or `generate`, " f"got {early_stopping_method}" ) class AgentExecutor(Chain, BaseModel): """Consists of an agent using tools.""" agent: Agent tools: List[Tool] return_intermediate_steps: bool = False max_iterations: Optional[int] = None early_stopping_method: str = "force" @classmethod def from_agent_and_tools( cls, agent: Agent, tools: List[Tool], callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, ) -> AgentExecutor: """Create from agent and tools.""" return cls( agent=agent, tools=tools, callback_manager=callback_manager, **kwargs ) @property def input_keys(self) -> List[str]: """Return the input keys. :meta private: """ return self.agent.input_keys @property def output_keys(self) -> List[str]: """Return the singular output key. :meta private: """ if self.return_intermediate_steps: return self.agent.return_values + ["intermediate_steps"] else: return self.agent.return_values def _should_continue(self, iterations: int) -> bool: if self.max_iterations is None: return True else: return iterations < self.max_iterations def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]: self.callback_manager.on_agent_finish( output, color="green", verbose=self.verbose ) final_output = output.return_values if self.return_intermediate_steps: final_output["intermediate_steps"] = intermediate_steps return final_output def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: """Run text through and get agent response.""" # Do any preparation necessary when receiving a new input. self.agent.prepare_for_new_call() # Construct a mapping of tool name to tool for easy lookup name_to_tool_map = {tool.name: tool for tool in self.tools} # We construct a mapping from each tool to a color, used for logging. color_mapping = get_color_mapping( [tool.name for tool in self.tools], excluded_colors=["green"] ) intermediate_steps: List[Tuple[AgentAction, str]] = [] # Let's start tracking the iterations the agent has gone through iterations = 0 # We now enter the agent loop (until it returns something). while self._should_continue(iterations): # Call the LLM to see what to do. output = self.agent.plan(intermediate_steps, **inputs) # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): return self._return(output, intermediate_steps) # Otherwise we lookup the tool if output.tool in name_to_tool_map: tool = name_to_tool_map[output.tool] self.callback_manager.on_tool_start( {"name": str(tool.func)[:60] + "..."}, output, color="green", verbose=self.verbose, ) try: # We then call the tool on the tool input to get an observation observation = tool.func(output.tool_input) color = color_mapping[output.tool] return_direct = tool.return_direct except (KeyboardInterrupt, Exception) as e: self.callback_manager.on_tool_error(e, verbose=self.verbose) raise e else: self.callback_manager.on_tool_start( {"name": "N/A"}, output, color="green", verbose=self.verbose ) observation = f"{output.tool} is not a valid tool, try another one." color = None return_direct = False llm_prefix = "" if return_direct else self.agent.llm_prefix self.callback_manager.on_tool_end( observation, color=color, observation_prefix=self.agent.observation_prefix, llm_prefix=llm_prefix, verbose=self.verbose, ) intermediate_steps.append((output, observation)) if return_direct: # Set the log to "" because we do not want to log it. output = AgentFinish({self.agent.return_values[0]: observation}, "") return self._return(output, intermediate_steps) iterations += 1 output = self.agent.return_stopped_response( self.early_stopping_method, intermediate_steps, **inputs ) return self._return(output, intermediate_steps)