From 90ef705ced0f5983e94890305e05a014cb8b9cc2 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Tue, 18 Apr 2023 18:18:33 -0700 Subject: [PATCH] Update Tool Input (#3103) - Remove dynamic model creation in the `args()` property. _Only infer for the decorator (and add an argument to NOT infer if someone wishes to only pass as a string)_ - Update the validation example to make it less likely to be misinterpreted as a "safe" way to run a repl There is one example of "Multi-argument tools" in the custom_tools.ipynb from yesterday, but we could add more. The output parsing for the base MRKL agent hasn't been adapted to handle structured args at this point in time --------- Co-authored-by: Harrison Chase --- docs/modules/agents/tools/custom_tools.ipynb | 97 ++++++++--- .../agents/tools/tool_input_validation.ipynb | 160 +++++++++--------- langchain/agents/tools.py | 35 ++-- langchain/tools/base.py | 117 ++++--------- langchain/tools/file_management/write.py | 2 +- tests/unit_tests/agents/test_tools.py | 18 +- 6 files changed, 216 insertions(+), 213 deletions(-) diff --git a/docs/modules/agents/tools/custom_tools.ipynb b/docs/modules/agents/tools/custom_tools.ipynb index 4f57597e..841cee6f 100644 --- a/docs/modules/agents/tools/custom_tools.ipynb +++ b/docs/modules/agents/tools/custom_tools.ipynb @@ -12,6 +12,7 @@ "- name (str), is required and must be unique within a set of tools provided to an agent\n", "- description (str), is optional but recommended, as it is used by an agent to determine tool use\n", "- return_direct (bool), defaults to False\n", + "- args_schema (Pydantic BaseModel), is optional but recommended, can be used to provide more information or validation for expected parameters.\n", "\n", "The function that should be called when the tool is selected should return a single string.\n", "\n", @@ -91,12 +92,22 @@ " func=search.run,\n", " description=\"useful for when you need to answer questions about current events\"\n", " ),\n", + "]\n", + "# You can also define an args_schema to provide more information about inputs\n", + "from pydantic import BaseModel, Field\n", + "\n", + "class CalculatorInput(BaseModel):\n", + " query: str = Field(description=\"should be a math expression\")\n", + " \n", + "\n", + "tools.append(\n", " Tool(\n", " name=\"Calculator\",\n", " func=llm_math_chain.run,\n", - " description=\"useful for when you need to answer questions about math\"\n", + " description=\"useful for when you need to answer questions about math\",\n", + " args_schema=CalculatorInput\n", " )\n", - "]" + ")" ] }, { @@ -130,22 +141,20 @@ "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n", "Action: Search\n", - "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel – Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Gigi Hadid's age\n", - "Action: Search\n", - "Action Input: \"Gigi Hadid age\"\u001b[0m\u001b[36;1m\u001b[1;3m27 years\u001b[0m\u001b[32;1m\u001b[1;3mI need to calculate her age raised to the 0.43 power\n", + "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mDiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Camila Morrone's current age\n", "Action: Calculator\n", - "Action Input: 27^(0.43)\u001b[0m\n", + "Action Input: 25^(0.43)\u001b[0m\n", "\n", "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", - "27^(0.43)\u001b[32;1m\u001b[1;3m```text\n", - "27**(0.43)\n", + "25^(0.43)\u001b[32;1m\u001b[1;3m```text\n", + "25**(0.43)\n", "```\n", - "...numexpr.evaluate(\"27**(0.43)\")...\n", + "...numexpr.evaluate(\"25**(0.43)\")...\n", "\u001b[0m\n", - "Answer: \u001b[33;1m\u001b[1;3m4.125593352125936\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3m3.991298452658078\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "\u001b[33;1m\u001b[1;3mAnswer: 4.125593352125936\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n", - "Final Answer: 4.125593352125936\u001b[0m\n", + "\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n", + "Final Answer: 3.991298452658078\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -153,7 +162,7 @@ { "data": { "text/plain": [ - "'4.125593352125936'" + "'3.991298452658078'" ] }, "execution_count": 5, @@ -197,6 +206,7 @@ "class CustomCalculatorTool(BaseTool):\n", " name = \"Calculator\"\n", " description = \"useful for when you need to answer questions about math\"\n", + " args_schema=CalculatorInput\n", "\n", " def _run(self, query: str) -> str:\n", " \"\"\"Use the tool.\"\"\"\n", @@ -248,9 +258,7 @@ "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n", "Action: Search\n", - "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel – Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mI now know Leo DiCaprio's girlfriend's name and that he's currently linked to Gigi Hadid. I need to find out Camila Morrone's age.\n", - "Action: Search\n", - "Action Input: \"Camila Morrone age\"\u001b[0m\u001b[36;1m\u001b[1;3m25 years\u001b[0m\u001b[32;1m\u001b[1;3mI have Camila Morrone's age. I need to calculate her age raised to the 0.43 power.\n", + "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mDiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Camila Morrone's current age\n", "Action: Calculator\n", "Action Input: 25^(0.43)\u001b[0m\n", "\n", @@ -262,8 +270,8 @@ "\u001b[0m\n", "Answer: \u001b[33;1m\u001b[1;3m3.991298452658078\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the answer to the original question.\n", - "Final Answer: Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\u001b[0m\n", + "\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n", + "Final Answer: 3.991298452658078\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -271,7 +279,7 @@ { "data": { "text/plain": [ - "\"Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\"" + "'3.991298452658078'" ] }, "execution_count": 9, @@ -321,7 +329,7 @@ { "data": { "text/plain": [ - "Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', args_schema=, return_direct=False, verbose=False, callback_manager=, func=, coroutine=None)" + "Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', args_schema=, return_direct=False, verbose=False, callback_manager=, func=, coroutine=None)" ] }, "execution_count": 11, @@ -365,7 +373,7 @@ { "data": { "text/plain": [ - "Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=, return_direct=True, verbose=False, callback_manager=, func=, coroutine=None)" + "Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=, return_direct=True, verbose=False, callback_manager=, func=, coroutine=None)" ] }, "execution_count": 13, @@ -377,6 +385,51 @@ "search_api" ] }, + { + "cell_type": "markdown", + "id": "de34a6a3", + "metadata": {}, + "source": [ + "You can also provide `args_schema` to provide more information about the argument" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "f3a5c106", + "metadata": {}, + "outputs": [], + "source": [ + "class SearchInput(BaseModel):\n", + " query: str = Field(description=\"should be a search query\")\n", + " \n", + "@tool(\"search\", return_direct=True, args_schema=SearchInput)\n", + "def search_api(query: str) -> str:\n", + " \"\"\"Searches the API for the query.\"\"\"\n", + " return \"Results\"" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "7914ba6b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=, return_direct=True, verbose=False, callback_manager=, func=, coroutine=None)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "search_api" + ] + }, { "cell_type": "markdown", "id": "1d0430d6", @@ -784,7 +837,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.9.1" }, "vscode": { "interpreter": { diff --git a/docs/modules/agents/tools/tool_input_validation.ipynb b/docs/modules/agents/tools/tool_input_validation.ipynb index 3fe9aed4..f18e5929 100644 --- a/docs/modules/agents/tools/tool_input_validation.ipynb +++ b/docs/modules/agents/tools/tool_input_validation.ipynb @@ -19,13 +19,12 @@ }, "outputs": [], "source": [ - "from langchain.agents import initialize_agent, Tool\n", - "from langchain.agents import AgentType\n", + "from typing import Any, Dict\n", + "\n", + "from langchain.agents import AgentType, initialize_agent\n", "from langchain.llms import OpenAI\n", - "from langchain.tools.python.tool import PythonREPLTool\n", - "from pydantic import BaseModel\n", - "from pydantic import Field, root_validator\n", - "from typing import Dict, Any" + "from langchain.tools.requests.tool import RequestsGetTool, TextRequestsWrapper\n", + "from pydantic import BaseModel, Field, root_validator\n" ] }, { @@ -42,22 +41,20 @@ { "cell_type": "code", "execution_count": 3, - "metadata": { - "tags": [] - }, - "outputs": [], + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" + ] + } + ], "source": [ - "class ToolInputModel(BaseModel):\n", - " query: str = Field(...)\n", - " \n", - " @root_validator\n", - " def validate_query(cls, values: Dict[str, Any]) -> Dict:\n", - " # Note: this is NOT a safe REPL! This is used for instructive purposes only\n", - " if \"import os\" in values[\"query\"]:\n", - " raise ValueError(\"'import os' not permitted in this python REPL.\")\n", - " return values\n", - " \n", - "tool = PythonREPLTool(args_schema=ToolInputModel)" + "!pip install tldextract > /dev/null" ] }, { @@ -68,7 +65,27 @@ }, "outputs": [], "source": [ - "agent = initialize_agent([tool], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)" + "import tldextract\n", + "\n", + "_APPROVED_DOMAINS = {\n", + " \"langchain\",\n", + " \"wikipedia\",\n", + "}\n", + "\n", + "class ToolInputSchema(BaseModel):\n", + "\n", + " url: str = Field(...)\n", + " \n", + " @root_validator\n", + " def validate_query(cls, values: Dict[str, Any]) -> Dict:\n", + " url = values[\"url\"]\n", + " domain = tldextract.extract(url).domain\n", + " if domain not in _APPROVED_DOMAINS:\n", + " raise ValueError(f\"Domain {domain} is not on the approved list:\"\n", + " f\" {sorted(_APPROVED_DOMAINS)}\")\n", + " return values\n", + " \n", + "tool = RequestsGetTool(args_schema=ToolInputSchema, requests_wrapper=TextRequestsWrapper())" ] }, { @@ -77,40 +94,9 @@ "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I need to define a function that adds two numbers\n", - "Action: Python REPL\n", - "Action Input: def add_two_numbers(a, b):\n", - " return a + b\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3m I need to call the function\n", - "Action: Python REPL\n", - "Action Input: print(add_two_numbers(2, 2))\u001b[0m\u001b[36;1m\u001b[1;3m4\n", - "\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: 4\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "'4'" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "# This will succeed, since there aren't any arguments that will be triggered during validation\n", - "agent.run(\"Run a python function that adds 2 and 2\")" + "agent = initialize_agent([tool], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)" ] }, { @@ -124,38 +110,46 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I need to import os and then list the dir\n", - "Action: Python REPL\n", - "Action Input: import os; print(os.listdir())\u001b[0m" - ] - }, - { - "ename": "ValidationError", - "evalue": "1 validation error for ToolInputModel\n__root__\n 'import os' not (type=value_error)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43magent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mRun a python function that imports os and lists the dir\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:213\u001b[0m, in \u001b[0;36mChain.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(args) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 212\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`run` supports only one positional argument.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 213\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_keys[\u001b[38;5;241m0\u001b[39m]]\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m kwargs \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m args:\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m(kwargs)[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_keys[\u001b[38;5;241m0\u001b[39m]]\n", - "File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:116\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_end(outputs, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprep_outputs(inputs, outputs, return_only_outputs)\n", - "File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:113\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_start(\n\u001b[1;32m 108\u001b[0m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m},\n\u001b[1;32m 109\u001b[0m inputs,\n\u001b[1;32m 110\u001b[0m verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose,\n\u001b[1;32m 111\u001b[0m )\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 113\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n", - "File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:792\u001b[0m, in \u001b[0;36mAgentExecutor._call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# We now enter the agent loop (until it returns something).\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_continue(iterations, time_elapsed):\n\u001b[0;32m--> 792\u001b[0m next_step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_take_next_step\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 793\u001b[0m \u001b[43m \u001b[49m\u001b[43mname_to_tool_map\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcolor_mapping\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mintermediate_steps\u001b[49m\n\u001b[1;32m 794\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 795\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(next_step_output, AgentFinish):\n\u001b[1;32m 796\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_return(next_step_output, intermediate_steps)\n", - "File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:695\u001b[0m, in \u001b[0;36mAgentExecutor._take_next_step\u001b[0;34m(self, name_to_tool_map, color_mapping, inputs, intermediate_steps)\u001b[0m\n\u001b[1;32m 693\u001b[0m tool_run_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mllm_prefix\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 694\u001b[0m \u001b[38;5;66;03m# We then call the tool on the tool input to get an observation\u001b[39;00m\n\u001b[0;32m--> 695\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[43mtool\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 696\u001b[0m \u001b[43m \u001b[49m\u001b[43magent_action\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtool_input\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 697\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 698\u001b[0m \u001b[43m \u001b[49m\u001b[43mcolor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcolor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 699\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtool_run_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 700\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 701\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 702\u001b[0m tool_run_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39magent\u001b[38;5;241m.\u001b[39mtool_run_logging_kwargs()\n", - "File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:146\u001b[0m, in \u001b[0;36mBaseTool.run\u001b[0;34m(self, tool_input, verbose, start_color, color, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrun\u001b[39m(\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 139\u001b[0m tool_input: Union[\u001b[38;5;28mstr\u001b[39m, Dict],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any,\n\u001b[1;32m 144\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[1;32m 145\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Run the tool.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 146\u001b[0m run_input \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_parse_input\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose \u001b[38;5;129;01mand\u001b[39;00m verbose \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 148\u001b[0m verbose_ \u001b[38;5;241m=\u001b[39m verbose\n", - "File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:112\u001b[0m, in \u001b[0;36mBaseTool._parse_input\u001b[0;34m(self, tool_input)\u001b[0m\n\u001b[1;32m 110\u001b[0m tool_input \u001b[38;5;241m=\u001b[39m {field_name: tool_input}\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pydantic_input_type \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 112\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mpydantic_input_type\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse_obj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 115\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124margs_schema required for tool \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m in order to\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m accept input of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(tool_input)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 117\u001b[0m )\n", - "File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:526\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.parse_obj\u001b[0;34m()\u001b[0m\n", - "File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:341\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[0;34m()\u001b[0m\n", - "\u001b[0;31mValidationError\u001b[0m: 1 validation error for ToolInputModel\n__root__\n 'import os' not (type=value_error)" + "The main title of langchain.com is \"LANG CHAIN 🦜️🔗 Official Home Page\"\n" ] } ], "source": [ - "# This will fail, because the attempt to import os will trigger a validation error\n", - "agent.run(\"Run a python function that imports os and lists the dir\")" + "# This will succeed, since there aren't any arguments that will be triggered during validation\n", + "answer = agent.run(\"What's the main title on langchain.com?\")\n", + "print(answer)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "ename": "ValidationError", + "evalue": "1 validation error for ToolInputSchema\n__root__\n Domain google is not on the approved list: ['langchain', 'wikipedia'] (type=value_error)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m agent\u001b[39m.\u001b[39;49mrun(\u001b[39m\"\u001b[39;49m\u001b[39mWhat\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39ms the main title on google.com?\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n", + "File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:213\u001b[0m, in \u001b[0;36mChain.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(args) \u001b[39m!=\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[1;32m 212\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m`run` supports only one positional argument.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 213\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m(args[\u001b[39m0\u001b[39;49m])[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput_keys[\u001b[39m0\u001b[39m]]\n\u001b[1;32m 215\u001b[0m \u001b[39mif\u001b[39;00m kwargs \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m args:\n\u001b[1;32m 216\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m(kwargs)[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput_keys[\u001b[39m0\u001b[39m]]\n", + "File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:116\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[39mexcept\u001b[39;00m (\u001b[39mKeyboardInterrupt\u001b[39;00m, \u001b[39mException\u001b[39;00m) \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcallback_manager\u001b[39m.\u001b[39mon_chain_error(e, verbose\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose)\n\u001b[0;32m--> 116\u001b[0m \u001b[39mraise\u001b[39;00m e\n\u001b[1;32m 117\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcallback_manager\u001b[39m.\u001b[39mon_chain_end(outputs, verbose\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose)\n\u001b[1;32m 118\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprep_outputs(inputs, outputs, return_only_outputs)\n", + "File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:113\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcallback_manager\u001b[39m.\u001b[39mon_chain_start(\n\u001b[1;32m 108\u001b[0m {\u001b[39m\"\u001b[39m\u001b[39mname\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m},\n\u001b[1;32m 109\u001b[0m inputs,\n\u001b[1;32m 110\u001b[0m verbose\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose,\n\u001b[1;32m 111\u001b[0m )\n\u001b[1;32m 112\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 113\u001b[0m outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call(inputs)\n\u001b[1;32m 114\u001b[0m \u001b[39mexcept\u001b[39;00m (\u001b[39mKeyboardInterrupt\u001b[39;00m, \u001b[39mException\u001b[39;00m) \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcallback_manager\u001b[39m.\u001b[39mon_chain_error(e, verbose\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose)\n", + "File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:792\u001b[0m, in \u001b[0;36mAgentExecutor._call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[39m# We now enter the agent loop (until it returns something).\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_should_continue(iterations, time_elapsed):\n\u001b[0;32m--> 792\u001b[0m next_step_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_take_next_step(\n\u001b[1;32m 793\u001b[0m name_to_tool_map, color_mapping, inputs, intermediate_steps\n\u001b[1;32m 794\u001b[0m )\n\u001b[1;32m 795\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(next_step_output, AgentFinish):\n\u001b[1;32m 796\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_return(next_step_output, intermediate_steps)\n", + "File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:695\u001b[0m, in \u001b[0;36mAgentExecutor._take_next_step\u001b[0;34m(self, name_to_tool_map, color_mapping, inputs, intermediate_steps)\u001b[0m\n\u001b[1;32m 693\u001b[0m tool_run_kwargs[\u001b[39m\"\u001b[39m\u001b[39mllm_prefix\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 694\u001b[0m \u001b[39m# We then call the tool on the tool input to get an observation\u001b[39;00m\n\u001b[0;32m--> 695\u001b[0m observation \u001b[39m=\u001b[39m tool\u001b[39m.\u001b[39;49mrun(\n\u001b[1;32m 696\u001b[0m agent_action\u001b[39m.\u001b[39;49mtool_input,\n\u001b[1;32m 697\u001b[0m verbose\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mverbose,\n\u001b[1;32m 698\u001b[0m color\u001b[39m=\u001b[39;49mcolor,\n\u001b[1;32m 699\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mtool_run_kwargs,\n\u001b[1;32m 700\u001b[0m )\n\u001b[1;32m 701\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 702\u001b[0m tool_run_kwargs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39magent\u001b[39m.\u001b[39mtool_run_logging_kwargs()\n", + "File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:110\u001b[0m, in \u001b[0;36mBaseTool.run\u001b[0;34m(self, tool_input, verbose, start_color, color, **kwargs)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mrun\u001b[39m(\n\u001b[1;32m 102\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 103\u001b[0m tool_input: Union[\u001b[39mstr\u001b[39m, Dict],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs: Any,\n\u001b[1;32m 108\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mstr\u001b[39m:\n\u001b[1;32m 109\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Run the tool.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 110\u001b[0m run_input \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_parse_input(tool_input)\n\u001b[1;32m 111\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose \u001b[39mand\u001b[39;00m verbose \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 112\u001b[0m verbose_ \u001b[39m=\u001b[39m verbose\n", + "File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:71\u001b[0m, in \u001b[0;36mBaseTool._parse_input\u001b[0;34m(self, tool_input)\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39missubclass\u001b[39m(input_args, BaseModel):\n\u001b[1;32m 70\u001b[0m key_ \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39m(\u001b[39miter\u001b[39m(input_args\u001b[39m.\u001b[39m__fields__\u001b[39m.\u001b[39mkeys()))\n\u001b[0;32m---> 71\u001b[0m input_args\u001b[39m.\u001b[39;49mparse_obj({key_: tool_input})\n\u001b[1;32m 72\u001b[0m \u001b[39m# Passing as a positional argument is more straightforward for\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \u001b[39m# backwards compatability\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39mreturn\u001b[39;00m tool_input\n", + "File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:526\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.parse_obj\u001b[0;34m()\u001b[0m\n", + "File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:341\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[0;34m()\u001b[0m\n", + "\u001b[0;31mValidationError\u001b[0m: 1 validation error for ToolInputSchema\n__root__\n Domain google is not on the approved list: ['langchain', 'wikipedia'] (type=value_error)" + ] + } + ], + "source": [ + "agent.run(\"What's the main title on google.com?\")" ] }, { diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index 9ae8f698..fa40bce8 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -2,9 +2,9 @@ from inspect import signature from typing import Any, Awaitable, Callable, Optional, Type, Union -from pydantic import BaseModel +from pydantic import BaseModel, validate_arguments -from langchain.tools.base import BaseTool, create_args_schema_model_from_signature +from langchain.tools.base import BaseTool class Tool(BaseTool): @@ -16,15 +16,6 @@ class Tool(BaseTool): coroutine: Optional[Callable[..., Awaitable[str]]] = None """The asynchronous version of the function.""" - @property - def args(self) -> Type[BaseModel]: - """Generate an input pydantic model.""" - if self.args_schema is not None: - return self.args_schema - # Infer the schema directly from the function to add more structured - # arguments. - return create_args_schema_model_from_signature(self.func) - def _run(self, *args: Any, **kwargs: Any) -> str: """Use the tool.""" return self.func(*args, **kwargs) @@ -60,9 +51,23 @@ class InvalidTool(BaseTool): return f"{tool_name} is not a valid tool, try another one." -def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable: +def tool( + *args: Union[str, Callable], + return_direct: bool = False, + args_schema: Optional[Type[BaseModel]] = None, + infer_schema: bool = True, +) -> Callable: """Make tools out of functions, can be used with or without arguments. + Args: + *args: The arguments to the tool. + return_direct: Whether to return directly from the tool rather + than continuing the agent loop. + args_schema: optional argument schema for user to specify + infer_schema: Whether to infer the schema of the arguments from + the function's signature. This also makes the resultant tool + accept a dictionary input to its `run()` function. + Requires: - Function must be of type (str) -> str - Function must have a docstring @@ -87,11 +92,13 @@ def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable: # Description example: # search_api(query: str) - Searches the API for the query. description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}" - args_schema = create_args_schema_model_from_signature(func) + _args_schema = args_schema + if _args_schema is None and infer_schema: + _args_schema = validate_arguments(func).model # type: ignore tool_ = Tool( name=tool_name, func=func, - args_schema=args_schema, + args_schema=_args_schema, description=description, return_direct=return_direct, ) diff --git a/langchain/tools/base.py b/langchain/tools/base.py index 6c554beb..eb03dfdd 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -1,75 +1,21 @@ """Base implementation for tools or skills.""" -import inspect from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union -from pydantic import BaseModel, Extra, Field, create_model, validator +from pydantic import BaseModel, Extra, Field, validator from langchain.callbacks import get_callback_manager from langchain.callbacks.base import BaseCallbackManager -def create_args_schema_model_from_signature(run_func: Callable) -> Type[BaseModel]: - """Create a pydantic model type from a function's signature.""" - signature_ = inspect.signature(run_func) - field_definitions: Dict[str, Any] = {} - - for name, param in signature_.parameters.items(): - if name == "self": - continue - default_value = ( - param.default if param.default != inspect.Parameter.empty else None - ) - annotation = ( - param.annotation if param.annotation != inspect.Parameter.empty else Any - ) - # Handle functions with *args in the signature - if param.kind == inspect.Parameter.VAR_POSITIONAL: - field_definitions[name] = ( - Any, - Field(default=None, extra={"is_var_positional": True}), - ) - # handle functions with **kwargs in the signature - elif param.kind == inspect.Parameter.VAR_KEYWORD: - field_definitions[name] = ( - Any, - Field(default=None, extra={"is_var_keyword": True}), - ) - # Handle all other named parameters - else: - is_keyword_only = param.kind == inspect.Parameter.KEYWORD_ONLY - field_definitions[name] = ( - annotation, - Field( - default=default_value, extra={"is_keyword_only": is_keyword_only} - ), - ) - return create_model("ArgsModel", **field_definitions) # type: ignore - - -def _to_args_and_kwargs(model: BaseModel) -> Tuple[Sequence, dict]: - args = [] - kwargs = {} - for name, field in model.__fields__.items(): - value = getattr(model, name) - # Handle *args in the function signature - if field.field_info.extra.get("extra", {}).get("is_var_positional"): - if isinstance(value, str): - # Base case for backwards compatability - args.append(value) - elif value is not None: - args.extend(value) - # Handle **kwargs in the function signature - elif field.field_info.extra.get("extra", {}).get("is_var_keyword"): - if value is not None: - kwargs.update(value) - elif field.field_info.extra.get("extra", {}).get("is_keyword_only"): - kwargs[name] = value - else: - args.append(value) - - return tuple(args), kwargs +def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]: + # For backwards compatability, if run_input is a string, + # pass as a positional argument. + if isinstance(run_input, str): + return (run_input,), {} + else: + return [], run_input class BaseTool(ABC, BaseModel): @@ -90,31 +36,28 @@ class BaseTool(ABC, BaseModel): arbitrary_types_allowed = True @property - def args(self) -> Type[BaseModel]: + def args(self) -> Union[Type[BaseModel], Type[str]]: """Generate an input pydantic model.""" - if self.args_schema is not None: - return self.args_schema - return create_args_schema_model_from_signature(self._run) + return str if self.args_schema is None else self.args_schema def _parse_input( self, tool_input: Union[str, Dict], - ) -> BaseModel: + ) -> None: """Convert tool input to pydantic model.""" - pydantic_input_type = self.args + input_args = self.args if isinstance(tool_input, str): - # For backwards compatibility, a tool that only takes - # a single string input will be converted to a dict. - # to be validated. - field_name = next(iter(pydantic_input_type.__fields__)) - tool_input = {field_name: tool_input} - if pydantic_input_type is not None: - return pydantic_input_type.parse_obj(tool_input) + if issubclass(input_args, BaseModel): + key_ = next(iter(input_args.__fields__.keys())) + input_args.validate({key_: tool_input}) else: - raise ValueError( - f"args_schema required for tool {self.name} in order to" - f" accept input of type {type(tool_input)}" - ) + if issubclass(input_args, BaseModel): + input_args.validate(tool_input) + else: + raise ValueError( + f"args_schema required for tool {self.name} in order to" + f" accept input of type {type(tool_input)}" + ) @validator("callback_manager", pre=True, always=True) def set_callback_manager( @@ -143,20 +86,20 @@ class BaseTool(ABC, BaseModel): **kwargs: Any, ) -> str: """Run the tool.""" - run_input = self._parse_input(tool_input) + self._parse_input(tool_input) if not self.verbose and verbose is not None: verbose_ = verbose else: verbose_ = self.verbose self.callback_manager.on_tool_start( {"name": self.name, "description": self.description}, - str(run_input), + tool_input if isinstance(tool_input, str) else str(tool_input), verbose=verbose_, color=start_color, **kwargs, ) try: - args, kwargs = _to_args_and_kwargs(run_input) + args, kwargs = _to_args_and_kwargs(tool_input) observation = self._run(*args, **kwargs) except (Exception, KeyboardInterrupt) as e: self.callback_manager.on_tool_error(e, verbose=verbose_) @@ -175,7 +118,7 @@ class BaseTool(ABC, BaseModel): **kwargs: Any, ) -> str: """Run the tool asynchronously.""" - run_input = self._parse_input(tool_input) + self._parse_input(tool_input) if not self.verbose and verbose is not None: verbose_ = verbose else: @@ -183,7 +126,7 @@ class BaseTool(ABC, BaseModel): if self.callback_manager.is_async: await self.callback_manager.on_tool_start( {"name": self.name, "description": self.description}, - str(run_input.dict()), + tool_input if isinstance(tool_input, str) else str(tool_input), verbose=verbose_, color=start_color, **kwargs, @@ -191,14 +134,14 @@ class BaseTool(ABC, BaseModel): else: self.callback_manager.on_tool_start( {"name": self.name, "description": self.description}, - str(run_input.dict()), + tool_input if isinstance(tool_input, str) else str(tool_input), verbose=verbose_, color=start_color, **kwargs, ) try: # We then call the tool on the tool input to get an observation - args, kwargs = _to_args_and_kwargs(run_input) + args, kwargs = _to_args_and_kwargs(tool_input) observation = await self._arun(*args, **kwargs) except (Exception, KeyboardInterrupt) as e: if self.callback_manager.is_async: diff --git a/langchain/tools/file_management/write.py b/langchain/tools/file_management/write.py index 969d2a6f..0748ee82 100644 --- a/langchain/tools/file_management/write.py +++ b/langchain/tools/file_management/write.py @@ -29,6 +29,6 @@ class WriteFileTool(BaseTool): except Exception as e: return "Error: " + str(e) - async def _arun(self, tool_input: str) -> str: + async def _arun(self, file_path: str, text: str) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/tests/unit_tests/agents/test_tools.py b/tests/unit_tests/agents/test_tools.py index 09dda518..cdb7929d 100644 --- a/tests/unit_tests/agents/test_tools.py +++ b/tests/unit_tests/agents/test_tools.py @@ -1,6 +1,6 @@ """Test tool utils.""" from datetime import datetime -from typing import Any, Optional, Type, Union +from typing import Optional, Type, Union import pytest from pydantic import BaseModel @@ -125,21 +125,27 @@ def test_tool_with_kwargs() -> None: @tool(return_direct=True) def search_api( - arg_1: float, *args: Any, ping: Optional[str] = None, **kwargs: Any + arg_1: float, + ping: str = "hi", ) -> str: """Search the API for the query.""" - return f"arg_1={arg_1}, foo={args}, ping={ping}, kwargs={kwargs}" + return f"arg_1={arg_1}, ping={ping}" assert isinstance(search_api, Tool) result = search_api.run( tool_input={ "arg_1": 3.2, - "args": "fam", - "kwargs": {"bar": "baz"}, "ping": "pong", } ) - assert result == "arg_1=3.2, foo=('fam',), ping=pong, kwargs={'bar': 'baz'}" + assert result == "arg_1=3.2, ping=pong" + + result = search_api.run( + tool_input={ + "arg_1": 3.2, + } + ) + assert result == "arg_1=3.2, ping=hi" def test_missing_docstring() -> None: