|
|
|
@ -2,10 +2,11 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from typing import Any, ClassVar, Dict, List, Optional, Tuple
|
|
|
|
|
from typing import Any, ClassVar, Dict, List, NamedTuple, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, root_validator
|
|
|
|
|
|
|
|
|
|
import langchain
|
|
|
|
|
from langchain.agents.input import ChainedInput
|
|
|
|
|
from langchain.agents.tools import Tool
|
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
@ -14,11 +15,6 @@ from langchain.input import get_color_mapping
|
|
|
|
|
from langchain.llms.base import LLM
|
|
|
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
|
|
|
from langchain.schema import AgentAction, AgentFinish
|
|
|
|
|
import langchain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import NamedTuple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Planner(BaseModel):
|
|
|
|
@ -28,8 +24,9 @@ class Planner(BaseModel):
|
|
|
|
|
a variable called "agent_scratchpad" where the agent can put its
|
|
|
|
|
intermediary work.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
llm_chain: LLMChain
|
|
|
|
|
return_values: List[str]
|
|
|
|
|
return_values: List[str] = ["output"]
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
|
|
|
|
@ -43,7 +40,9 @@ class Planner(BaseModel):
|
|
|
|
|
def _stop(self) -> List[str]:
|
|
|
|
|
return [f"\n{self.observation_prefix}"]
|
|
|
|
|
|
|
|
|
|
def plan(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any) -> AgentAction:
|
|
|
|
|
def plan(
|
|
|
|
|
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
|
|
|
|
) -> Union[AgentFinish, AgentAction]:
|
|
|
|
|
"""Given input, decided what to do.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -68,11 +67,18 @@ class Planner(BaseModel):
|
|
|
|
|
full_output += output
|
|
|
|
|
parsed_output = self._extract_tool_and_input(full_output)
|
|
|
|
|
tool, tool_input = parsed_output
|
|
|
|
|
if tool == self.finish_tool_name:
|
|
|
|
|
return AgentFinish({"output": tool_input}, full_output)
|
|
|
|
|
return AgentAction(tool, tool_input, full_output)
|
|
|
|
|
|
|
|
|
|
def prepare_for_new_call(self):
|
|
|
|
|
def prepare_for_new_call(self) -> None:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def finish_tool_name(self) -> str:
|
|
|
|
|
"""Name of the tool to use to finish the chain."""
|
|
|
|
|
return "Final Answer"
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def input_keys(self) -> List[str]:
|
|
|
|
|
"""Return the input keys.
|
|
|
|
@ -101,8 +107,25 @@ class Planner(BaseModel):
|
|
|
|
|
def llm_prefix(self) -> str:
|
|
|
|
|
"""Prefix to append the LLM call with."""
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
@classmethod
|
|
|
|
|
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
|
|
|
|
|
"""Create a prompt for this class."""
|
|
|
|
|
|
|
|
|
|
class NewAgent(Chain, BaseModel):
|
|
|
|
|
@classmethod
|
|
|
|
|
def _validate_tools(cls, tools: List[Tool]) -> None:
|
|
|
|
|
"""Validate that appropriate tools are passed in."""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> Planner:
|
|
|
|
|
"""Construct an agent from an LLM and tools."""
|
|
|
|
|
cls._validate_tools(tools)
|
|
|
|
|
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
|
|
|
|
|
return cls(llm_chain=llm_chain)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Agent(Chain, BaseModel):
|
|
|
|
|
|
|
|
|
|
planner: Planner
|
|
|
|
|
tools: List[Tool]
|
|
|
|
@ -138,7 +161,7 @@ class NewAgent(Chain, BaseModel):
|
|
|
|
|
[tool.name for tool in self.tools], excluded_colors=["green"]
|
|
|
|
|
)
|
|
|
|
|
planner_inputs = inputs.copy()
|
|
|
|
|
intermediate_steps = []
|
|
|
|
|
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
|
|
|
|
# We now enter the agent loop (until it returns something).
|
|
|
|
|
while True:
|
|
|
|
|
# Call the LLM to see what to do.
|
|
|
|
@ -165,122 +188,3 @@ class NewAgent(Chain, BaseModel):
|
|
|
|
|
if self.verbose:
|
|
|
|
|
langchain.logger.log_agent_observation(observation, color=color)
|
|
|
|
|
intermediate_steps.append((output, observation))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Agent(Chain, BaseModel, ABC):
|
|
|
|
|
"""Agent that uses an LLM."""
|
|
|
|
|
|
|
|
|
|
prompt: ClassVar[BasePromptTemplate]
|
|
|
|
|
llm_chain: LLMChain
|
|
|
|
|
tools: List[Tool]
|
|
|
|
|
return_intermediate_steps: bool = False
|
|
|
|
|
input_key: str = "input" #: :meta private:
|
|
|
|
|
output_key: str = "output" #: :meta private:
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def output_keys(self) -> List[str]:
|
|
|
|
|
"""Return the singular output key.
|
|
|
|
|
|
|
|
|
|
:meta private:
|
|
|
|
|
"""
|
|
|
|
|
if self.return_intermediate_steps:
|
|
|
|
|
return [self.output_key, "intermediate_steps"]
|
|
|
|
|
else:
|
|
|
|
|
return [self.output_key]
|
|
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
|
def validate_prompt(cls, values: Dict) -> Dict:
|
|
|
|
|
"""Validate that prompt matches format."""
|
|
|
|
|
prompt = values["llm_chain"].prompt
|
|
|
|
|
if "agent_scratchpad" not in prompt.input_variables:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"`agent_scratchpad` should be a variable in prompt.input_variables"
|
|
|
|
|
)
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def observation_prefix(self) -> str:
|
|
|
|
|
"""Prefix to append the observation with."""
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def llm_prefix(self) -> str:
|
|
|
|
|
"""Prefix to append the LLM call with."""
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def finish_tool_name(self) -> str:
|
|
|
|
|
"""Name of the tool to use to finish the chain."""
|
|
|
|
|
return "Final Answer"
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def starter_string(self) -> str:
|
|
|
|
|
"""Put this string after user input but before first LLM call."""
|
|
|
|
|
return "\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _validate_tools(cls, tools: List[Tool]) -> None:
|
|
|
|
|
"""Validate that appropriate tools are passed in."""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
|
|
|
|
|
"""Create a prompt for this class."""
|
|
|
|
|
return cls.prompt
|
|
|
|
|
|
|
|
|
|
def _prepare_for_new_call(self) -> None:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool], **kwargs: Any) -> Agent:
|
|
|
|
|
"""Construct an agent from an LLM and tools."""
|
|
|
|
|
cls._validate_tools(tools)
|
|
|
|
|
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
|
|
|
|
|
return cls(llm_chain=llm_chain, tools=tools, **kwargs)
|
|
|
|
|
|
|
|
|
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
|
|
|
|
"""Run text through and get agent response."""
|
|
|
|
|
# Do any preparation necessary when receiving a new input.
|
|
|
|
|
self._prepare_for_new_call()
|
|
|
|
|
# Construct a mapping of tool name to tool for easy lookup
|
|
|
|
|
name_to_tool_map = {tool.name: tool.func for tool in self.tools}
|
|
|
|
|
# We use the ChainedInput class to iteratively add to the input over time.
|
|
|
|
|
chained_input = ChainedInput(self.llm_prefix, verbose=self.verbose)
|
|
|
|
|
# We construct a mapping from each tool to a color, used for logging.
|
|
|
|
|
color_mapping = get_color_mapping(
|
|
|
|
|
[tool.name for tool in self.tools], excluded_colors=["green"]
|
|
|
|
|
)
|
|
|
|
|
# We now enter the agent loop (until it returns something).
|
|
|
|
|
while True:
|
|
|
|
|
# Call the LLM to see what to do.
|
|
|
|
|
output = self.get_action(chained_input.input, inputs)
|
|
|
|
|
# If the tool chosen is the finishing tool, then we end and return.
|
|
|
|
|
if output.tool == self.finish_tool_name:
|
|
|
|
|
final_output: dict = {self.output_key: output.tool_input}
|
|
|
|
|
if self.return_intermediate_steps:
|
|
|
|
|
final_output[
|
|
|
|
|
"intermediate_steps"
|
|
|
|
|
] = chained_input.intermediate_steps
|
|
|
|
|
return final_output
|
|
|
|
|
# Other we add the log to the Chained Input.
|
|
|
|
|
chained_input.add_action(output, color="green")
|
|
|
|
|
# And then we lookup the tool
|
|
|
|
|
if output.tool in name_to_tool_map:
|
|
|
|
|
chain = name_to_tool_map[output.tool]
|
|
|
|
|
# We then call the tool on the tool input to get an observation
|
|
|
|
|
observation = chain(output.tool_input)
|
|
|
|
|
color = color_mapping[output.tool]
|
|
|
|
|
else:
|
|
|
|
|
observation = f"{output.tool} is not a valid tool, try another one."
|
|
|
|
|
color = None
|
|
|
|
|
# We then log the observation
|
|
|
|
|
chained_input.add_observation(
|
|
|
|
|
observation,
|
|
|
|
|
self.observation_prefix,
|
|
|
|
|
self.llm_prefix,
|
|
|
|
|
color=color,
|
|
|
|
|
)
|
|
|
|
|