{ "cells": [ { "cell_type": "markdown", "id": "ba5f8741", "metadata": {}, "source": [ "# Custom multi-action agent\n", "\n", "This notebook goes through how to create your own custom agent.\n", "\n", "An agent consists of two parts:\n", "\n", "- Tools: The tools the agent has available to use.\n", "- The agent class itself: this decides which action to take.\n", " \n", " \n", "In this notebook we walk through how to create a custom agent that predicts/takes multiple steps at a time." ] }, { "cell_type": "code", "execution_count": 1, "id": "9af9734e", "metadata": {}, "outputs": [], "source": [ "from langchain.agents import AgentExecutor, BaseMultiActionAgent, Tool\n", "from langchain_community.utilities import SerpAPIWrapper" ] }, { "cell_type": "code", "execution_count": 2, "id": "d7c4ebdc", "metadata": {}, "outputs": [], "source": [ "def random_word(query: str) -> str:\n", " print(\"\\nNow I'm doing this!\")\n", " return \"foo\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "becda2a1", "metadata": {}, "outputs": [], "source": [ "search = SerpAPIWrapper()\n", "tools = [\n", " Tool(\n", " name=\"Search\",\n", " func=search.run,\n", " description=\"useful for when you need to answer questions about current events\",\n", " ),\n", " Tool(\n", " name=\"RandomWord\",\n", " func=random_word,\n", " description=\"call this to get a random word.\",\n", " ),\n", "]" ] }, { "cell_type": "code", "execution_count": 4, "id": "a33e2f7e", "metadata": {}, "outputs": [], "source": [ "from typing import Any, List, Tuple, Union\n", "\n", "from langchain_core.agents import AgentAction, AgentFinish\n", "\n", "\n", "class FakeAgent(BaseMultiActionAgent):\n", " \"\"\"Fake Custom Agent.\"\"\"\n", "\n", " @property\n", " def input_keys(self):\n", " return [\"input\"]\n", "\n", " def plan(\n", " self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any\n", " ) -> Union[List[AgentAction], AgentFinish]:\n", " \"\"\"Given input, decided what to do.\n", "\n", " Args:\n", " intermediate_steps: Steps the LLM has taken to date,\n", " along with observations\n", " **kwargs: User inputs.\n", "\n", " Returns:\n", " Action specifying what tool to use.\n", " \"\"\"\n", " if len(intermediate_steps) == 0:\n", " return [\n", " AgentAction(tool=\"Search\", tool_input=kwargs[\"input\"], log=\"\"),\n", " AgentAction(tool=\"RandomWord\", tool_input=kwargs[\"input\"], log=\"\"),\n", " ]\n", " else:\n", " return AgentFinish(return_values={\"output\": \"bar\"}, log=\"\")\n", "\n", " async def aplan(\n", " self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any\n", " ) -> Union[List[AgentAction], AgentFinish]:\n", " \"\"\"Given input, decided what to do.\n", "\n", " Args:\n", " intermediate_steps: Steps the LLM has taken to date,\n", " along with observations\n", " **kwargs: User inputs.\n", "\n", " Returns:\n", " Action specifying what tool to use.\n", " \"\"\"\n", " if len(intermediate_steps) == 0:\n", " return [\n", " AgentAction(tool=\"Search\", tool_input=kwargs[\"input\"], log=\"\"),\n", " AgentAction(tool=\"RandomWord\", tool_input=kwargs[\"input\"], log=\"\"),\n", " ]\n", " else:\n", " return AgentFinish(return_values={\"output\": \"bar\"}, log=\"\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "655d72f6", "metadata": {}, "outputs": [], "source": [ "agent = FakeAgent()" ] }, { "cell_type": "code", "execution_count": 6, "id": "490604e9", "metadata": {}, "outputs": [], "source": [ "agent_executor = AgentExecutor.from_agent_and_tools(\n", " agent=agent, tools=tools, verbose=True\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "id": "653b1617", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[36;1m\u001b[1;3mThe current population of Canada is 38,669,152 as of Monday, April 24, 2023, based on Worldometer elaboration of the latest United Nations data.\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n", "Now I'm doing this!\n", "\u001b[33;1m\u001b[1;3mfoo\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ "'bar'" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "agent_executor.run(\"How many people live in canada as of 2023?\")" ] }, { "cell_type": "code", "execution_count": null, "id": "adefb4c2", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" }, "vscode": { "interpreter": { "hash": "18784188d7ecd866c0586ac068b02361a6896dc3a29b64f5cc957f09c590acef" } } }, "nbformat": 4, "nbformat_minor": 5 }