harrison/agent-refactor
Harrison Chase 1 year ago
parent ac208f85c8
commit f646c94bc1

@ -1,5 +1,5 @@
"""Routing chains.""" """Routing chains."""
from langchain.agents.agent import Agent from langchain.agents.agent import AgentWithTools
from langchain.agents.loading import initialize_agent from langchain.agents.loading import initialize_agent
from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent
from langchain.agents.react.base import ReActChain, ReActTextWorldAgent from langchain.agents.react.base import ReActChain, ReActTextWorldAgent
@ -10,7 +10,7 @@ __all__ = [
"MRKLChain", "MRKLChain",
"SelfAskWithSearchChain", "SelfAskWithSearchChain",
"ReActChain", "ReActChain",
"Agent", "AgentWithTools",
"Tool", "Tool",
"initialize_agent", "initialize_agent",
"ZeroShotAgent", "ZeroShotAgent",

@ -1,13 +1,12 @@
"""Chain that takes in an input and produces an action and action input.""" """Chain that takes in an input and produces an action and action input."""
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import abstractmethod
from typing import Any, ClassVar, Dict, List, NamedTuple, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
import langchain import langchain
from langchain.agents.input import ChainedInput
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -17,7 +16,7 @@ from langchain.prompts.base import BasePromptTemplate
from langchain.schema import AgentAction, AgentFinish from langchain.schema import AgentAction, AgentFinish
class Planner(BaseModel): class Agent(BaseModel):
"""Class responsible for calling the language model and deciding the action. """Class responsible for calling the language model and deciding the action.
This is driven by an LLMChain. The prompt in the LLMChain MUST include This is driven by an LLMChain. The prompt in the LLMChain MUST include
@ -72,6 +71,7 @@ class Planner(BaseModel):
return AgentAction(tool, tool_input, full_output) return AgentAction(tool, tool_input, full_output)
def prepare_for_new_call(self) -> None: def prepare_for_new_call(self) -> None:
"""Prepare the agent for new call, if needed."""
pass pass
@property @property
@ -107,8 +107,8 @@ class Planner(BaseModel):
def llm_prefix(self) -> str: def llm_prefix(self) -> str:
"""Prefix to append the LLM call with.""" """Prefix to append the LLM call with."""
@abstractmethod
@classmethod @classmethod
@abstractmethod
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate: def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
"""Create a prompt for this class.""" """Create a prompt for this class."""
@ -118,16 +118,17 @@ class Planner(BaseModel):
pass pass
@classmethod @classmethod
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> Planner: def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> Agent:
"""Construct an agent from an LLM and tools.""" """Construct an agent from an LLM and tools."""
cls._validate_tools(tools) cls._validate_tools(tools)
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools)) llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
return cls(llm_chain=llm_chain) return cls(llm_chain=llm_chain)
class Agent(Chain, BaseModel): class AgentWithTools(Chain, BaseModel):
"""Consists of an agent using tools."""
planner: Planner agent: Agent
tools: List[Tool] tools: List[Tool]
return_intermediate_steps: bool = False return_intermediate_steps: bool = False
@ -137,7 +138,7 @@ class Agent(Chain, BaseModel):
:meta private: :meta private:
""" """
return self.planner.input_keys return self.agent.input_keys
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:
@ -146,26 +147,25 @@ class Agent(Chain, BaseModel):
:meta private: :meta private:
""" """
if self.return_intermediate_steps: if self.return_intermediate_steps:
return self.planner.return_values + ["intermediate_steps"] return self.agent.return_values + ["intermediate_steps"]
else: else:
return self.planner.return_values return self.agent.return_values
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run text through and get agent response.""" """Run text through and get agent response."""
# Do any preparation necessary when receiving a new input. # Do any preparation necessary when receiving a new input.
self.planner.prepare_for_new_call() self.agent.prepare_for_new_call()
# Construct a mapping of tool name to tool for easy lookup # Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool.func for tool in self.tools} name_to_tool_map = {tool.name: tool.func for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging. # We construct a mapping from each tool to a color, used for logging.
color_mapping = get_color_mapping( color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green"] [tool.name for tool in self.tools], excluded_colors=["green"]
) )
planner_inputs = inputs.copy()
intermediate_steps: List[Tuple[AgentAction, str]] = [] intermediate_steps: List[Tuple[AgentAction, str]] = []
# We now enter the agent loop (until it returns something). # We now enter the agent loop (until it returns something).
while True: while True:
# Call the LLM to see what to do. # Call the LLM to see what to do.
output = self.planner.plan(intermediate_steps, **planner_inputs) output = self.agent.plan(intermediate_steps, **inputs)
# If the tool chosen is the finishing tool, then we end and return. # If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish): if isinstance(output, AgentFinish):
if self.verbose: if self.verbose:

@ -1,61 +0,0 @@
"""Input manager for agents."""
from typing import List, Optional
import langchain
from langchain.schema import AgentAction
class ChainedInput:
"""Class for working with input that is the result of chains."""
def __init__(self, text: str, verbose: bool = False):
"""Initialize with verbose flag and initial text."""
self._verbose = verbose
if self._verbose:
langchain.logger.log_agent_start(text)
self._input = text
self._intermediate_actions: List[AgentAction] = []
self._intermediate_observations: List[str] = []
@property
def intermediate_steps(self) -> List:
"""Return intermediate steps the agent took."""
steps = []
for i, action in enumerate(self._intermediate_actions):
step = {
"log": action.log,
"tool": action.tool,
"tool_input": action.tool_input,
"observation": self._intermediate_observations[i],
}
steps.append(step)
return steps
def add_action(self, action: AgentAction, color: Optional[str] = None) -> None:
"""Add text to input, print if in verbose mode."""
self._input += action.log
self._intermediate_actions.append(action)
def add_observation(
self,
observation: str,
observation_prefix: str,
llm_prefix: str,
color: Optional[str],
) -> None:
"""Add observation to input, print if in verbose mode."""
if self._verbose:
langchain.logger.log_agent_observation(
observation,
color=color,
observation_prefix=observation_prefix,
llm_prefix=llm_prefix,
)
self._input += f"\n{observation_prefix}{observation}\n{llm_prefix}"
self._intermediate_observations.append(observation)
@property
def input(self) -> str:
"""Return the accumulated input."""
return self._input

@ -1,17 +1,17 @@
"""Load agent.""" """Load agent."""
from typing import Any, List from typing import Any, List
from langchain.agents.agent import Agent, Planner from langchain.agents.agent import AgentWithTools
from langchain.agents.mrkl.base import ZeroShotPlanner from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstorePlanner from langchain.agents.react.base import ReActDocstoreAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchPlanner from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.llms.base import LLM from langchain.llms.base import LLM
AGENT_TO_CLASS = { AGENT_TO_CLASS = {
"zero-shot-react-description": ZeroShotPlanner, "zero-shot-react-description": ZeroShotAgent,
"react-docstore": ReActDocstorePlanner, "react-docstore": ReActDocstoreAgent,
"self-ask-with-search": SelfAskWithSearchPlanner, "self-ask-with-search": SelfAskWithSearchAgent,
} }
@ -20,7 +20,7 @@ def initialize_agent(
llm: LLM, llm: LLM,
agent: str = "zero-shot-react-description", agent: str = "zero-shot-react-description",
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> AgentWithTools:
"""Load agent given tools and LLM. """Load agent given tools and LLM.
Args: Args:
@ -39,5 +39,5 @@ def initialize_agent(
f"Valid types are: {AGENT_TO_CLASS.keys()}." f"Valid types are: {AGENT_TO_CLASS.keys()}."
) )
agent_cls = AGENT_TO_CLASS[agent] agent_cls = AGENT_TO_CLASS[agent]
planner = agent_cls.from_llm_and_tools(llm, tools, **kwargs) agent_obj = agent_cls.from_llm_and_tools(llm, tools)
return Agent(planner=planner, tools=tools) return AgentWithTools(agent=agent_obj, tools=tools, **kwargs)

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Callable, List, NamedTuple, Optional, Tuple from typing import Any, Callable, List, NamedTuple, Optional, Tuple
from langchain.agents.agent import Agent, Planner from langchain.agents.agent import Agent, AgentWithTools
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.llms.base import LLM from langchain.llms.base import LLM
@ -47,7 +47,7 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]:
return action, action_input.strip(" ").strip('"') return action, action_input.strip(" ").strip('"')
class ZeroShotPlanner(Planner): class ZeroShotAgent(Agent):
"""Agent for the MRKL chain.""" """Agent for the MRKL chain."""
@property @property
@ -101,10 +101,7 @@ class ZeroShotPlanner(Planner):
return get_action_and_input(text) return get_action_and_input(text)
ZeroShotAgent = ZeroShotPlanner class MRKLChain(AgentWithTools):
class MRKLChain(Agent):
"""Chain that implements the MRKL system. """Chain that implements the MRKL system.
Example: Example:
@ -119,7 +116,9 @@ class MRKLChain(Agent):
""" """
@classmethod @classmethod
def from_chains(cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any) -> Agent: def from_chains(
cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any
) -> AgentWithTools:
"""User friendly way to initialize the MRKL chain. """User friendly way to initialize the MRKL chain.
This is intended to be an easy way to get up and running with the This is intended to be an easy way to get up and running with the
@ -159,5 +158,5 @@ class MRKLChain(Agent):
Tool(name=c.action_name, func=c.action, description=c.action_description) Tool(name=c.action_name, func=c.action, description=c.action_description)
for c in chains for c in chains
] ]
planner = ZeroShotPlanner.from_llm_and_tools(llm, tools) agent = ZeroShotAgent.from_llm_and_tools(llm, tools)
return cls(planner=planner, tools=tools, **kwargs) return cls(agent=agent, tools=tools, **kwargs)

@ -1,21 +1,20 @@
"""Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf.""" """Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf."""
import re import re
from typing import Any, ClassVar, List, Optional, Tuple from typing import Any, List, Optional, Tuple
from pydantic import BaseModel from pydantic import BaseModel
from langchain.agents.agent import Agent, Planner from langchain.agents.agent import Agent, AgentWithTools
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
from langchain.agents.react.wiki_prompt import WIKI_PROMPT from langchain.agents.react.wiki_prompt import WIKI_PROMPT
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.chains.llm import LLMChain
from langchain.docstore.base import Docstore from langchain.docstore.base import Docstore
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
class ReActDocstorePlanner(Planner, BaseModel): class ReActDocstoreAgent(Agent, BaseModel):
"""Agent for the ReAct chin.""" """Agent for the ReAct chin."""
@classmethod @classmethod
@ -75,9 +74,6 @@ class ReActDocstorePlanner(Planner, BaseModel):
return f"Thought {self.i}:" return f"Thought {self.i}:"
ReActDocstoreAgent = ReActDocstorePlanner
class DocstoreExplorer: class DocstoreExplorer:
"""Class to assist with exploration of a document store.""" """Class to assist with exploration of a document store."""
@ -103,7 +99,7 @@ class DocstoreExplorer:
return self.document.lookup(term) return self.document.lookup(term)
class ReActTextWorldPlanner(ReActDocstorePlanner, BaseModel): class ReActTextWorldAgent(ReActDocstoreAgent, BaseModel):
"""Agent for the ReAct TextWorld chain.""" """Agent for the ReAct TextWorld chain."""
@classmethod @classmethod
@ -120,10 +116,7 @@ class ReActTextWorldPlanner(ReActDocstorePlanner, BaseModel):
raise ValueError(f"Tool name should be Play, got {tool_names}") raise ValueError(f"Tool name should be Play, got {tool_names}")
ReActTextWorldAgent = ReActTextWorldPlanner class ReActChain(AgentWithTools):
class ReActChain(Agent):
"""Chain that implements the ReAct paper. """Chain that implements the ReAct paper.
Example: Example:
@ -140,5 +133,5 @@ class ReActChain(Agent):
Tool(name="Search", func=docstore_explorer.search), Tool(name="Search", func=docstore_explorer.search),
Tool(name="Lookup", func=docstore_explorer.lookup), Tool(name="Lookup", func=docstore_explorer.lookup),
] ]
planner = ReActDocstorePlanner.from_llm_and_tools(llm, tools) agent = ReActDocstoreAgent.from_llm_and_tools(llm, tools)
super().__init__(planner=planner, tools=tools, **kwargs) super().__init__(agent=agent, tools=tools, **kwargs)

@ -1,7 +1,7 @@
"""Chain that does self ask with search.""" """Chain that does self ask with search."""
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple
from langchain.agents.agent import Agent, Planner from langchain.agents.agent import Agent, AgentWithTools
from langchain.agents.self_ask_with_search.prompt import PROMPT from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.llms.base import LLM from langchain.llms.base import LLM
@ -9,7 +9,7 @@ from langchain.prompts.base import BasePromptTemplate
from langchain.serpapi import SerpAPIWrapper from langchain.serpapi import SerpAPIWrapper
class SelfAskWithSearchPlanner(Planner): class SelfAskWithSearchAgent(Agent):
"""Agent for the self-ask-with-search paper.""" """Agent for the self-ask-with-search paper."""
@classmethod @classmethod
@ -63,10 +63,7 @@ class SelfAskWithSearchPlanner(Planner):
return "Are follow up questions needed here:" return "Are follow up questions needed here:"
SelfAskWithSearchAgent = SelfAskWithSearchPlanner class SelfAskWithSearchChain(AgentWithTools):
class SelfAskWithSearchChain(Agent):
"""Chain that does self ask with search. """Chain that does self ask with search.
Example: Example:
@ -80,5 +77,5 @@ class SelfAskWithSearchChain(Agent):
def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any): def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any):
"""Initialize with just an LLM and a search chain.""" """Initialize with just an LLM and a search chain."""
search_tool = Tool(name="Intermediate Answer", func=search_chain.run) search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
planner = SelfAskWithSearchPlanner.from_llm_and_tools(llm, [search_tool]) agent = SelfAskWithSearchAgent.from_llm_and_tools(llm, [search_tool])
super().__init__(planner=planner, tools=[search_tool], **kwargs) super().__init__(agent=agent, tools=[search_tool], **kwargs)

@ -1,75 +0,0 @@
"""Test input manipulating logic."""
import sys
from io import StringIO
from langchain.agents.input import ChainedInput
from langchain.input import get_color_mapping
def test_chained_input_not_verbose() -> None:
"""Test chained input logic."""
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
chained_input = ChainedInput("foo")
sys.stdout = old_stdout
output = mystdout.getvalue()
assert output == ""
assert chained_input.input == "foo"
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
chained_input.add_observation("bar", "1", "2", None)
sys.stdout = old_stdout
output = mystdout.getvalue()
assert output == ""
assert chained_input.input == "foo\n1bar\n2"
def test_chained_input_verbose() -> None:
"""Test chained input logic, making sure verbose doesn't mess it up."""
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
chained_input = ChainedInput("foo", verbose=True)
sys.stdout = old_stdout
output = mystdout.getvalue()
assert output == "foo"
assert chained_input.input == "foo"
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
chained_input.add_observation("bar", "1", "2", None)
sys.stdout = old_stdout
output = mystdout.getvalue()
assert output == "\n1bar\n2"
assert chained_input.input == "foo\n1bar\n2"
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
chained_input.add_observation("baz", "3", "4", "blue")
sys.stdout = old_stdout
output = mystdout.getvalue()
assert output == "\n3\x1b[36;1m\x1b[1;3mbaz\x1b[0m\n4"
assert chained_input.input == "foo\n1bar\n2\n3baz\n4"
def test_get_color_mapping() -> None:
"""Test getting of color mapping."""
# Test on few inputs.
items = ["foo", "bar"]
output = get_color_mapping(items)
expected_output = {"foo": "blue", "bar": "yellow"}
assert output == expected_output
# Test on a lot of inputs.
items = [f"foo-{i}" for i in range(20)]
output = get_color_mapping(items)
assert len(output) == 20
def test_get_color_mapping_excluded_colors() -> None:
"""Test getting of color mapping with excluded colors."""
items = ["foo", "bar"]
output = get_color_mapping(items, excluded_colors=["blue"])
expected_output = {"foo": "yellow", "bar": "pink"}
assert output == expected_output
Loading…
Cancel
Save