harrison/agent-refactor
Harrison Chase 1 year ago
parent ac208f85c8
commit f646c94bc1

@ -1,5 +1,5 @@
"""Routing chains."""
from langchain.agents.agent import Agent
from langchain.agents.agent import AgentWithTools
from langchain.agents.loading import initialize_agent
from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent
from langchain.agents.react.base import ReActChain, ReActTextWorldAgent
@ -10,7 +10,7 @@ __all__ = [
"MRKLChain",
"SelfAskWithSearchChain",
"ReActChain",
"Agent",
"AgentWithTools",
"Tool",
"initialize_agent",
"ZeroShotAgent",

@ -1,13 +1,12 @@
"""Chain that takes in an input and produces an action and action input."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, List, NamedTuple, Optional, Tuple, Union
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel, root_validator
import langchain
from langchain.agents.input import ChainedInput
from langchain.agents.tools import Tool
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
@ -17,7 +16,7 @@ from langchain.prompts.base import BasePromptTemplate
from langchain.schema import AgentAction, AgentFinish
class Planner(BaseModel):
class Agent(BaseModel):
"""Class responsible for calling the language model and deciding the action.
This is driven by an LLMChain. The prompt in the LLMChain MUST include
@ -72,6 +71,7 @@ class Planner(BaseModel):
return AgentAction(tool, tool_input, full_output)
def prepare_for_new_call(self) -> None:
"""Prepare the agent for new call, if needed."""
pass
@property
@ -107,8 +107,8 @@ class Planner(BaseModel):
def llm_prefix(self) -> str:
"""Prefix to append the LLM call with."""
@abstractmethod
@classmethod
@abstractmethod
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
"""Create a prompt for this class."""
@ -118,16 +118,17 @@ class Planner(BaseModel):
pass
@classmethod
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> Planner:
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
return cls(llm_chain=llm_chain)
class Agent(Chain, BaseModel):
class AgentWithTools(Chain, BaseModel):
"""Consists of an agent using tools."""
planner: Planner
agent: Agent
tools: List[Tool]
return_intermediate_steps: bool = False
@ -137,7 +138,7 @@ class Agent(Chain, BaseModel):
:meta private:
"""
return self.planner.input_keys
return self.agent.input_keys
@property
def output_keys(self) -> List[str]:
@ -146,26 +147,25 @@ class Agent(Chain, BaseModel):
:meta private:
"""
if self.return_intermediate_steps:
return self.planner.return_values + ["intermediate_steps"]
return self.agent.return_values + ["intermediate_steps"]
else:
return self.planner.return_values
return self.agent.return_values
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run text through and get agent response."""
# Do any preparation necessary when receiving a new input.
self.planner.prepare_for_new_call()
self.agent.prepare_for_new_call()
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool.func for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging.
color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green"]
)
planner_inputs = inputs.copy()
intermediate_steps: List[Tuple[AgentAction, str]] = []
# We now enter the agent loop (until it returns something).
while True:
# Call the LLM to see what to do.
output = self.planner.plan(intermediate_steps, **planner_inputs)
output = self.agent.plan(intermediate_steps, **inputs)
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
if self.verbose:

@ -1,61 +0,0 @@
"""Input manager for agents."""
from typing import List, Optional
import langchain
from langchain.schema import AgentAction
class ChainedInput:
"""Class for working with input that is the result of chains."""
def __init__(self, text: str, verbose: bool = False):
"""Initialize with verbose flag and initial text."""
self._verbose = verbose
if self._verbose:
langchain.logger.log_agent_start(text)
self._input = text
self._intermediate_actions: List[AgentAction] = []
self._intermediate_observations: List[str] = []
@property
def intermediate_steps(self) -> List:
"""Return intermediate steps the agent took."""
steps = []
for i, action in enumerate(self._intermediate_actions):
step = {
"log": action.log,
"tool": action.tool,
"tool_input": action.tool_input,
"observation": self._intermediate_observations[i],
}
steps.append(step)
return steps
def add_action(self, action: AgentAction, color: Optional[str] = None) -> None:
"""Add text to input, print if in verbose mode."""
self._input += action.log
self._intermediate_actions.append(action)
def add_observation(
self,
observation: str,
observation_prefix: str,
llm_prefix: str,
color: Optional[str],
) -> None:
"""Add observation to input, print if in verbose mode."""
if self._verbose:
langchain.logger.log_agent_observation(
observation,
color=color,
observation_prefix=observation_prefix,
llm_prefix=llm_prefix,
)
self._input += f"\n{observation_prefix}{observation}\n{llm_prefix}"
self._intermediate_observations.append(observation)
@property
def input(self) -> str:
"""Return the accumulated input."""
return self._input

@ -1,17 +1,17 @@
"""Load agent."""
from typing import Any, List
from langchain.agents.agent import Agent, Planner
from langchain.agents.mrkl.base import ZeroShotPlanner
from langchain.agents.react.base import ReActDocstorePlanner
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchPlanner
from langchain.agents.agent import AgentWithTools
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool
from langchain.llms.base import LLM
AGENT_TO_CLASS = {
"zero-shot-react-description": ZeroShotPlanner,
"react-docstore": ReActDocstorePlanner,
"self-ask-with-search": SelfAskWithSearchPlanner,
"zero-shot-react-description": ZeroShotAgent,
"react-docstore": ReActDocstoreAgent,
"self-ask-with-search": SelfAskWithSearchAgent,
}
@ -20,7 +20,7 @@ def initialize_agent(
llm: LLM,
agent: str = "zero-shot-react-description",
**kwargs: Any,
) -> Agent:
) -> AgentWithTools:
"""Load agent given tools and LLM.
Args:
@ -39,5 +39,5 @@ def initialize_agent(
f"Valid types are: {AGENT_TO_CLASS.keys()}."
)
agent_cls = AGENT_TO_CLASS[agent]
planner = agent_cls.from_llm_and_tools(llm, tools, **kwargs)
return Agent(planner=planner, tools=tools)
agent_obj = agent_cls.from_llm_and_tools(llm, tools)
return AgentWithTools(agent=agent_obj, tools=tools, **kwargs)

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Callable, List, NamedTuple, Optional, Tuple
from langchain.agents.agent import Agent, Planner
from langchain.agents.agent import Agent, AgentWithTools
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.tools import Tool
from langchain.llms.base import LLM
@ -47,7 +47,7 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]:
return action, action_input.strip(" ").strip('"')
class ZeroShotPlanner(Planner):
class ZeroShotAgent(Agent):
"""Agent for the MRKL chain."""
@property
@ -101,10 +101,7 @@ class ZeroShotPlanner(Planner):
return get_action_and_input(text)
ZeroShotAgent = ZeroShotPlanner
class MRKLChain(Agent):
class MRKLChain(AgentWithTools):
"""Chain that implements the MRKL system.
Example:
@ -119,7 +116,9 @@ class MRKLChain(Agent):
"""
@classmethod
def from_chains(cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any) -> Agent:
def from_chains(
cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any
) -> AgentWithTools:
"""User friendly way to initialize the MRKL chain.
This is intended to be an easy way to get up and running with the
@ -159,5 +158,5 @@ class MRKLChain(Agent):
Tool(name=c.action_name, func=c.action, description=c.action_description)
for c in chains
]
planner = ZeroShotPlanner.from_llm_and_tools(llm, tools)
return cls(planner=planner, tools=tools, **kwargs)
agent = ZeroShotAgent.from_llm_and_tools(llm, tools)
return cls(agent=agent, tools=tools, **kwargs)

@ -1,21 +1,20 @@
"""Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf."""
import re
from typing import Any, ClassVar, List, Optional, Tuple
from typing import Any, List, Optional, Tuple
from pydantic import BaseModel
from langchain.agents.agent import Agent, Planner
from langchain.agents.agent import Agent, AgentWithTools
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
from langchain.agents.react.wiki_prompt import WIKI_PROMPT
from langchain.agents.tools import Tool
from langchain.chains.llm import LLMChain
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.llms.base import LLM
from langchain.prompts.base import BasePromptTemplate
class ReActDocstorePlanner(Planner, BaseModel):
class ReActDocstoreAgent(Agent, BaseModel):
"""Agent for the ReAct chin."""
@classmethod
@ -75,9 +74,6 @@ class ReActDocstorePlanner(Planner, BaseModel):
return f"Thought {self.i}:"
ReActDocstoreAgent = ReActDocstorePlanner
class DocstoreExplorer:
"""Class to assist with exploration of a document store."""
@ -103,7 +99,7 @@ class DocstoreExplorer:
return self.document.lookup(term)
class ReActTextWorldPlanner(ReActDocstorePlanner, BaseModel):
class ReActTextWorldAgent(ReActDocstoreAgent, BaseModel):
"""Agent for the ReAct TextWorld chain."""
@classmethod
@ -120,10 +116,7 @@ class ReActTextWorldPlanner(ReActDocstorePlanner, BaseModel):
raise ValueError(f"Tool name should be Play, got {tool_names}")
ReActTextWorldAgent = ReActTextWorldPlanner
class ReActChain(Agent):
class ReActChain(AgentWithTools):
"""Chain that implements the ReAct paper.
Example:
@ -140,5 +133,5 @@ class ReActChain(Agent):
Tool(name="Search", func=docstore_explorer.search),
Tool(name="Lookup", func=docstore_explorer.lookup),
]
planner = ReActDocstorePlanner.from_llm_and_tools(llm, tools)
super().__init__(planner=planner, tools=tools, **kwargs)
agent = ReActDocstoreAgent.from_llm_and_tools(llm, tools)
super().__init__(agent=agent, tools=tools, **kwargs)

@ -1,7 +1,7 @@
"""Chain that does self ask with search."""
from typing import Any, List, Optional, Tuple
from langchain.agents.agent import Agent, Planner
from langchain.agents.agent import Agent, AgentWithTools
from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool
from langchain.llms.base import LLM
@ -9,7 +9,7 @@ from langchain.prompts.base import BasePromptTemplate
from langchain.serpapi import SerpAPIWrapper
class SelfAskWithSearchPlanner(Planner):
class SelfAskWithSearchAgent(Agent):
"""Agent for the self-ask-with-search paper."""
@classmethod
@ -63,10 +63,7 @@ class SelfAskWithSearchPlanner(Planner):
return "Are follow up questions needed here:"
SelfAskWithSearchAgent = SelfAskWithSearchPlanner
class SelfAskWithSearchChain(Agent):
class SelfAskWithSearchChain(AgentWithTools):
"""Chain that does self ask with search.
Example:
@ -80,5 +77,5 @@ class SelfAskWithSearchChain(Agent):
def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any):
"""Initialize with just an LLM and a search chain."""
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
planner = SelfAskWithSearchPlanner.from_llm_and_tools(llm, [search_tool])
super().__init__(planner=planner, tools=[search_tool], **kwargs)
agent = SelfAskWithSearchAgent.from_llm_and_tools(llm, [search_tool])
super().__init__(agent=agent, tools=[search_tool], **kwargs)

@ -1,75 +0,0 @@
"""Test input manipulating logic."""
import sys
from io import StringIO
from langchain.agents.input import ChainedInput
from langchain.input import get_color_mapping
def test_chained_input_not_verbose() -> None:
"""Test chained input logic."""
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
chained_input = ChainedInput("foo")
sys.stdout = old_stdout
output = mystdout.getvalue()
assert output == ""
assert chained_input.input == "foo"
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
chained_input.add_observation("bar", "1", "2", None)
sys.stdout = old_stdout
output = mystdout.getvalue()
assert output == ""
assert chained_input.input == "foo\n1bar\n2"
def test_chained_input_verbose() -> None:
"""Test chained input logic, making sure verbose doesn't mess it up."""
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
chained_input = ChainedInput("foo", verbose=True)
sys.stdout = old_stdout
output = mystdout.getvalue()
assert output == "foo"
assert chained_input.input == "foo"
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
chained_input.add_observation("bar", "1", "2", None)
sys.stdout = old_stdout
output = mystdout.getvalue()
assert output == "\n1bar\n2"
assert chained_input.input == "foo\n1bar\n2"
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
chained_input.add_observation("baz", "3", "4", "blue")
sys.stdout = old_stdout
output = mystdout.getvalue()
assert output == "\n3\x1b[36;1m\x1b[1;3mbaz\x1b[0m\n4"
assert chained_input.input == "foo\n1bar\n2\n3baz\n4"
def test_get_color_mapping() -> None:
"""Test getting of color mapping."""
# Test on few inputs.
items = ["foo", "bar"]
output = get_color_mapping(items)
expected_output = {"foo": "blue", "bar": "yellow"}
assert output == expected_output
# Test on a lot of inputs.
items = [f"foo-{i}" for i in range(20)]
output = get_color_mapping(items)
assert len(output) == 20
def test_get_color_mapping_excluded_colors() -> None:
"""Test getting of color mapping with excluded colors."""
items = ["foo", "bar"]
output = get_color_mapping(items, excluded_colors=["blue"])
expected_output = {"foo": "yellow", "bar": "pink"}
assert output == expected_output
Loading…
Cancel
Save