From f646c94bc1c4485e1dc36dfffb7282e60ecd3a99 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 18 Dec 2022 11:08:14 -0500 Subject: [PATCH] cr --- langchain/agents/__init__.py | 4 +- langchain/agents/agent.py | 28 +++---- langchain/agents/input.py | 61 --------------- langchain/agents/loading.py | 20 ++--- langchain/agents/mrkl/base.py | 17 ++--- langchain/agents/react/base.py | 21 ++---- langchain/agents/self_ask_with_search/base.py | 13 ++-- tests/unit_tests/test_input.py | 75 ------------------- 8 files changed, 46 insertions(+), 193 deletions(-) delete mode 100644 langchain/agents/input.py delete mode 100644 tests/unit_tests/test_input.py diff --git a/langchain/agents/__init__.py b/langchain/agents/__init__.py index 6c41e16b..cec78bfa 100644 --- a/langchain/agents/__init__.py +++ b/langchain/agents/__init__.py @@ -1,5 +1,5 @@ """Routing chains.""" -from langchain.agents.agent import Agent +from langchain.agents.agent import AgentWithTools from langchain.agents.loading import initialize_agent from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent from langchain.agents.react.base import ReActChain, ReActTextWorldAgent @@ -10,7 +10,7 @@ __all__ = [ "MRKLChain", "SelfAskWithSearchChain", "ReActChain", - "Agent", + "AgentWithTools", "Tool", "initialize_agent", "ZeroShotAgent", diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 538793b8..2de34c1e 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -1,13 +1,12 @@ """Chain that takes in an input and produces an action and action input.""" from __future__ import annotations -from abc import ABC, abstractmethod -from typing import Any, ClassVar, Dict, List, NamedTuple, Optional, Tuple, Union +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Tuple, Union from pydantic import BaseModel, root_validator import langchain -from langchain.agents.input import ChainedInput from langchain.agents.tools import Tool from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -17,7 +16,7 @@ from langchain.prompts.base import BasePromptTemplate from langchain.schema import AgentAction, AgentFinish -class Planner(BaseModel): +class Agent(BaseModel): """Class responsible for calling the language model and deciding the action. This is driven by an LLMChain. The prompt in the LLMChain MUST include @@ -72,6 +71,7 @@ class Planner(BaseModel): return AgentAction(tool, tool_input, full_output) def prepare_for_new_call(self) -> None: + """Prepare the agent for new call, if needed.""" pass @property @@ -107,8 +107,8 @@ class Planner(BaseModel): def llm_prefix(self) -> str: """Prefix to append the LLM call with.""" - @abstractmethod @classmethod + @abstractmethod def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate: """Create a prompt for this class.""" @@ -118,16 +118,17 @@ class Planner(BaseModel): pass @classmethod - def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> Planner: + def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> Agent: """Construct an agent from an LLM and tools.""" cls._validate_tools(tools) llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools)) return cls(llm_chain=llm_chain) -class Agent(Chain, BaseModel): +class AgentWithTools(Chain, BaseModel): + """Consists of an agent using tools.""" - planner: Planner + agent: Agent tools: List[Tool] return_intermediate_steps: bool = False @@ -137,7 +138,7 @@ class Agent(Chain, BaseModel): :meta private: """ - return self.planner.input_keys + return self.agent.input_keys @property def output_keys(self) -> List[str]: @@ -146,26 +147,25 @@ class Agent(Chain, BaseModel): :meta private: """ if self.return_intermediate_steps: - return self.planner.return_values + ["intermediate_steps"] + return self.agent.return_values + ["intermediate_steps"] else: - return self.planner.return_values + return self.agent.return_values def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: """Run text through and get agent response.""" # Do any preparation necessary when receiving a new input. - self.planner.prepare_for_new_call() + self.agent.prepare_for_new_call() # Construct a mapping of tool name to tool for easy lookup name_to_tool_map = {tool.name: tool.func for tool in self.tools} # We construct a mapping from each tool to a color, used for logging. color_mapping = get_color_mapping( [tool.name for tool in self.tools], excluded_colors=["green"] ) - planner_inputs = inputs.copy() intermediate_steps: List[Tuple[AgentAction, str]] = [] # We now enter the agent loop (until it returns something). while True: # Call the LLM to see what to do. - output = self.planner.plan(intermediate_steps, **planner_inputs) + output = self.agent.plan(intermediate_steps, **inputs) # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): if self.verbose: diff --git a/langchain/agents/input.py b/langchain/agents/input.py deleted file mode 100644 index 24a2e41d..00000000 --- a/langchain/agents/input.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Input manager for agents.""" -from typing import List, Optional - -import langchain -from langchain.schema import AgentAction - - -class ChainedInput: - """Class for working with input that is the result of chains.""" - - def __init__(self, text: str, verbose: bool = False): - """Initialize with verbose flag and initial text.""" - self._verbose = verbose - if self._verbose: - langchain.logger.log_agent_start(text) - self._input = text - self._intermediate_actions: List[AgentAction] = [] - self._intermediate_observations: List[str] = [] - - @property - def intermediate_steps(self) -> List: - """Return intermediate steps the agent took.""" - steps = [] - for i, action in enumerate(self._intermediate_actions): - step = { - "log": action.log, - "tool": action.tool, - "tool_input": action.tool_input, - "observation": self._intermediate_observations[i], - } - steps.append(step) - return steps - - def add_action(self, action: AgentAction, color: Optional[str] = None) -> None: - """Add text to input, print if in verbose mode.""" - - self._input += action.log - self._intermediate_actions.append(action) - - def add_observation( - self, - observation: str, - observation_prefix: str, - llm_prefix: str, - color: Optional[str], - ) -> None: - """Add observation to input, print if in verbose mode.""" - if self._verbose: - langchain.logger.log_agent_observation( - observation, - color=color, - observation_prefix=observation_prefix, - llm_prefix=llm_prefix, - ) - self._input += f"\n{observation_prefix}{observation}\n{llm_prefix}" - self._intermediate_observations.append(observation) - - @property - def input(self) -> str: - """Return the accumulated input.""" - return self._input diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index 3d1f9205..04c430a3 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -1,17 +1,17 @@ """Load agent.""" from typing import Any, List -from langchain.agents.agent import Agent, Planner -from langchain.agents.mrkl.base import ZeroShotPlanner -from langchain.agents.react.base import ReActDocstorePlanner -from langchain.agents.self_ask_with_search.base import SelfAskWithSearchPlanner +from langchain.agents.agent import AgentWithTools +from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.agents.react.base import ReActDocstoreAgent +from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.tools import Tool from langchain.llms.base import LLM AGENT_TO_CLASS = { - "zero-shot-react-description": ZeroShotPlanner, - "react-docstore": ReActDocstorePlanner, - "self-ask-with-search": SelfAskWithSearchPlanner, + "zero-shot-react-description": ZeroShotAgent, + "react-docstore": ReActDocstoreAgent, + "self-ask-with-search": SelfAskWithSearchAgent, } @@ -20,7 +20,7 @@ def initialize_agent( llm: LLM, agent: str = "zero-shot-react-description", **kwargs: Any, -) -> Agent: +) -> AgentWithTools: """Load agent given tools and LLM. Args: @@ -39,5 +39,5 @@ def initialize_agent( f"Valid types are: {AGENT_TO_CLASS.keys()}." ) agent_cls = AGENT_TO_CLASS[agent] - planner = agent_cls.from_llm_and_tools(llm, tools, **kwargs) - return Agent(planner=planner, tools=tools) + agent_obj = agent_cls.from_llm_and_tools(llm, tools) + return AgentWithTools(agent=agent_obj, tools=tools, **kwargs) diff --git a/langchain/agents/mrkl/base.py b/langchain/agents/mrkl/base.py index 701aaf2d..4351cced 100644 --- a/langchain/agents/mrkl/base.py +++ b/langchain/agents/mrkl/base.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Any, Callable, List, NamedTuple, Optional, Tuple -from langchain.agents.agent import Agent, Planner +from langchain.agents.agent import Agent, AgentWithTools from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.tools import Tool from langchain.llms.base import LLM @@ -47,7 +47,7 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]: return action, action_input.strip(" ").strip('"') -class ZeroShotPlanner(Planner): +class ZeroShotAgent(Agent): """Agent for the MRKL chain.""" @property @@ -101,10 +101,7 @@ class ZeroShotPlanner(Planner): return get_action_and_input(text) -ZeroShotAgent = ZeroShotPlanner - - -class MRKLChain(Agent): +class MRKLChain(AgentWithTools): """Chain that implements the MRKL system. Example: @@ -119,7 +116,9 @@ class MRKLChain(Agent): """ @classmethod - def from_chains(cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any) -> Agent: + def from_chains( + cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any + ) -> AgentWithTools: """User friendly way to initialize the MRKL chain. This is intended to be an easy way to get up and running with the @@ -159,5 +158,5 @@ class MRKLChain(Agent): Tool(name=c.action_name, func=c.action, description=c.action_description) for c in chains ] - planner = ZeroShotPlanner.from_llm_and_tools(llm, tools) - return cls(planner=planner, tools=tools, **kwargs) + agent = ZeroShotAgent.from_llm_and_tools(llm, tools) + return cls(agent=agent, tools=tools, **kwargs) diff --git a/langchain/agents/react/base.py b/langchain/agents/react/base.py index d30cc4fc..1e3a1acf 100644 --- a/langchain/agents/react/base.py +++ b/langchain/agents/react/base.py @@ -1,21 +1,20 @@ """Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf.""" import re -from typing import Any, ClassVar, List, Optional, Tuple +from typing import Any, List, Optional, Tuple from pydantic import BaseModel -from langchain.agents.agent import Agent, Planner +from langchain.agents.agent import Agent, AgentWithTools from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT from langchain.agents.react.wiki_prompt import WIKI_PROMPT from langchain.agents.tools import Tool -from langchain.chains.llm import LLMChain from langchain.docstore.base import Docstore from langchain.docstore.document import Document from langchain.llms.base import LLM from langchain.prompts.base import BasePromptTemplate -class ReActDocstorePlanner(Planner, BaseModel): +class ReActDocstoreAgent(Agent, BaseModel): """Agent for the ReAct chin.""" @classmethod @@ -75,9 +74,6 @@ class ReActDocstorePlanner(Planner, BaseModel): return f"Thought {self.i}:" -ReActDocstoreAgent = ReActDocstorePlanner - - class DocstoreExplorer: """Class to assist with exploration of a document store.""" @@ -103,7 +99,7 @@ class DocstoreExplorer: return self.document.lookup(term) -class ReActTextWorldPlanner(ReActDocstorePlanner, BaseModel): +class ReActTextWorldAgent(ReActDocstoreAgent, BaseModel): """Agent for the ReAct TextWorld chain.""" @classmethod @@ -120,10 +116,7 @@ class ReActTextWorldPlanner(ReActDocstorePlanner, BaseModel): raise ValueError(f"Tool name should be Play, got {tool_names}") -ReActTextWorldAgent = ReActTextWorldPlanner - - -class ReActChain(Agent): +class ReActChain(AgentWithTools): """Chain that implements the ReAct paper. Example: @@ -140,5 +133,5 @@ class ReActChain(Agent): Tool(name="Search", func=docstore_explorer.search), Tool(name="Lookup", func=docstore_explorer.lookup), ] - planner = ReActDocstorePlanner.from_llm_and_tools(llm, tools) - super().__init__(planner=planner, tools=tools, **kwargs) + agent = ReActDocstoreAgent.from_llm_and_tools(llm, tools) + super().__init__(agent=agent, tools=tools, **kwargs) diff --git a/langchain/agents/self_ask_with_search/base.py b/langchain/agents/self_ask_with_search/base.py index 5ccbfad8..de7e17a0 100644 --- a/langchain/agents/self_ask_with_search/base.py +++ b/langchain/agents/self_ask_with_search/base.py @@ -1,7 +1,7 @@ """Chain that does self ask with search.""" from typing import Any, List, Optional, Tuple -from langchain.agents.agent import Agent, Planner +from langchain.agents.agent import Agent, AgentWithTools from langchain.agents.self_ask_with_search.prompt import PROMPT from langchain.agents.tools import Tool from langchain.llms.base import LLM @@ -9,7 +9,7 @@ from langchain.prompts.base import BasePromptTemplate from langchain.serpapi import SerpAPIWrapper -class SelfAskWithSearchPlanner(Planner): +class SelfAskWithSearchAgent(Agent): """Agent for the self-ask-with-search paper.""" @classmethod @@ -63,10 +63,7 @@ class SelfAskWithSearchPlanner(Planner): return "Are follow up questions needed here:" -SelfAskWithSearchAgent = SelfAskWithSearchPlanner - - -class SelfAskWithSearchChain(Agent): +class SelfAskWithSearchChain(AgentWithTools): """Chain that does self ask with search. Example: @@ -80,5 +77,5 @@ class SelfAskWithSearchChain(Agent): def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any): """Initialize with just an LLM and a search chain.""" search_tool = Tool(name="Intermediate Answer", func=search_chain.run) - planner = SelfAskWithSearchPlanner.from_llm_and_tools(llm, [search_tool]) - super().__init__(planner=planner, tools=[search_tool], **kwargs) + agent = SelfAskWithSearchAgent.from_llm_and_tools(llm, [search_tool]) + super().__init__(agent=agent, tools=[search_tool], **kwargs) diff --git a/tests/unit_tests/test_input.py b/tests/unit_tests/test_input.py deleted file mode 100644 index 5d26535f..00000000 --- a/tests/unit_tests/test_input.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Test input manipulating logic.""" - -import sys -from io import StringIO - -from langchain.agents.input import ChainedInput -from langchain.input import get_color_mapping - - -def test_chained_input_not_verbose() -> None: - """Test chained input logic.""" - old_stdout = sys.stdout - sys.stdout = mystdout = StringIO() - chained_input = ChainedInput("foo") - sys.stdout = old_stdout - output = mystdout.getvalue() - assert output == "" - assert chained_input.input == "foo" - - old_stdout = sys.stdout - sys.stdout = mystdout = StringIO() - chained_input.add_observation("bar", "1", "2", None) - sys.stdout = old_stdout - output = mystdout.getvalue() - assert output == "" - assert chained_input.input == "foo\n1bar\n2" - - -def test_chained_input_verbose() -> None: - """Test chained input logic, making sure verbose doesn't mess it up.""" - old_stdout = sys.stdout - sys.stdout = mystdout = StringIO() - chained_input = ChainedInput("foo", verbose=True) - sys.stdout = old_stdout - output = mystdout.getvalue() - assert output == "foo" - assert chained_input.input == "foo" - - old_stdout = sys.stdout - sys.stdout = mystdout = StringIO() - chained_input.add_observation("bar", "1", "2", None) - sys.stdout = old_stdout - output = mystdout.getvalue() - assert output == "\n1bar\n2" - assert chained_input.input == "foo\n1bar\n2" - - old_stdout = sys.stdout - sys.stdout = mystdout = StringIO() - chained_input.add_observation("baz", "3", "4", "blue") - sys.stdout = old_stdout - output = mystdout.getvalue() - assert output == "\n3\x1b[36;1m\x1b[1;3mbaz\x1b[0m\n4" - assert chained_input.input == "foo\n1bar\n2\n3baz\n4" - - -def test_get_color_mapping() -> None: - """Test getting of color mapping.""" - # Test on few inputs. - items = ["foo", "bar"] - output = get_color_mapping(items) - expected_output = {"foo": "blue", "bar": "yellow"} - assert output == expected_output - - # Test on a lot of inputs. - items = [f"foo-{i}" for i in range(20)] - output = get_color_mapping(items) - assert len(output) == 20 - - -def test_get_color_mapping_excluded_colors() -> None: - """Test getting of color mapping with excluded colors.""" - items = ["foo", "bar"] - output = get_color_mapping(items, excluded_colors=["blue"]) - expected_output = {"foo": "yellow", "bar": "pink"} - assert output == expected_output