agent refactor

This commit is contained in:
Harrison Chase 2022-12-17 20:29:12 -08:00
parent 85e7c5fd6c
commit ac208f85c8
10 changed files with 95 additions and 171 deletions

View File

@ -224,7 +224,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.10.8"
}
},
"nbformat": 4,

View File

@ -2,10 +2,11 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, List, Optional, Tuple
from typing import Any, ClassVar, Dict, List, NamedTuple, Optional, Tuple, Union
from pydantic import BaseModel, root_validator
import langchain
from langchain.agents.input import ChainedInput
from langchain.agents.tools import Tool
from langchain.chains.base import Chain
@ -14,11 +15,6 @@ from langchain.input import get_color_mapping
from langchain.llms.base import LLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import AgentAction, AgentFinish
import langchain
from typing import NamedTuple
class Planner(BaseModel):
@ -28,8 +24,9 @@ class Planner(BaseModel):
a variable called "agent_scratchpad" where the agent can put its
intermediary work.
"""
llm_chain: LLMChain
return_values: List[str]
return_values: List[str] = ["output"]
@abstractmethod
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
@ -43,7 +40,9 @@ class Planner(BaseModel):
def _stop(self) -> List[str]:
return [f"\n{self.observation_prefix}"]
def plan(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any) -> AgentAction:
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Union[AgentFinish, AgentAction]:
"""Given input, decided what to do.
Args:
@ -68,11 +67,18 @@ class Planner(BaseModel):
full_output += output
parsed_output = self._extract_tool_and_input(full_output)
tool, tool_input = parsed_output
if tool == self.finish_tool_name:
return AgentFinish({"output": tool_input}, full_output)
return AgentAction(tool, tool_input, full_output)
def prepare_for_new_call(self):
def prepare_for_new_call(self) -> None:
pass
@property
def finish_tool_name(self) -> str:
"""Name of the tool to use to finish the chain."""
return "Final Answer"
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
@ -101,8 +107,25 @@ class Planner(BaseModel):
def llm_prefix(self) -> str:
"""Prefix to append the LLM call with."""
@abstractmethod
@classmethod
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
"""Create a prompt for this class."""
class NewAgent(Chain, BaseModel):
@classmethod
def _validate_tools(cls, tools: List[Tool]) -> None:
"""Validate that appropriate tools are passed in."""
pass
@classmethod
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> Planner:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
return cls(llm_chain=llm_chain)
class Agent(Chain, BaseModel):
planner: Planner
tools: List[Tool]
@ -138,7 +161,7 @@ class NewAgent(Chain, BaseModel):
[tool.name for tool in self.tools], excluded_colors=["green"]
)
planner_inputs = inputs.copy()
intermediate_steps = []
intermediate_steps: List[Tuple[AgentAction, str]] = []
# We now enter the agent loop (until it returns something).
while True:
# Call the LLM to see what to do.
@ -165,122 +188,3 @@ class NewAgent(Chain, BaseModel):
if self.verbose:
langchain.logger.log_agent_observation(observation, color=color)
intermediate_steps.append((output, observation))
class Agent(Chain, BaseModel, ABC):
"""Agent that uses an LLM."""
prompt: ClassVar[BasePromptTemplate]
llm_chain: LLMChain
tools: List[Tool]
return_intermediate_steps: bool = False
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
if self.return_intermediate_steps:
return [self.output_key, "intermediate_steps"]
else:
return [self.output_key]
@root_validator()
def validate_prompt(cls, values: Dict) -> Dict:
"""Validate that prompt matches format."""
prompt = values["llm_chain"].prompt
if "agent_scratchpad" not in prompt.input_variables:
raise ValueError(
"`agent_scratchpad` should be a variable in prompt.input_variables"
)
return values
@property
@abstractmethod
def observation_prefix(self) -> str:
"""Prefix to append the observation with."""
@property
@abstractmethod
def llm_prefix(self) -> str:
"""Prefix to append the LLM call with."""
@property
def finish_tool_name(self) -> str:
"""Name of the tool to use to finish the chain."""
return "Final Answer"
@property
def starter_string(self) -> str:
"""Put this string after user input but before first LLM call."""
return "\n"
@classmethod
def _validate_tools(cls, tools: List[Tool]) -> None:
"""Validate that appropriate tools are passed in."""
pass
@classmethod
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
"""Create a prompt for this class."""
return cls.prompt
def _prepare_for_new_call(self) -> None:
pass
@classmethod
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool], **kwargs: Any) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
return cls(llm_chain=llm_chain, tools=tools, **kwargs)
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run text through and get agent response."""
# Do any preparation necessary when receiving a new input.
self._prepare_for_new_call()
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool.func for tool in self.tools}
# We use the ChainedInput class to iteratively add to the input over time.
chained_input = ChainedInput(self.llm_prefix, verbose=self.verbose)
# We construct a mapping from each tool to a color, used for logging.
color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green"]
)
# We now enter the agent loop (until it returns something).
while True:
# Call the LLM to see what to do.
output = self.get_action(chained_input.input, inputs)
# If the tool chosen is the finishing tool, then we end and return.
if output.tool == self.finish_tool_name:
final_output: dict = {self.output_key: output.tool_input}
if self.return_intermediate_steps:
final_output[
"intermediate_steps"
] = chained_input.intermediate_steps
return final_output
# Other we add the log to the Chained Input.
chained_input.add_action(output, color="green")
# And then we lookup the tool
if output.tool in name_to_tool_map:
chain = name_to_tool_map[output.tool]
# We then call the tool on the tool input to get an observation
observation = chain(output.tool_input)
color = color_mapping[output.tool]
else:
observation = f"{output.tool} is not a valid tool, try another one."
color = None
# We then log the observation
chained_input.add_observation(
observation,
self.observation_prefix,
self.llm_prefix,
color=color,
)

View File

@ -1,17 +1,17 @@
"""Load agent."""
from typing import Any, List
from langchain.agents.agent import Agent
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.agent import Agent, Planner
from langchain.agents.mrkl.base import ZeroShotPlanner
from langchain.agents.react.base import ReActDocstorePlanner
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchPlanner
from langchain.agents.tools import Tool
from langchain.llms.base import LLM
AGENT_TO_CLASS = {
"zero-shot-react-description": ZeroShotAgent,
"react-docstore": ReActDocstoreAgent,
"self-ask-with-search": SelfAskWithSearchAgent,
"zero-shot-react-description": ZeroShotPlanner,
"react-docstore": ReActDocstorePlanner,
"self-ask-with-search": SelfAskWithSearchPlanner,
}
@ -39,4 +39,5 @@ def initialize_agent(
f"Valid types are: {AGENT_TO_CLASS.keys()}."
)
agent_cls = AGENT_TO_CLASS[agent]
return agent_cls.from_llm_and_tools(llm, tools, **kwargs)
planner = agent_cls.from_llm_and_tools(llm, tools, **kwargs)
return Agent(planner=planner, tools=tools)

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Callable, List, NamedTuple, Optional, Tuple
from langchain.agents.agent import Agent
from langchain.agents.agent import Agent, Planner
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.tools import Tool
from langchain.llms.base import LLM
@ -47,7 +47,7 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]:
return action, action_input.strip(" ").strip('"')
class ZeroShotAgent(Agent):
class ZeroShotPlanner(Planner):
"""Agent for the MRKL chain."""
@property
@ -101,7 +101,10 @@ class ZeroShotAgent(Agent):
return get_action_and_input(text)
class MRKLChain(ZeroShotAgent):
ZeroShotAgent = ZeroShotPlanner
class MRKLChain(Agent):
"""Chain that implements the MRKL system.
Example:
@ -156,4 +159,5 @@ class MRKLChain(ZeroShotAgent):
Tool(name=c.action_name, func=c.action, description=c.action_description)
for c in chains
]
return cls.from_llm_and_tools(llm, tools, **kwargs)
planner = ZeroShotPlanner.from_llm_and_tools(llm, tools)
return cls(planner=planner, tools=tools, **kwargs)

