|
|
|
"""Chain that takes in an input and produces an action and action input."""
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import logging
|
|
|
|
from abc import abstractmethod
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
from pydantic import BaseModel, root_validator
|
|
|
|
|
|
|
|
from langchain.agents.tools import Tool
|
|
|
|
from langchain.callbacks.base import BaseCallbackManager
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
from langchain.input import get_color_mapping
|
|
|
|
from langchain.llms.base import BaseLLM
|
|
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
|
|
from langchain.prompts.few_shot import FewShotPromptTemplate
|
|
|
|
from langchain.prompts.prompt import PromptTemplate
|
|
|
|
from langchain.schema import AgentAction, AgentFinish
|
|
|
|
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
|
|
|
|
|
|
class Agent(BaseModel):
|
|
|
|
"""Class responsible for calling the language model and deciding the action.
|
|
|
|
|
|
|
|
This is driven by an LLMChain. The prompt in the LLMChain MUST include
|
|
|
|
a variable called "agent_scratchpad" where the agent can put its
|
|
|
|
intermediary work.
|
|
|
|
"""
|
|
|
|
|
|
|
|
llm_chain: LLMChain
|
|
|
|
return_values: List[str] = ["output"]
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
|
|
|
|
"""Extract tool and tool input from llm output."""
|
|
|
|
|
|
|
|
def _fix_text(self, text: str) -> str:
|
|
|
|
"""Fix the text."""
|
|
|
|
raise ValueError("fix_text not implemented for this agent.")
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _stop(self) -> List[str]:
|
|
|
|
return [f"\n{self.observation_prefix}"]
|
|
|
|
|
|
|
|
def plan(
|
|
|
|
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
|
|
|
) -> Union[AgentAction, AgentFinish]:
|
|
|
|
"""Given input, decided what to do.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
intermediate_steps: Steps the LLM has taken to date,
|
|
|
|
along with observations
|
|
|
|
**kwargs: User inputs.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Action specifying what tool to use.
|
|
|
|
"""
|
|
|
|
thoughts = ""
|
|
|
|
for action, observation in intermediate_steps:
|
|
|
|
thoughts += action.log
|
|
|
|
thoughts += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
|
|
|
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
|
|
|
|
full_inputs = {**kwargs, **new_inputs}
|
|
|
|
full_output = self.llm_chain.predict(**full_inputs)
|
|
|
|
parsed_output = self._extract_tool_and_input(full_output)
|
|
|
|
while parsed_output is None:
|
|
|
|
full_output = self._fix_text(full_output)
|
|
|
|
full_inputs["agent_scratchpad"] += full_output
|
|
|
|
output = self.llm_chain.predict(**full_inputs)
|
|
|
|
full_output += output
|
|
|
|
parsed_output = self._extract_tool_and_input(full_output)
|
|
|
|
tool, tool_input = parsed_output
|
|
|
|
if tool == self.finish_tool_name:
|
|
|
|
return AgentFinish({"output": tool_input}, full_output)
|
|
|
|
return AgentAction(tool, tool_input, full_output)
|
|
|
|
|
|
|
|
def prepare_for_new_call(self) -> None:
|
|
|
|
"""Prepare the agent for new call, if needed."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@property
|
|
|
|
def finish_tool_name(self) -> str:
|
|
|
|
"""Name of the tool to use to finish the chain."""
|
|
|
|
return "Final Answer"
|
|
|
|
|
|
|
|
@property
|
|
|
|
def input_keys(self) -> List[str]:
|
|
|
|
"""Return the input keys.
|
|
|
|
|
|
|
|
:meta private:
|
|
|
|
"""
|
|
|
|
return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"})
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
def validate_prompt(cls, values: Dict) -> Dict:
|
|
|
|
"""Validate that prompt matches format."""
|
|
|
|
prompt = values["llm_chain"].prompt
|
|
|
|
if "agent_scratchpad" not in prompt.input_variables:
|
|
|
|
logger.warning(
|
|
|
|
"`agent_scratchpad` should be a variable in prompt.input_variables."
|
|
|
|
" Did not find it, so adding it at the end."
|
|
|
|
)
|
|
|
|
prompt.input_variables.append("agent_scratchpad")
|
|
|
|
if isinstance(prompt, PromptTemplate):
|
|
|
|
prompt.template += "\n{agent_scratchpad}"
|
|
|
|
elif isinstance(prompt, FewShotPromptTemplate):
|
|
|
|
prompt.suffix += "\n{agent_scratchpad}"
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Got unexpected prompt type {type(prompt)}")
|
|
|
|
return values
|
|
|
|
|
|
|
|
@property
|
|
|
|
@abstractmethod
|
|
|
|
def observation_prefix(self) -> str:
|
|
|
|
"""Prefix to append the observation with."""
|
|
|
|
|
|
|
|
@property
|
|
|
|
@abstractmethod
|
|
|
|
def llm_prefix(self) -> str:
|
|
|
|
"""Prefix to append the LLM call with."""
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@abstractmethod
|
|
|
|
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
|
|
|
|
"""Create a prompt for this class."""
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _validate_tools(cls, tools: List[Tool]) -> None:
|
|
|
|
"""Validate that appropriate tools are passed in."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_llm_and_tools(
|
|
|
|
cls,
|
|
|
|
llm: BaseLLM,
|
|
|
|
tools: List[Tool],
|
|
|
|
callback_manager: Optional[BaseCallbackManager] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> Agent:
|
|
|
|
"""Construct an agent from an LLM and tools."""
|
|
|
|
cls._validate_tools(tools)
|
|
|
|
llm_chain = LLMChain(
|
|
|
|
llm=llm,
|
|
|
|
prompt=cls.create_prompt(tools),
|
|
|
|
callback_manager=callback_manager,
|
|
|
|
)
|
|
|
|
return cls(llm_chain=llm_chain, **kwargs)
|
|
|
|
|
|
|
|
def return_stopped_response(
|
|
|
|
self,
|
|
|
|
early_stopping_method: str,
|
|
|
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> AgentFinish:
|
|
|
|
"""Return response when agent has been stopped due to max iterations."""
|
|
|
|
if early_stopping_method == "force":
|
|
|
|
# `force` just returns a constant string
|
|
|
|
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
|
|
|
|
elif early_stopping_method == "generate":
|
|
|
|
# Generate does one final forward pass
|
|
|
|
thoughts = ""
|
|
|
|
for action, observation in intermediate_steps:
|
|
|
|
thoughts += action.log
|
|
|
|
thoughts += (
|
|
|
|
f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
|
|
|
)
|
|
|
|
# Adding to the previous steps, we now tell the LLM to make a final pred
|
|
|
|
thoughts += (
|
|
|
|
"\n\nI now need to return a final answer based on the previous steps:"
|
|
|
|
)
|
|
|
|
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
|
|
|
|
full_inputs = {**kwargs, **new_inputs}
|
|
|
|
full_output = self.llm_chain.predict(**full_inputs)
|
|
|
|
# We try to extract a final answer
|
|
|
|
parsed_output = self._extract_tool_and_input(full_output)
|
|
|
|
if parsed_output is None:
|
|
|
|
# If we cannot extract, we just return the full output
|
|
|
|
return AgentFinish({"output": full_output}, full_output)
|
|
|
|
tool, tool_input = parsed_output
|
|
|
|
if tool == self.finish_tool_name:
|
|
|
|
# If we can extract, we send the correct stuff
|
|
|
|
return AgentFinish({"output": tool_input}, full_output)
|
|
|
|
else:
|
|
|
|
# If we can extract, but the tool is not the final tool,
|
|
|
|
# we just return the full output
|
|
|
|
return AgentFinish({"output": full_output}, full_output)
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
"early_stopping_method should be one of `force` or `generate`, "
|
|
|
|
f"got {early_stopping_method}"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class AgentExecutor(Chain, BaseModel):
|
|
|
|
"""Consists of an agent using tools."""
|
|
|
|
|
|
|
|
agent: Agent
|
|
|
|
tools: List[Tool]
|
|
|
|
return_intermediate_steps: bool = False
|
|
|
|
max_iterations: Optional[int] = None
|
|
|
|
early_stopping_method: str = "force"
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_agent_and_tools(
|
|
|
|
cls,
|
|
|
|
agent: Agent,
|
|
|
|
tools: List[Tool],
|
|
|
|
callback_manager: Optional[BaseCallbackManager] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> AgentExecutor:
|
|
|
|
"""Create from agent and tools."""
|
|
|
|
return cls(
|
|
|
|
agent=agent, tools=tools, callback_manager=callback_manager, **kwargs
|
|
|
|
)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def input_keys(self) -> List[str]:
|
|
|
|
"""Return the input keys.
|
|
|
|
|
|
|
|
:meta private:
|
|
|
|
"""
|
|
|
|
return self.agent.input_keys
|
|
|
|
|
|
|
|
@property
|
|
|
|
def output_keys(self) -> List[str]:
|
|
|
|
"""Return the singular output key.
|
|
|
|
|
|
|
|
:meta private:
|
|
|
|
"""
|
|
|
|
if self.return_intermediate_steps:
|
|
|
|
return self.agent.return_values + ["intermediate_steps"]
|
|
|
|
else:
|
|
|
|
return self.agent.return_values
|
|
|
|
|
|
|
|
def _should_continue(self, iterations: int) -> bool:
|
|
|
|
if self.max_iterations is None:
|
|
|
|
return True
|
|
|
|
else:
|
|
|
|
return iterations < self.max_iterations
|
|
|
|
|
|
|
|
def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]:
|
|
|
|
if self.verbose:
|
|
|
|
self.callback_manager.on_agent_finish(output, color="green")
|
|
|
|
final_output = output.return_values
|
|
|
|
if self.return_intermediate_steps:
|
|
|
|
final_output["intermediate_steps"] = intermediate_steps
|
|
|
|
return final_output
|
|
|
|
|
|
|
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
|
|
|
"""Run text through and get agent response."""
|
|
|
|
# Do any preparation necessary when receiving a new input.
|
|
|
|
self.agent.prepare_for_new_call()
|
|
|
|
# Construct a mapping of tool name to tool for easy lookup
|
|
|
|
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
|
|
|
# We construct a mapping from each tool to a color, used for logging.
|
|
|
|
color_mapping = get_color_mapping(
|
|
|
|
[tool.name for tool in self.tools], excluded_colors=["green"]
|
|
|
|
)
|
|
|
|
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
|
|
|
# Let's start tracking the iterations the agent has gone through
|
|
|
|
iterations = 0
|
|
|
|
# We now enter the agent loop (until it returns something).
|
|
|
|
while self._should_continue(iterations):
|
|
|
|
# Call the LLM to see what to do.
|
|
|
|
output = self.agent.plan(intermediate_steps, **inputs)
|
|
|
|
# If the tool chosen is the finishing tool, then we end and return.
|
|
|
|
if isinstance(output, AgentFinish):
|
|
|
|
return self._return(output, intermediate_steps)
|
|
|
|
|
|
|
|
# Otherwise we lookup the tool
|
|
|
|
if output.tool in name_to_tool_map:
|
|
|
|
tool = name_to_tool_map[output.tool]
|
|
|
|
if self.verbose:
|
|
|
|
self.callback_manager.on_tool_start(
|
|
|
|
{"name": str(tool.func)[:60] + "..."}, output, color="green"
|
|
|
|
)
|
|
|
|
try:
|
|
|
|
# We then call the tool on the tool input to get an observation
|
|
|
|
observation = tool.func(output.tool_input)
|
|
|
|
color = color_mapping[output.tool]
|
|
|
|
return_direct = tool.return_direct
|
|
|
|
except Exception as e:
|
|
|
|
if self.verbose:
|
|
|
|
self.callback_manager.on_tool_error(e)
|
|
|
|
raise e
|
|
|
|
else:
|
|
|
|
if self.verbose:
|
|
|
|
self.callback_manager.on_tool_start(
|
|
|
|
{"name": "N/A"}, output, color="green"
|
|
|
|
)
|
|
|
|
observation = f"{output.tool} is not a valid tool, try another one."
|
|
|
|
color = None
|
|
|
|
return_direct = False
|
|
|
|
if self.verbose:
|
|
|
|
llm_prefix = "" if return_direct else self.agent.llm_prefix
|
|
|
|
self.callback_manager.on_tool_end(
|
|
|
|
observation,
|
|
|
|
color=color,
|
|
|
|
observation_prefix=self.agent.observation_prefix,
|
|
|
|
llm_prefix=llm_prefix,
|
|
|
|
)
|
|
|
|
intermediate_steps.append((output, observation))
|
|
|
|
if return_direct:
|
|
|
|
# Set the log to "" because we do not want to log it.
|
|
|
|
output = AgentFinish({self.agent.return_values[0]: observation}, "")
|
|
|
|
return self._return(output, intermediate_steps)
|
|
|
|
iterations += 1
|
|
|
|
output = self.agent.return_stopped_response(
|
|
|
|
self.early_stopping_method, intermediate_steps, **inputs
|
|
|
|
)
|
|
|
|
return self._return(output, intermediate_steps)
|