diff --git a/docs/modules/agents/tools/custom_tools.ipynb b/docs/modules/agents/tools/custom_tools.ipynb index f69e2224..4f57597e 100644 --- a/docs/modules/agents/tools/custom_tools.ipynb +++ b/docs/modules/agents/tools/custom_tools.ipynb @@ -9,28 +9,29 @@ "\n", "When constructing your own agent, you will need to provide it with a list of Tools that it can use. Besides the actual function that is called, the Tool consists of several components:\n", "\n", - "- name (str), is required\n", - "- description (str), is optional\n", + "- name (str), is required and must be unique within a set of tools provided to an agent\n", + "- description (str), is optional but recommended, as it is used by an agent to determine tool use\n", "- return_direct (bool), defaults to False\n", "\n", - "The function that should be called when the tool is selected should take as input a single string and return a single string.\n", + "The function that should be called when the tool is selected should return a single string.\n", "\n", "There are two ways to define a tool, we will cover both in the example below." ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "1aaba18c", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# Import things that are needed generically\n", - "from langchain.agents import initialize_agent, Tool\n", - "from langchain.agents import AgentType\n", - "from langchain.tools import BaseTool\n", - "from langchain.llms import OpenAI\n", - "from langchain import LLMMathChain, SerpAPIWrapper" + "from langchain import LLMMathChain, SerpAPIWrapper\n", + "from langchain.agents import AgentType, Tool, initialize_agent, tool\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.tools import BaseTool" ] }, { @@ -43,12 +44,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "36ed392e", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ - "llm = OpenAI(temperature=0)" + "llm = ChatOpenAI(temperature=0)" ] }, { @@ -74,7 +77,9 @@ "cell_type": "code", "execution_count": 3, "id": "56ff7670", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# Load the tool configs that are needed.\n", @@ -98,7 +103,9 @@ "cell_type": "code", "execution_count": 4, "id": "5b93047d", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# Construct the agent. We will use the default agent type here.\n", @@ -110,7 +117,9 @@ "cell_type": "code", "execution_count": 5, "id": "6f96a891", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "name": "stdout", @@ -119,29 +128,24 @@ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n", + "\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n", + "Action: Search\n", + "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel – Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Gigi Hadid's age\n", "Action: Search\n", - "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mCamila Morrone\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now need to calculate her age raised to the 0.43 power\n", + "Action Input: \"Gigi Hadid age\"\u001b[0m\u001b[36;1m\u001b[1;3m27 years\u001b[0m\u001b[32;1m\u001b[1;3mI need to calculate her age raised to the 0.43 power\n", "Action: Calculator\n", - "Action Input: 22^0.43\u001b[0m\n", + "Action Input: 27^(0.43)\u001b[0m\n", "\n", "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", - "22^0.43\u001b[32;1m\u001b[1;3m\n", - "```python\n", - "import math\n", - "print(math.pow(22, 0.43))\n", + "27^(0.43)\u001b[32;1m\u001b[1;3m```text\n", + "27**(0.43)\n", "```\n", + "...numexpr.evaluate(\"27**(0.43)\")...\n", "\u001b[0m\n", - "Answer: \u001b[33;1m\u001b[1;3m3.777824273683966\n", - "\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3m4.125593352125936\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "\n", - "Observation: \u001b[33;1m\u001b[1;3mAnswer: 3.777824273683966\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\u001b[0m\n", + "\u001b[33;1m\u001b[1;3mAnswer: 4.125593352125936\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n", + "Final Answer: 4.125593352125936\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -149,7 +153,7 @@ { "data": { "text/plain": [ - "\"Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\"" + "'4.125593352125936'" ] }, "execution_count": 5, @@ -171,9 +175,11 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "id": "c58a7c40", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "class CustomSearchTool(BaseTool):\n", @@ -203,9 +209,11 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "id": "3318a46f", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "tools = [CustomSearchTool(), CustomCalculatorTool()]" @@ -213,9 +221,11 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "id": "ee2d0f3a", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)" @@ -223,9 +233,11 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "id": "6a2cebbf", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "name": "stdout", @@ -234,29 +246,24 @@ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n", + "\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n", + "Action: Search\n", + "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel – Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mI now know Leo DiCaprio's girlfriend's name and that he's currently linked to Gigi Hadid. I need to find out Camila Morrone's age.\n", "Action: Search\n", - "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mCamila Morrone\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now need to calculate her age raised to the 0.43 power\n", + "Action Input: \"Camila Morrone age\"\u001b[0m\u001b[36;1m\u001b[1;3m25 years\u001b[0m\u001b[32;1m\u001b[1;3mI have Camila Morrone's age. I need to calculate her age raised to the 0.43 power.\n", "Action: Calculator\n", - "Action Input: 22^0.43\u001b[0m\n", + "Action Input: 25^(0.43)\u001b[0m\n", "\n", "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", - "22^0.43\u001b[32;1m\u001b[1;3m\n", - "```python\n", - "import math\n", - "print(math.pow(22, 0.43))\n", + "25^(0.43)\u001b[32;1m\u001b[1;3m```text\n", + "25**(0.43)\n", "```\n", + "...numexpr.evaluate(\"25**(0.43)\")...\n", "\u001b[0m\n", - "Answer: \u001b[33;1m\u001b[1;3m3.777824273683966\n", - "\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3m3.991298452658078\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "\n", - "Observation: \u001b[33;1m\u001b[1;3mAnswer: 3.777824273683966\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\u001b[0m\n", + "\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the answer to the original question.\n", + "Final Answer: Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -264,10 +271,10 @@ { "data": { "text/plain": [ - "\"Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\"" + "\"Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\"" ] }, - "execution_count": 11, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -288,9 +295,11 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 10, "id": "8f15307d", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "from langchain.agents import tool\n", @@ -298,22 +307,24 @@ "@tool\n", "def search_api(query: str) -> str:\n", " \"\"\"Searches the API for the query.\"\"\"\n", - " return \"Results\"" + " return f\"Results for query {query}\"" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 11, "id": "0a23b91b", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "data": { "text/plain": [ - "Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', return_direct=False, verbose=False, callback_manager=, func=, coroutine=None)" + "Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', args_schema=, return_direct=False, verbose=False, callback_manager=, func=, coroutine=None)" ] }, - "execution_count": 5, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -332,9 +343,11 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 12, "id": "28cdf04d", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "@tool(\"search\", return_direct=True)\n", @@ -345,17 +358,17 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 13, "id": "1085a4bd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', return_direct=True, verbose=False, callback_manager=, func=, coroutine=None)" + "Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=, return_direct=True, verbose=False, callback_manager=, func=, coroutine=None)" ] }, - "execution_count": 7, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -376,7 +389,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 14, "id": "79213f40", "metadata": {}, "outputs": [], @@ -386,7 +399,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 15, "id": "e1067dcb", "metadata": {}, "outputs": [], @@ -396,7 +409,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 16, "id": "6c66ffe8", "metadata": {}, "outputs": [], @@ -406,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 17, "id": "f45b5bc3", "metadata": {}, "outputs": [], @@ -416,7 +429,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 18, "id": "565e2b9b", "metadata": {}, "outputs": [ @@ -427,21 +440,12 @@ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n", - "Action: Google Search\n", - "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mCamila Morrone\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Camila Morrone's age\n", + "\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age.\n", "Action: Google Search\n", - "Action Input: \"Camila Morrone age\"\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3m25 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 25 raised to the 0.43 power\n", + "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel – Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mNow I need to find out Camila Morrone's current age.\n", "Action: Calculator\n", - "Action Input: 25^0.43\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Camila Morrone is Leo DiCaprio's girlfriend and her current age raised to the 0.43 power is 3.991298452658078.\u001b[0m\n", + "Action Input: 25^0.43\u001b[0m\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer.\n", + "Final Answer: Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -449,10 +453,10 @@ { "data": { "text/plain": [ - "\"Camila Morrone is Leo DiCaprio's girlfriend and her current age raised to the 0.43 power is 3.991298452658078.\"" + "\"Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\"" ] }, - "execution_count": 12, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -478,7 +482,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 19, "id": "3450512e", "metadata": {}, "outputs": [], @@ -507,7 +511,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 20, "id": "4b9a7849", "metadata": {}, "outputs": [ @@ -520,9 +524,7 @@ "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[32;1m\u001b[1;3m I should use a music search engine to find the answer\n", "Action: Music Search\n", - "Action Input: most famous song of christmas\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3m'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Action Input: most famous song of christmas\u001b[0m\u001b[33;1m\u001b[1;3m'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n", "Final Answer: 'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" @@ -534,7 +536,7 @@ "\"'All I Want For Christmas Is You' by Mariah Carey.\"" ] }, - "execution_count": 14, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -554,7 +556,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 21, "id": "3bb6185f", "metadata": {}, "outputs": [], @@ -572,7 +574,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 22, "id": "113ddb84", "metadata": {}, "outputs": [], @@ -583,9 +585,11 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 23, "id": "582439a6", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "name": "stdout", @@ -596,9 +600,7 @@ "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[32;1m\u001b[1;3m I need to calculate this\n", "Action: Calculator\n", - "Action Input: 2**.12\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.2599210498948732\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m\u001b[0m\n", + "Action Input: 2**.12\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 1.086734862526058\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -606,10 +608,10 @@ { "data": { "text/plain": [ - "'Answer: 1.2599210498948732'" + "'Answer: 1.086734862526058'" ] }, - "execution_count": 17, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -618,10 +620,149 @@ "agent.run(\"whats 2**.12\")" ] }, + { + "cell_type": "markdown", + "id": "8aa3c353-bd89-467c-9c27-b83a90cd4daa", + "metadata": {}, + "source": [ + "## Multi-argument tools\n", + "\n", + "Many functions expect structured inputs. These can also be supported using the Tool decorator or by directly subclassing `BaseTool`! We have to modify the LLM's OutputParser to map its string output to a dictionary to pass to the action, however." + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "id": "537bc628", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from typing import Optional, Union\n", + "\n", + "@tool\n", + "def custom_search(k: int, query: str, other_arg: Optional[str] = None):\n", + " \"\"\"The custom search function.\"\"\"\n", + " return f\"Here are the results for the custom search: k={k}, query={query}, other_arg={other_arg}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "d5c992cf-776a-40cd-a6c4-e7cf65ea709e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import re\n", + "from langchain.schema import (\n", + " AgentAction,\n", + " AgentFinish,\n", + ")\n", + "from langchain.agents import AgentOutputParser\n", + "\n", + "# We will add a custom parser to map the arguments to a dictionary\n", + "class CustomOutputParser(AgentOutputParser):\n", + " \n", + " def parse_tool_input(self, action_input: str) -> dict:\n", + " # Regex pattern to match arguments and their values\n", + " pattern = r\"(\\w+)\\s*=\\s*(None|\\\"[^\\\"]*\\\"|\\d+)\"\n", + " matches = re.findall(pattern, action_input)\n", + " \n", + " if not matches:\n", + " raise ValueError(f\"Could not parse action input: `{action_input}`\")\n", + "\n", + " # Create a dictionary with the parsed arguments and their values\n", + " parsed_input = {}\n", + " for arg, value in matches:\n", + " if value == \"None\":\n", + " parsed_value = None\n", + " elif value.isdigit():\n", + " parsed_value = int(value)\n", + " else:\n", + " parsed_value = value.strip('\"')\n", + " parsed_input[arg] = parsed_value\n", + "\n", + " return parsed_input\n", + " \n", + " def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:\n", + " # Check if agent should finish\n", + " if \"Final Answer:\" in llm_output:\n", + " return AgentFinish(\n", + " # Return values is generally always a dictionary with a single `output` key\n", + " # It is not recommended to try anything else at the moment :)\n", + " return_values={\"output\": llm_output.split(\"Final Answer:\")[-1].strip()},\n", + " log=llm_output,\n", + " )\n", + " # Parse out the action and action input\n", + " regex = r\"Action\\s*\\d*\\s*:(.*?)\\nAction\\s*\\d*\\s*Input\\s*\\d*\\s*:[\\s]*(.*)\"\n", + " match = re.search(regex, llm_output, re.DOTALL)\n", + " if not match:\n", + " raise ValueError(f\"Could not parse LLM output: `{llm_output}`\")\n", + " action = match.group(1).strip()\n", + " action_input = match.group(2)\n", + " tool_input = self.parse_tool_input(action_input)\n", + " # Return the action and action \n", + " return AgentAction(tool=action, tool_input=tool_input, log=llm_output)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "68269547-1482-4138-a6ea-58f00b4a9548", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "llm = OpenAI(temperature=0)\n", + "agent = initialize_agent([custom_search], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, agent_kwargs={\"output_parser\": CustomOutputParser()})" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "0947835a-691c-4f51-b8f4-6744e0e48ab1", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I need to use a search function to find the answer\n", + "Action: custom_search\n", + "Action Input: k=1, query=\"me\"\u001b[0m\u001b[36;1m\u001b[1;3mHere are the results for the custom search: k=1, query=me, other_arg=None\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: The results of the custom search for k=1, query=me, other_arg=None.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'The results of the custom search for k=1, query=me, other_arg=None.'" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"Search for me and tell me whatever it says\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "caf39c66-102b-42c1-baf2-777a49886ce4", "metadata": {}, "outputs": [], "source": [] @@ -643,7 +784,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.11.2" }, "vscode": { "interpreter": { diff --git a/docs/modules/agents/tools/tool_input_validation.ipynb b/docs/modules/agents/tools/tool_input_validation.ipynb new file mode 100644 index 00000000..3fe9aed4 --- /dev/null +++ b/docs/modules/agents/tools/tool_input_validation.ipynb @@ -0,0 +1,190 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "# Tool Input Schema\n", + "\n", + "By default, tools infer the argument schema by inspecting the function signature. For more strict requirements, custom input schema can be specified, along with custom validation logic." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.agents import initialize_agent, Tool\n", + "from langchain.agents import AgentType\n", + "from langchain.llms import OpenAI\n", + "from langchain.tools.python.tool import PythonREPLTool\n", + "from pydantic import BaseModel\n", + "from pydantic import Field, root_validator\n", + "from typing import Dict, Any" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "llm = OpenAI(temperature=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class ToolInputModel(BaseModel):\n", + " query: str = Field(...)\n", + " \n", + " @root_validator\n", + " def validate_query(cls, values: Dict[str, Any]) -> Dict:\n", + " # Note: this is NOT a safe REPL! This is used for instructive purposes only\n", + " if \"import os\" in values[\"query\"]:\n", + " raise ValueError(\"'import os' not permitted in this python REPL.\")\n", + " return values\n", + " \n", + "tool = PythonREPLTool(args_schema=ToolInputModel)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "agent = initialize_agent([tool], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I need to define a function that adds two numbers\n", + "Action: Python REPL\n", + "Action Input: def add_two_numbers(a, b):\n", + " return a + b\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3m I need to call the function\n", + "Action: Python REPL\n", + "Action Input: print(add_two_numbers(2, 2))\u001b[0m\u001b[36;1m\u001b[1;3m4\n", + "\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: 4\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'4'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# This will succeed, since there aren't any arguments that will be triggered during validation\n", + "agent.run(\"Run a python function that adds 2 and 2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I need to import os and then list the dir\n", + "Action: Python REPL\n", + "Action Input: import os; print(os.listdir())\u001b[0m" + ] + }, + { + "ename": "ValidationError", + "evalue": "1 validation error for ToolInputModel\n__root__\n 'import os' not (type=value_error)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43magent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mRun a python function that imports os and lists the dir\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:213\u001b[0m, in \u001b[0;36mChain.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(args) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 212\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`run` supports only one positional argument.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 213\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_keys[\u001b[38;5;241m0\u001b[39m]]\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m kwargs \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m args:\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m(kwargs)[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_keys[\u001b[38;5;241m0\u001b[39m]]\n", + "File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:116\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_end(outputs, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprep_outputs(inputs, outputs, return_only_outputs)\n", + "File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:113\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_start(\n\u001b[1;32m 108\u001b[0m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m},\n\u001b[1;32m 109\u001b[0m inputs,\n\u001b[1;32m 110\u001b[0m verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose,\n\u001b[1;32m 111\u001b[0m )\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 113\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n", + "File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:792\u001b[0m, in \u001b[0;36mAgentExecutor._call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# We now enter the agent loop (until it returns something).\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_continue(iterations, time_elapsed):\n\u001b[0;32m--> 792\u001b[0m next_step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_take_next_step\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 793\u001b[0m \u001b[43m \u001b[49m\u001b[43mname_to_tool_map\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcolor_mapping\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mintermediate_steps\u001b[49m\n\u001b[1;32m 794\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 795\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(next_step_output, AgentFinish):\n\u001b[1;32m 796\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_return(next_step_output, intermediate_steps)\n", + "File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:695\u001b[0m, in \u001b[0;36mAgentExecutor._take_next_step\u001b[0;34m(self, name_to_tool_map, color_mapping, inputs, intermediate_steps)\u001b[0m\n\u001b[1;32m 693\u001b[0m tool_run_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mllm_prefix\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 694\u001b[0m \u001b[38;5;66;03m# We then call the tool on the tool input to get an observation\u001b[39;00m\n\u001b[0;32m--> 695\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[43mtool\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 696\u001b[0m \u001b[43m \u001b[49m\u001b[43magent_action\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtool_input\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 697\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 698\u001b[0m \u001b[43m \u001b[49m\u001b[43mcolor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcolor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 699\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtool_run_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 700\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 701\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 702\u001b[0m tool_run_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39magent\u001b[38;5;241m.\u001b[39mtool_run_logging_kwargs()\n", + "File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:146\u001b[0m, in \u001b[0;36mBaseTool.run\u001b[0;34m(self, tool_input, verbose, start_color, color, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrun\u001b[39m(\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 139\u001b[0m tool_input: Union[\u001b[38;5;28mstr\u001b[39m, Dict],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any,\n\u001b[1;32m 144\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[1;32m 145\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Run the tool.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 146\u001b[0m run_input \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_parse_input\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose \u001b[38;5;129;01mand\u001b[39;00m verbose \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 148\u001b[0m verbose_ \u001b[38;5;241m=\u001b[39m verbose\n", + "File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:112\u001b[0m, in \u001b[0;36mBaseTool._parse_input\u001b[0;34m(self, tool_input)\u001b[0m\n\u001b[1;32m 110\u001b[0m tool_input \u001b[38;5;241m=\u001b[39m {field_name: tool_input}\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pydantic_input_type \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 112\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mpydantic_input_type\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse_obj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 115\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124margs_schema required for tool \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m in order to\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m accept input of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(tool_input)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 117\u001b[0m )\n", + "File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:526\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.parse_obj\u001b[0;34m()\u001b[0m\n", + "File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:341\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[0;34m()\u001b[0m\n", + "\u001b[0;31mValidationError\u001b[0m: 1 validation error for ToolInputModel\n__root__\n 'import os' not (type=value_error)" + ] + } + ], + "source": [ + "# This will fail, because the attempt to import os will trigger a validation error\n", + "agent.run(\"Run a python function that imports os and lists the dir\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index 99e9b741..9ae8f698 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -1,25 +1,38 @@ """Interface for tools.""" from inspect import signature -from typing import Any, Awaitable, Callable, Optional, Union +from typing import Any, Awaitable, Callable, Optional, Type, Union -from langchain.tools.base import BaseTool +from pydantic import BaseModel + +from langchain.tools.base import BaseTool, create_args_schema_model_from_signature class Tool(BaseTool): """Tool that takes in function or coroutine directly.""" description: str = "" - func: Callable[[str], str] - coroutine: Optional[Callable[[str], Awaitable[str]]] = None - - def _run(self, tool_input: str) -> str: + func: Callable[..., str] + """The function to run when the tool is called.""" + coroutine: Optional[Callable[..., Awaitable[str]]] = None + """The asynchronous version of the function.""" + + @property + def args(self) -> Type[BaseModel]: + """Generate an input pydantic model.""" + if self.args_schema is not None: + return self.args_schema + # Infer the schema directly from the function to add more structured + # arguments. + return create_args_schema_model_from_signature(self.func) + + def _run(self, *args: Any, **kwargs: Any) -> str: """Use the tool.""" - return self.func(tool_input) + return self.func(*args, **kwargs) - async def _arun(self, tool_input: str) -> str: + async def _arun(self, *args: Any, **kwargs: Any) -> str: """Use the tool asynchronously.""" if self.coroutine: - return await self.coroutine(tool_input) + return await self.coroutine(*args, **kwargs) raise NotImplementedError("Tool does not support async") # TODO: this is for backwards compatibility, remove in future @@ -69,14 +82,16 @@ def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable: """ def _make_with_name(tool_name: str) -> Callable: - def _make_tool(func: Callable[[str], str]) -> Tool: + def _make_tool(func: Callable) -> Tool: assert func.__doc__, "Function must have a docstring" # Description example: - # search_api(query: str) - Searches the API for the query. + # search_api(query: str) - Searches the API for the query. description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}" + args_schema = create_args_schema_model_from_signature(func) tool_ = Tool( name=tool_name, func=func, + args_schema=args_schema, description=description, return_direct=return_direct, ) diff --git a/langchain/callbacks/comet_ml_callback.py b/langchain/callbacks/comet_ml_callback.py index 6f061f14..f057cea8 100644 --- a/langchain/callbacks/comet_ml_callback.py +++ b/langchain/callbacks/comet_ml_callback.py @@ -371,7 +371,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.starts += 1 tool = action.tool - tool_input = action.tool_input + tool_input = str(action.tool_input) log = action.log resp = self._init_resp() diff --git a/langchain/schema.py b/langchain/schema.py index 8678b1a8..a2b709f1 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -2,7 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, List, NamedTuple, Optional, TypeVar +from typing import Any, Dict, Generic, List, NamedTuple, Optional, TypeVar, Union from pydantic import BaseModel, Extra, Field, root_validator @@ -31,7 +31,7 @@ class AgentAction(NamedTuple): """Agent's action to take.""" tool: str - tool_input: str + tool_input: Union[str, dict] log: str diff --git a/langchain/tools/base.py b/langchain/tools/base.py index e6045a27..6c554beb 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -1,19 +1,84 @@ """Base implementation for tools or skills.""" -from abc import abstractmethod -from typing import Any, Optional +import inspect +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union -from pydantic import BaseModel, Extra, Field, validator +from pydantic import BaseModel, Extra, Field, create_model, validator from langchain.callbacks import get_callback_manager from langchain.callbacks.base import BaseCallbackManager -class BaseTool(BaseModel): - """Class responsible for defining a tool or skill for an LLM.""" +def create_args_schema_model_from_signature(run_func: Callable) -> Type[BaseModel]: + """Create a pydantic model type from a function's signature.""" + signature_ = inspect.signature(run_func) + field_definitions: Dict[str, Any] = {} + + for name, param in signature_.parameters.items(): + if name == "self": + continue + default_value = ( + param.default if param.default != inspect.Parameter.empty else None + ) + annotation = ( + param.annotation if param.annotation != inspect.Parameter.empty else Any + ) + # Handle functions with *args in the signature + if param.kind == inspect.Parameter.VAR_POSITIONAL: + field_definitions[name] = ( + Any, + Field(default=None, extra={"is_var_positional": True}), + ) + # handle functions with **kwargs in the signature + elif param.kind == inspect.Parameter.VAR_KEYWORD: + field_definitions[name] = ( + Any, + Field(default=None, extra={"is_var_keyword": True}), + ) + # Handle all other named parameters + else: + is_keyword_only = param.kind == inspect.Parameter.KEYWORD_ONLY + field_definitions[name] = ( + annotation, + Field( + default=default_value, extra={"is_keyword_only": is_keyword_only} + ), + ) + return create_model("ArgsModel", **field_definitions) # type: ignore + + +def _to_args_and_kwargs(model: BaseModel) -> Tuple[Sequence, dict]: + args = [] + kwargs = {} + for name, field in model.__fields__.items(): + value = getattr(model, name) + # Handle *args in the function signature + if field.field_info.extra.get("extra", {}).get("is_var_positional"): + if isinstance(value, str): + # Base case for backwards compatability + args.append(value) + elif value is not None: + args.extend(value) + # Handle **kwargs in the function signature + elif field.field_info.extra.get("extra", {}).get("is_var_keyword"): + if value is not None: + kwargs.update(value) + elif field.field_info.extra.get("extra", {}).get("is_keyword_only"): + kwargs[name] = value + else: + args.append(value) + + return tuple(args), kwargs + + +class BaseTool(ABC, BaseModel): + """Interface LangChain tools must implement.""" name: str description: str + args_schema: Optional[Type[BaseModel]] = None + """Pydantic model class to validate and parse the tool's input arguments.""" return_direct: bool = False verbose: bool = False callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) @@ -24,6 +89,33 @@ class BaseTool(BaseModel): extra = Extra.forbid arbitrary_types_allowed = True + @property + def args(self) -> Type[BaseModel]: + """Generate an input pydantic model.""" + if self.args_schema is not None: + return self.args_schema + return create_args_schema_model_from_signature(self._run) + + def _parse_input( + self, + tool_input: Union[str, Dict], + ) -> BaseModel: + """Convert tool input to pydantic model.""" + pydantic_input_type = self.args + if isinstance(tool_input, str): + # For backwards compatibility, a tool that only takes + # a single string input will be converted to a dict. + # to be validated. + field_name = next(iter(pydantic_input_type.__fields__)) + tool_input = {field_name: tool_input} + if pydantic_input_type is not None: + return pydantic_input_type.parse_obj(tool_input) + else: + raise ValueError( + f"args_schema required for tool {self.name} in order to" + f" accept input of type {type(tool_input)}" + ) + @validator("callback_manager", pre=True, always=True) def set_callback_manager( cls, callback_manager: Optional[BaseCallbackManager] @@ -35,39 +127,37 @@ class BaseTool(BaseModel): return callback_manager or get_callback_manager() @abstractmethod - def _run(self, tool_input: str) -> str: + def _run(self, *args: Any, **kwargs: Any) -> str: """Use the tool.""" @abstractmethod - async def _arun(self, tool_input: str) -> str: + async def _arun(self, *args: Any, **kwargs: Any) -> str: """Use the tool asynchronously.""" - def __call__(self, tool_input: str) -> str: - """Make tools callable with str input.""" - return self.run(tool_input) - def run( self, - tool_input: str, + tool_input: Union[str, Dict], verbose: Optional[bool] = None, start_color: Optional[str] = "green", color: Optional[str] = "green", - **kwargs: Any + **kwargs: Any, ) -> str: """Run the tool.""" + run_input = self._parse_input(tool_input) if not self.verbose and verbose is not None: verbose_ = verbose else: verbose_ = self.verbose self.callback_manager.on_tool_start( {"name": self.name, "description": self.description}, - tool_input, + str(run_input), verbose=verbose_, color=start_color, **kwargs, ) try: - observation = self._run(tool_input) + args, kwargs = _to_args_and_kwargs(run_input) + observation = self._run(*args, **kwargs) except (Exception, KeyboardInterrupt) as e: self.callback_manager.on_tool_error(e, verbose=verbose_) raise e @@ -78,13 +168,14 @@ class BaseTool(BaseModel): async def arun( self, - tool_input: str, + tool_input: Union[str, Dict], verbose: Optional[bool] = None, start_color: Optional[str] = "green", color: Optional[str] = "green", - **kwargs: Any + **kwargs: Any, ) -> str: """Run the tool asynchronously.""" + run_input = self._parse_input(tool_input) if not self.verbose and verbose is not None: verbose_ = verbose else: @@ -92,7 +183,7 @@ class BaseTool(BaseModel): if self.callback_manager.is_async: await self.callback_manager.on_tool_start( {"name": self.name, "description": self.description}, - tool_input, + str(run_input.dict()), verbose=verbose_, color=start_color, **kwargs, @@ -100,14 +191,15 @@ class BaseTool(BaseModel): else: self.callback_manager.on_tool_start( {"name": self.name, "description": self.description}, - tool_input, + str(run_input.dict()), verbose=verbose_, color=start_color, **kwargs, ) try: # We then call the tool on the tool input to get an observation - observation = await self._arun(tool_input) + args, kwargs = _to_args_and_kwargs(run_input) + observation = await self._arun(*args, **kwargs) except (Exception, KeyboardInterrupt) as e: if self.callback_manager.is_async: await self.callback_manager.on_tool_error(e, verbose=verbose_) @@ -123,3 +215,7 @@ class BaseTool(BaseModel): observation, verbose=verbose_, color=color, name=self.name, **kwargs ) return observation + + def __call__(self, tool_input: str) -> str: + """Make tool callable.""" + return self.run(tool_input) diff --git a/langchain/tools/file_management/__init__.py b/langchain/tools/file_management/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/langchain/tools/file_management/read.py b/langchain/tools/file_management/read.py new file mode 100644 index 00000000..a42df67e --- /dev/null +++ b/langchain/tools/file_management/read.py @@ -0,0 +1,29 @@ +from typing import Type + +from pydantic import BaseModel, Field + +from langchain.tools.base import BaseTool + + +class ReadFileInput(BaseModel): + """Input for ReadFileTool.""" + + file_path: str = Field(..., description="name of file") + + +class ReadFileTool(BaseTool): + name: str = "read_file" + tool_args: Type[BaseModel] = ReadFileInput + description: str = "Read file from disk" + + def _run(self, file_path: str) -> str: + try: + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + return content + except Exception as e: + return "Error: " + str(e) + + async def _arun(self, tool_input: str) -> str: + # TODO: Add aiofiles method + raise NotImplementedError diff --git a/langchain/tools/file_management/write.py b/langchain/tools/file_management/write.py new file mode 100644 index 00000000..969d2a6f --- /dev/null +++ b/langchain/tools/file_management/write.py @@ -0,0 +1,34 @@ +import os +from typing import Type + +from pydantic import BaseModel, Field + +from langchain.tools.base import BaseTool + + +class WriteFileInput(BaseModel): + """Input for WriteFileTool.""" + + file_path: str = Field(..., description="name of file") + text: str = Field(..., description="text to write to file") + + +class WriteFileTool(BaseTool): + name: str = "write_file" + tool_args: Type[BaseModel] = WriteFileInput + description: str = "Write file to disk" + + def _run(self, file_path: str, text: str) -> str: + try: + directory = os.path.dirname(file_path) + if not os.path.exists(directory) and directory: + os.makedirs(directory) + with open(file_path, "w", encoding="utf-8") as f: + f.write(text) + return "File written to successfully." + except Exception as e: + return "Error: " + str(e) + + async def _arun(self, tool_input: str) -> str: + # TODO: Add aiofiles method + raise NotImplementedError diff --git a/tests/unit_tests/agents/test_mrkl.py b/tests/unit_tests/agents/test_mrkl.py index d88fbf02..8cf4f6f8 100644 --- a/tests/unit_tests/agents/test_mrkl.py +++ b/tests/unit_tests/agents/test_mrkl.py @@ -16,7 +16,7 @@ from tests.unit_tests.llms.fake_llm import FakeLLM def get_action_and_input(text: str) -> Tuple[str, str]: output = MRKLOutputParser().parse(text) if isinstance(output, AgentAction): - return output.tool, output.tool_input + return output.tool, str(output.tool_input) else: return "Final Answer", output.return_values["output"] diff --git a/tests/unit_tests/agents/test_tools.py b/tests/unit_tests/agents/test_tools.py index f01bfa33..09dda518 100644 --- a/tests/unit_tests/agents/test_tools.py +++ b/tests/unit_tests/agents/test_tools.py @@ -1,7 +1,12 @@ """Test tool utils.""" +from datetime import datetime +from typing import Any, Optional, Type, Union + import pytest +from pydantic import BaseModel from langchain.agents.tools import Tool, tool +from langchain.tools.base import BaseTool def test_unnamed_decorator() -> None: @@ -18,6 +23,64 @@ def test_unnamed_decorator() -> None: assert search_api("test") == "API result" +class _MockSchema(BaseModel): + arg1: int + arg2: bool + arg3: Optional[dict] = None + + +class _MockStructuredTool(BaseTool): + name = "structured_api" + args_schema: Type[BaseModel] = _MockSchema + description = "A Structured Tool" + + def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: + return f"{arg1} {arg2} {arg3}" + + async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: + raise NotImplementedError + + +def test_structured_args() -> None: + """Test functionality with structured arguments.""" + structured_api = _MockStructuredTool() + assert isinstance(structured_api, BaseTool) + assert structured_api.name == "structured_api" + expected_result = "1 True {'foo': 'bar'}" + args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}} + assert structured_api.run(args) == expected_result + + +def test_structured_args_decorator() -> None: + """Test functionality with structured arguments parsed as a decorator.""" + + @tool + def structured_tool_input( + arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None + ) -> str: + """Return the arguments directly.""" + return f"{arg1}, {arg2}, {opt_arg}" + + assert isinstance(structured_tool_input, Tool) + assert structured_tool_input.name == "structured_tool_input" + args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}} + expected_result = "1, 0.001, {'foo': 'bar'}" + assert structured_tool_input.run(args) == expected_result + + +def test_empty_args_decorator() -> None: + """Test functionality with no args parsed as a decorator.""" + + @tool + def empty_tool_input() -> str: + """Return a constant.""" + return "the empty result" + + assert isinstance(empty_tool_input, Tool) + assert empty_tool_input.name == "empty_tool_input" + assert empty_tool_input.run({}) == "the empty result" + + def test_named_tool_decorator() -> None: """Test functionality when arguments are provided as input to decorator.""" @@ -57,6 +120,28 @@ def test_unnamed_tool_decorator_return_direct() -> None: assert search_api.return_direct +def test_tool_with_kwargs() -> None: + """Test functionality when only return direct is provided.""" + + @tool(return_direct=True) + def search_api( + arg_1: float, *args: Any, ping: Optional[str] = None, **kwargs: Any + ) -> str: + """Search the API for the query.""" + return f"arg_1={arg_1}, foo={args}, ping={ping}, kwargs={kwargs}" + + assert isinstance(search_api, Tool) + result = search_api.run( + tool_input={ + "arg_1": 3.2, + "args": "fam", + "kwargs": {"bar": "baz"}, + "ping": "pong", + } + ) + assert result == "arg_1=3.2, foo=('fam',), ping=pong, kwargs={'bar': 'baz'}" + + def test_missing_docstring() -> None: """Test error is raised when docstring is missing.""" # expect to throw a value error if theres no docstring @@ -67,7 +152,7 @@ def test_missing_docstring() -> None: return "API result" -def test_create_tool_posistional_args() -> None: +def test_create_tool_positional_args() -> None: """Test that positional arguments are allowed.""" test_tool = Tool("test_name", lambda x: x, "test_description") assert test_tool("foo") == "foo" diff --git a/tests/unit_tests/tools/requests/__init__.py b/tests/unit_tests/tools/requests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/tools/requests/test_tool.py b/tests/unit_tests/tools/requests/test_tool.py new file mode 100644 index 00000000..427ff611 --- /dev/null +++ b/tests/unit_tests/tools/requests/test_tool.py @@ -0,0 +1,100 @@ +import asyncio +from typing import Any, Dict + +import pytest + +from langchain.requests import TextRequestsWrapper +from langchain.tools.requests.tool import ( + RequestsDeleteTool, + RequestsGetTool, + RequestsPatchTool, + RequestsPostTool, + RequestsPutTool, + _parse_input, +) + + +class _MockTextRequestsWrapper(TextRequestsWrapper): + @staticmethod + def get(url: str, **kwargs: Any) -> str: + return "get_response" + + @staticmethod + async def aget(url: str, **kwargs: Any) -> str: + return "aget_response" + + @staticmethod + def post(url: str, data: Dict[str, Any], **kwargs: Any) -> str: + return f"post {str(data)}" + + @staticmethod + async def apost(url: str, data: Dict[str, Any], **kwargs: Any) -> str: + return f"apost {str(data)}" + + @staticmethod + def patch(url: str, data: Dict[str, Any], **kwargs: Any) -> str: + return f"patch {str(data)}" + + @staticmethod + async def apatch(url: str, data: Dict[str, Any], **kwargs: Any) -> str: + return f"apatch {str(data)}" + + @staticmethod + def put(url: str, data: Dict[str, Any], **kwargs: Any) -> str: + return f"put {str(data)}" + + @staticmethod + async def aput(url: str, data: Dict[str, Any], **kwargs: Any) -> str: + return f"aput {str(data)}" + + @staticmethod + def delete(url: str, **kwargs: Any) -> str: + return "delete_response" + + @staticmethod + async def adelete(url: str, **kwargs: Any) -> str: + return "adelete_response" + + +@pytest.fixture +def mock_requests_wrapper() -> TextRequestsWrapper: + return _MockTextRequestsWrapper() + + +def test_parse_input() -> None: + input_text = '{"url": "https://example.com", "data": {"key": "value"}}' + expected_output = {"url": "https://example.com", "data": {"key": "value"}} + assert _parse_input(input_text) == expected_output + + +def test_requests_get_tool(mock_requests_wrapper: TextRequestsWrapper) -> None: + tool = RequestsGetTool(requests_wrapper=mock_requests_wrapper) + assert tool.run("https://example.com") == "get_response" + assert asyncio.run(tool.arun("https://example.com")) == "aget_response" + + +def test_requests_post_tool(mock_requests_wrapper: TextRequestsWrapper) -> None: + tool = RequestsPostTool(requests_wrapper=mock_requests_wrapper) + input_text = '{"url": "https://example.com", "data": {"key": "value"}}' + assert tool.run(input_text) == "post {'key': 'value'}" + assert asyncio.run(tool.arun(input_text)) == "apost {'key': 'value'}" + + +def test_requests_patch_tool(mock_requests_wrapper: TextRequestsWrapper) -> None: + tool = RequestsPatchTool(requests_wrapper=mock_requests_wrapper) + input_text = '{"url": "https://example.com", "data": {"key": "value"}}' + assert tool.run(input_text) == "patch {'key': 'value'}" + assert asyncio.run(tool.arun(input_text)) == "apatch {'key': 'value'}" + + +def test_requests_put_tool(mock_requests_wrapper: TextRequestsWrapper) -> None: + tool = RequestsPutTool(requests_wrapper=mock_requests_wrapper) + input_text = '{"url": "https://example.com", "data": {"key": "value"}}' + assert tool.run(input_text) == "put {'key': 'value'}" + assert asyncio.run(tool.arun(input_text)) == "aput {'key': 'value'}" + + +def test_requests_delete_tool(mock_requests_wrapper: TextRequestsWrapper) -> None: + tool = RequestsDeleteTool(requests_wrapper=mock_requests_wrapper) + assert tool.run("https://example.com") == "delete_response" + assert asyncio.run(tool.arun("https://example.com")) == "adelete_response"