View File

@ -13,4 +13,4 @@ Final Answer: the final answer to the original input question"""
SUFFIX = """Begin!
Question: {input}
{agent_scratchpad}"""
Thought:{agent_scratchpad}"""

View File

@ -4,7 +4,7 @@ from typing import Any, ClassVar, List, Optional, Tuple
from pydantic import BaseModel
from langchain.agents.agent import Agent
from langchain.agents.agent import Agent, Planner
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
from langchain.agents.react.wiki_prompt import WIKI_PROMPT
from langchain.agents.tools import Tool
@ -15,10 +15,13 @@ from langchain.llms.base import LLM
from langchain.prompts.base import BasePromptTemplate
class ReActDocstoreAgent(Agent, BaseModel):
class ReActDocstorePlanner(Planner, BaseModel):
"""Agent for the ReAct chin."""
prompt: ClassVar[BasePromptTemplate] = WIKI_PROMPT
@classmethod
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
"""Return default prompt."""
return WIKI_PROMPT
i: int = 1
@ -72,6 +75,9 @@ class ReActDocstoreAgent(Agent, BaseModel):
return f"Thought {self.i}:"
ReActDocstoreAgent = ReActDocstorePlanner
class DocstoreExplorer:
"""Class to assist with exploration of a document store."""
@ -97,12 +103,13 @@ class DocstoreExplorer:
return self.document.lookup(term)
class ReActTextWorldAgent(ReActDocstoreAgent, BaseModel):
class ReActTextWorldPlanner(ReActDocstorePlanner, BaseModel):
"""Agent for the ReAct TextWorld chain."""
prompt: ClassVar[BasePromptTemplate] = TEXTWORLD_PROMPT
i: int = 1
@classmethod
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
"""Return default prompt."""
return TEXTWORLD_PROMPT
@classmethod
def _validate_tools(cls, tools: List[Tool]) -> None:
@ -113,7 +120,10 @@ class ReActTextWorldAgent(ReActDocstoreAgent, BaseModel):
raise ValueError(f"Tool name should be Play, got {tool_names}")
class ReActChain(ReActDocstoreAgent):
ReActTextWorldAgent = ReActTextWorldPlanner
class ReActChain(Agent):
"""Chain that implements the ReAct paper.
Example:
@ -130,5 +140,5 @@ class ReActChain(ReActDocstoreAgent):
Tool(name="Search", func=docstore_explorer.search),
Tool(name="Lookup", func=docstore_explorer.lookup),
]
llm_chain = LLMChain(llm=llm, prompt=WIKI_PROMPT)
super().__init__(llm_chain=llm_chain, tools=tools, **kwargs)
planner = ReActDocstorePlanner.from_llm_and_tools(llm, tools)
super().__init__(planner=planner, tools=tools, **kwargs)

View File

@ -1,19 +1,21 @@
"""Chain that does self ask with search."""
from typing import Any, ClassVar, List, Optional, Tuple
from typing import Any, List, Optional, Tuple
from langchain.agents.agent import Agent
from langchain.agents.agent import Agent, Planner
from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool
from langchain.chains.llm import LLMChain
from langchain.llms.base import LLM
from langchain.prompts.base import BasePromptTemplate
from langchain.serpapi import SerpAPIWrapper
class SelfAskWithSearchAgent(Agent):
class SelfAskWithSearchPlanner(Planner):
"""Agent for the self-ask-with-search paper."""
prompt: ClassVar[BasePromptTemplate] = PROMPT
@classmethod
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
"""Prompt does not depend on tools."""
return PROMPT
@classmethod
def _validate_tools(cls, tools: List[Tool]) -> None:
@ -61,7 +63,10 @@ class SelfAskWithSearchAgent(Agent):
return "Are follow up questions needed here:"
class SelfAskWithSearchChain(SelfAskWithSearchAgent):
SelfAskWithSearchAgent = SelfAskWithSearchPlanner
class SelfAskWithSearchChain(Agent):
"""Chain that does self ask with search.
Example:
@ -75,5 +80,5 @@ class SelfAskWithSearchChain(SelfAskWithSearchAgent):
def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any):
"""Initialize with just an LLM and a search chain."""
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
super().__init__(llm_chain=llm_chain, tools=[search_tool], **kwargs)
planner = SelfAskWithSearchPlanner.from_llm_and_tools(llm, [search_tool])
super().__init__(planner=planner, tools=[search_tool], **kwargs)

