From a9e637b8f5a8873a8316ef6bf809591dc1f57c09 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 5 Apr 2023 15:28:48 -0700 Subject: [PATCH] rfc: multi action agent (#2362) --- .../agents/custom_multi_action_agent.ipynb | 217 ++++++++++++++ langchain/agents/__init__.py | 2 + langchain/agents/agent.py | 275 ++++++++++++++---- 3 files changed, 435 insertions(+), 59 deletions(-) create mode 100644 docs/modules/agents/agents/custom_multi_action_agent.ipynb diff --git a/docs/modules/agents/agents/custom_multi_action_agent.ipynb b/docs/modules/agents/agents/custom_multi_action_agent.ipynb new file mode 100644 index 00000000..ef6e9eda --- /dev/null +++ b/docs/modules/agents/agents/custom_multi_action_agent.ipynb @@ -0,0 +1,217 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ba5f8741", + "metadata": {}, + "source": [ + "# Custom MultiAction Agent\n", + "\n", + "This notebook goes through how to create your own custom agent.\n", + "\n", + "An agent consists of three parts:\n", + " \n", + " - Tools: The tools the agent has available to use.\n", + " - The agent class itself: this decides which action to take.\n", + " \n", + " \n", + "In this notebook we walk through how to create a custom agent that predicts/takes multiple steps at a time." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9af9734e", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.agents import Tool, AgentExecutor, BaseMultiActionAgent\n", + "from langchain import OpenAI, SerpAPIWrapper" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "d7c4ebdc", + "metadata": {}, + "outputs": [], + "source": [ + "def random_word(query: str) -> str:\n", + " print(\"\\nNow I'm doing this!\")\n", + " return \"foo\"" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "becda2a1", + "metadata": {}, + "outputs": [], + "source": [ + "search = SerpAPIWrapper()\n", + "tools = [\n", + " Tool(\n", + " name = \"Search\",\n", + " func=search.run,\n", + " description=\"useful for when you need to answer questions about current events\"\n", + " ),\n", + " Tool(\n", + " name = \"RandomWord\",\n", + " func=random_word,\n", + " description=\"call this to get a random word.\"\n", + " \n", + " )\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a33e2f7e", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import List, Tuple, Any, Union\n", + "from langchain.schema import AgentAction, AgentFinish\n", + "\n", + "class FakeAgent(BaseMultiActionAgent):\n", + " \"\"\"Fake Custom Agent.\"\"\"\n", + " \n", + " @property\n", + " def input_keys(self):\n", + " return [\"input\"]\n", + " \n", + " def plan(\n", + " self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any\n", + " ) -> Union[List[AgentAction], AgentFinish]:\n", + " \"\"\"Given input, decided what to do.\n", + "\n", + " Args:\n", + " intermediate_steps: Steps the LLM has taken to date,\n", + " along with observations\n", + " **kwargs: User inputs.\n", + "\n", + " Returns:\n", + " Action specifying what tool to use.\n", + " \"\"\"\n", + " if len(intermediate_steps) == 0:\n", + " return [\n", + " AgentAction(tool=\"Search\", tool_input=\"foo\", log=\"\"),\n", + " AgentAction(tool=\"RandomWord\", tool_input=\"foo\", log=\"\"),\n", + " ]\n", + " else:\n", + " return AgentFinish(return_values={\"output\": \"bar\"}, log=\"\")\n", + "\n", + " async def aplan(\n", + " self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any\n", + " ) -> Union[List[AgentAction], AgentFinish]:\n", + " \"\"\"Given input, decided what to do.\n", + "\n", + " Args:\n", + " intermediate_steps: Steps the LLM has taken to date,\n", + " along with observations\n", + " **kwargs: User inputs.\n", + "\n", + " Returns:\n", + " Action specifying what tool to use.\n", + " \"\"\"\n", + " if len(intermediate_steps) == 0:\n", + " return [\n", + " AgentAction(tool=\"Search\", tool_input=\"foo\", log=\"\"),\n", + " AgentAction(tool=\"RandomWord\", tool_input=\"foo\", log=\"\"),\n", + " ]\n", + " else:\n", + " return AgentFinish(return_values={\"output\": \"bar\"}, log=\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "655d72f6", + "metadata": {}, + "outputs": [], + "source": [ + "agent = FakeAgent()" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "490604e9", + "metadata": {}, + "outputs": [], + "source": [ + "agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "653b1617", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[36;1m\u001b[1;3mFoo Fighters is an American rock band formed in Seattle in 1994. Foo Fighters was initially formed as a one-man project by former Nirvana drummer Dave Grohl. Following the success of the 1995 eponymous debut album, Grohl recruited a band consisting of Nate Mendel, William Goldsmith, and Pat Smear.\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n", + "Now I'm doing this!\n", + "\u001b[33;1m\u001b[1;3mfoo\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'bar'" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.run(\"How many people live in canada as of 2023?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "adefb4c2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + }, + "vscode": { + "interpreter": { + "hash": "18784188d7ecd866c0586ac068b02361a6896dc3a29b64f5cc957f09c590acef" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/agents/__init__.py b/langchain/agents/__init__.py index e5f12e07..bb5b71e4 100644 --- a/langchain/agents/__init__.py +++ b/langchain/agents/__init__.py @@ -3,6 +3,7 @@ from langchain.agents.agent import ( Agent, AgentExecutor, AgentOutputParser, + BaseMultiActionAgent, BaseSingleActionAgent, LLMSingleActionAgent, ) @@ -53,4 +54,5 @@ __all__ = [ "AgentOutputParser", "BaseSingleActionAgent", "AgentType", + "BaseMultiActionAgent", ] diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index aed79dba..e6c5ce58 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -142,6 +142,118 @@ class BaseSingleActionAgent(BaseModel): return {} +class BaseMultiActionAgent(BaseModel): + """Base Agent class.""" + + @property + def return_values(self) -> List[str]: + """Return values of the agent.""" + return ["output"] + + def get_allowed_tools(self) -> Optional[List[str]]: + return None + + @abstractmethod + def plan( + self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + ) -> Union[List[AgentAction], AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + **kwargs: User inputs. + + Returns: + Actions specifying what tool to use. + """ + + @abstractmethod + async def aplan( + self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + ) -> Union[List[AgentAction], AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + **kwargs: User inputs. + + Returns: + Actions specifying what tool to use. + """ + + @property + @abstractmethod + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + + def return_stopped_response( + self, + early_stopping_method: str, + intermediate_steps: List[Tuple[AgentAction, str]], + **kwargs: Any, + ) -> AgentFinish: + """Return response when agent has been stopped due to max iterations.""" + if early_stopping_method == "force": + # `force` just returns a constant string + return AgentFinish({"output": "Agent stopped due to max iterations."}, "") + else: + raise ValueError( + f"Got unsupported early_stopping_method `{early_stopping_method}`" + ) + + @property + def _agent_type(self) -> str: + """Return Identifier of agent type.""" + raise NotImplementedError + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of agent.""" + _dict = super().dict() + _dict["_type"] = self._agent_type + return _dict + + def save(self, file_path: Union[Path, str]) -> None: + """Save the agent. + + Args: + file_path: Path to file to save the agent to. + + Example: + .. code-block:: python + + # If working with agent executor + agent.agent.save(file_path="path/agent.yaml") + """ + # Convert file to Path object. + if isinstance(file_path, str): + save_path = Path(file_path) + else: + save_path = file_path + + directory_path = save_path.parent + directory_path.mkdir(parents=True, exist_ok=True) + + # Fetch dictionary to save + agent_dict = self.dict() + + if save_path.suffix == ".json": + with open(file_path, "w") as f: + json.dump(agent_dict, f, indent=4) + elif save_path.suffix == ".yaml": + with open(file_path, "w") as f: + yaml.dump(agent_dict, f, default_flow_style=False) + else: + raise ValueError(f"{save_path} must be json or yaml") + + def tool_run_logging_kwargs(self) -> Dict: + return {} + + class AgentOutputParser(BaseOutputParser): @abstractmethod def parse(self, text: str) -> Union[AgentAction, AgentFinish]: @@ -439,7 +551,7 @@ class Agent(BaseSingleActionAgent): class AgentExecutor(Chain, BaseModel): """Consists of an agent using tools.""" - agent: BaseSingleActionAgent + agent: Union[BaseSingleActionAgent, BaseMultiActionAgent] tools: Sequence[BaseTool] return_intermediate_steps: bool = False max_iterations: Optional[int] = 15 @@ -448,7 +560,7 @@ class AgentExecutor(Chain, BaseModel): @classmethod def from_agent_and_tools( cls, - agent: BaseSingleActionAgent, + agent: Union[BaseSingleActionAgent, BaseMultiActionAgent], tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, @@ -472,6 +584,20 @@ class AgentExecutor(Chain, BaseModel): ) return values + @root_validator() + def validate_return_direct_tool(cls, values: Dict) -> Dict: + """Validate that tools are compatible with agent.""" + agent = values["agent"] + tools = values["tools"] + if isinstance(agent, BaseMultiActionAgent): + for tool in tools: + if tool.return_direct: + raise ValueError( + "Tools that have `return_direct=True` are not allowed " + "in multi-action agents" + ) + return values + def save(self, file_path: Union[Path, str]) -> None: """Raise error - saving not supported for Agent Executors.""" raise ValueError( @@ -544,7 +670,7 @@ class AgentExecutor(Chain, BaseModel): color_mapping: Dict[str, str], inputs: Dict[str, str], intermediate_steps: List[Tuple[AgentAction, str]], - ) -> Union[AgentFinish, Tuple[AgentAction, str]]: + ) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: """Take a single step in the thought-action-observation loop. Override this to take control of how the agent makes and acts on choices. @@ -554,27 +680,41 @@ class AgentExecutor(Chain, BaseModel): # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): return output - self.callback_manager.on_agent_action( - output, verbose=self.verbose, color="green" - ) - # Otherwise we lookup the tool - if output.tool in name_to_tool_map: - tool = name_to_tool_map[output.tool] - return_direct = tool.return_direct - color = color_mapping[output.tool] - tool_run_kwargs = self.agent.tool_run_logging_kwargs() - if return_direct: - tool_run_kwargs["llm_prefix"] = "" - # We then call the tool on the tool input to get an observation - observation = tool.run( - output.tool_input, verbose=self.verbose, color=color, **tool_run_kwargs - ) + actions: List[AgentAction] + if isinstance(output, AgentAction): + actions = [output] else: - tool_run_kwargs = self.agent.tool_run_logging_kwargs() - observation = InvalidTool().run( - output.tool, verbose=self.verbose, color=None, **tool_run_kwargs + actions = output + result = [] + for agent_action in actions: + self.callback_manager.on_agent_action( + agent_action, verbose=self.verbose, color="green" ) - return output, observation + # Otherwise we lookup the tool + if agent_action.tool in name_to_tool_map: + tool = name_to_tool_map[agent_action.tool] + return_direct = tool.return_direct + color = color_mapping[agent_action.tool] + tool_run_kwargs = self.agent.tool_run_logging_kwargs() + if return_direct: + tool_run_kwargs["llm_prefix"] = "" + # We then call the tool on the tool input to get an observation + observation = tool.run( + agent_action.tool_input, + verbose=self.verbose, + color=color, + **tool_run_kwargs, + ) + else: + tool_run_kwargs = self.agent.tool_run_logging_kwargs() + observation = InvalidTool().run( + agent_action.tool, + verbose=self.verbose, + color=None, + **tool_run_kwargs, + ) + result.append((agent_action, observation)) + return result async def _atake_next_step( self, @@ -582,7 +722,7 @@ class AgentExecutor(Chain, BaseModel): color_mapping: Dict[str, str], inputs: Dict[str, str], intermediate_steps: List[Tuple[AgentAction, str]], - ) -> Union[AgentFinish, Tuple[AgentAction, str]]: + ) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: """Take a single step in the thought-action-observation loop. Override this to take control of how the agent makes and acts on choices. @@ -592,34 +732,47 @@ class AgentExecutor(Chain, BaseModel): # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): return output - if self.callback_manager.is_async: - await self.callback_manager.on_agent_action( - output, verbose=self.verbose, color="green" - ) + actions: List[AgentAction] + if isinstance(output, AgentAction): + actions = [output] else: - self.callback_manager.on_agent_action( - output, verbose=self.verbose, color="green" - ) + actions = output + result = [] + for agent_action in actions: + if self.callback_manager.is_async: + await self.callback_manager.on_agent_action( + agent_action, verbose=self.verbose, color="green" + ) + else: + self.callback_manager.on_agent_action( + agent_action, verbose=self.verbose, color="green" + ) + # Otherwise we lookup the tool + if agent_action.tool in name_to_tool_map: + tool = name_to_tool_map[agent_action.tool] + return_direct = tool.return_direct + color = color_mapping[agent_action.tool] + tool_run_kwargs = self.agent.tool_run_logging_kwargs() + if return_direct: + tool_run_kwargs["llm_prefix"] = "" + # We then call the tool on the tool input to get an observation + observation = await tool.arun( + agent_action.tool_input, + verbose=self.verbose, + color=color, + **tool_run_kwargs, + ) + else: + tool_run_kwargs = self.agent.tool_run_logging_kwargs() + observation = await InvalidTool().arun( + agent_action.tool, + verbose=self.verbose, + color=None, + **tool_run_kwargs, + ) + result.append((agent_action, observation)) - # Otherwise we lookup the tool - if output.tool in name_to_tool_map: - tool = name_to_tool_map[output.tool] - return_direct = tool.return_direct - color = color_mapping[output.tool] - tool_run_kwargs = self.agent.tool_run_logging_kwargs() - if return_direct: - tool_run_kwargs["llm_prefix"] = "" - # We then call the tool on the tool input to get an observation - observation = await tool.arun( - output.tool_input, verbose=self.verbose, color=color, **tool_run_kwargs - ) - else: - tool_run_kwargs = self.agent.tool_run_logging_kwargs() - observation = await InvalidTool().arun( - output.tool, verbose=self.verbose, color=None, **tool_run_kwargs - ) - return_direct = False - return output, observation + return result def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: """Run text through and get agent response.""" @@ -640,11 +793,13 @@ class AgentExecutor(Chain, BaseModel): if isinstance(next_step_output, AgentFinish): return self._return(next_step_output, intermediate_steps) - intermediate_steps.append(next_step_output) - # See if tool should return directly - tool_return = self._get_tool_return(next_step_output) - if tool_return is not None: - return self._return(tool_return, intermediate_steps) + intermediate_steps.extend(next_step_output) + if len(next_step_output) == 1: + next_step_action = next_step_output[0] + # See if tool should return directly + tool_return = self._get_tool_return(next_step_action) + if tool_return is not None: + return self._return(tool_return, intermediate_steps) iterations += 1 output = self.agent.return_stopped_response( self.early_stopping_method, intermediate_steps, **inputs @@ -670,11 +825,13 @@ class AgentExecutor(Chain, BaseModel): if isinstance(next_step_output, AgentFinish): return await self._areturn(next_step_output, intermediate_steps) - intermediate_steps.append(next_step_output) - # See if tool should return directly - tool_return = self._get_tool_return(next_step_output) - if tool_return is not None: - return await self._areturn(tool_return, intermediate_steps) + intermediate_steps.extend(next_step_output) + if len(next_step_output) == 1: + next_step_action = next_step_output[0] + # See if tool should return directly + tool_return = self._get_tool_return(next_step_action) + if tool_return is not None: + return await self._areturn(tool_return, intermediate_steps) iterations += 1 output = self.agent.return_stopped_response(