View File

@ -38,7 +38,7 @@ Intermediate answer: New Zealand.
So the final answer is: No
Question: {input}
{agent_scratchpad}"""
Are followup questions needed here:{agent_scratchpad}"""
PROMPT = PromptTemplate(
input_variables=["input", "agent_scratchpad"], template=_DEFAULT_TEMPLATE
)

View File

@ -13,6 +13,7 @@ class AgentAction(NamedTuple):
class AgentFinish(NamedTuple):
"""Agent's return value."""
return_values: dict
log: str

View File

@ -10,6 +10,7 @@ from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.llms.base import LLM
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import AgentAction
_PAGE_CONTENT = """This is a page about LangChain.
@ -61,10 +62,9 @@ def test_predict_until_observation_normal() -> None:
Tool("Lookup", lambda x: x),
]
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools)
output = agent.get_action("", {"input": ""})
assert output.log == outputs[0]
assert output.tool == "Search"
assert output.tool_input == "foo"
output = agent.plan([], input="")
expected_output = AgentAction("Search", "foo", outputs[0])
assert output == expected_output
def test_predict_until_observation_repeat() -> None:
@ -76,10 +76,9 @@ def test_predict_until_observation_repeat() -> None:
Tool("Lookup", lambda x: x),
]
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools)
output = agent.get_action("", {"input": ""})
assert output.log == "foo\nAction 1: Search[foo]"
assert output.tool == "Search"
assert output.tool_input == "foo"
output = agent.plan([], input="")
expected_output = AgentAction("Search", "foo", "foo\nAction 1: Search[foo]")
assert output == expected_output
def test_react_chain() -> None: