Update Tool Input (#3103)

- Remove dynamic model creation in the `args()` property. _Only infer
for the decorator (and add an argument to NOT infer if someone wishes to
only pass as a string)_
- Update the validation example to make it less likely to be
misinterpreted as a "safe" way to run a repl


There is one example of "Multi-argument tools" in the custom_tools.ipynb
from yesterday, but we could add more. The output parsing for the base
MRKL agent hasn't been adapted to handle structured args at this point
in time

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Zander Chase 2023-04-18 18:18:33 -07:00 committed by GitHub
parent 19116010ee
commit 90ef705ced
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 216 additions and 213 deletions

View File

@ -12,6 +12,7 @@
"- name (str), is required and must be unique within a set of tools provided to an agent\n", "- name (str), is required and must be unique within a set of tools provided to an agent\n",
"- description (str), is optional but recommended, as it is used by an agent to determine tool use\n", "- description (str), is optional but recommended, as it is used by an agent to determine tool use\n",
"- return_direct (bool), defaults to False\n", "- return_direct (bool), defaults to False\n",
"- args_schema (Pydantic BaseModel), is optional but recommended, can be used to provide more information or validation for expected parameters.\n",
"\n", "\n",
"The function that should be called when the tool is selected should return a single string.\n", "The function that should be called when the tool is selected should return a single string.\n",
"\n", "\n",
@ -91,12 +92,22 @@
" func=search.run,\n", " func=search.run,\n",
" description=\"useful for when you need to answer questions about current events\"\n", " description=\"useful for when you need to answer questions about current events\"\n",
" ),\n", " ),\n",
"]\n",
"# You can also define an args_schema to provide more information about inputs\n",
"from pydantic import BaseModel, Field\n",
"\n",
"class CalculatorInput(BaseModel):\n",
" query: str = Field(description=\"should be a math expression\")\n",
" \n",
"\n",
"tools.append(\n",
" Tool(\n", " Tool(\n",
" name=\"Calculator\",\n", " name=\"Calculator\",\n",
" func=llm_math_chain.run,\n", " func=llm_math_chain.run,\n",
" description=\"useful for when you need to answer questions about math\"\n", " description=\"useful for when you need to answer questions about math\",\n",
" args_schema=CalculatorInput\n",
" )\n", " )\n",
"]" ")"
] ]
}, },
{ {
@ -130,22 +141,20 @@
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n", "\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n",
"Action: Search\n", "Action: Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Gigi Hadid's age\n", "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mDiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Camila Morrone's current age\n",
"Action: Search\n",
"Action Input: \"Gigi Hadid age\"\u001b[0m\u001b[36;1m\u001b[1;3m27 years\u001b[0m\u001b[32;1m\u001b[1;3mI need to calculate her age raised to the 0.43 power\n",
"Action: Calculator\n", "Action: Calculator\n",
"Action Input: 27^(0.43)\u001b[0m\n", "Action Input: 25^(0.43)\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
"27^(0.43)\u001b[32;1m\u001b[1;3m```text\n", "25^(0.43)\u001b[32;1m\u001b[1;3m```text\n",
"27**(0.43)\n", "25**(0.43)\n",
"```\n", "```\n",
"...numexpr.evaluate(\"27**(0.43)\")...\n", "...numexpr.evaluate(\"25**(0.43)\")...\n",
"\u001b[0m\n", "\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m4.125593352125936\u001b[0m\n", "Answer: \u001b[33;1m\u001b[1;3m3.991298452658078\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[33;1m\u001b[1;3mAnswer: 4.125593352125936\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n", "\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n",
"Final Answer: 4.125593352125936\u001b[0m\n", "Final Answer: 3.991298452658078\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -153,7 +162,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'4.125593352125936'" "'3.991298452658078'"
] ]
}, },
"execution_count": 5, "execution_count": 5,
@ -197,6 +206,7 @@
"class CustomCalculatorTool(BaseTool):\n", "class CustomCalculatorTool(BaseTool):\n",
" name = \"Calculator\"\n", " name = \"Calculator\"\n",
" description = \"useful for when you need to answer questions about math\"\n", " description = \"useful for when you need to answer questions about math\"\n",
" args_schema=CalculatorInput\n",
"\n", "\n",
" def _run(self, query: str) -> str:\n", " def _run(self, query: str) -> str:\n",
" \"\"\"Use the tool.\"\"\"\n", " \"\"\"Use the tool.\"\"\"\n",
@ -248,9 +258,7 @@
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n", "\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n",
"Action: Search\n", "Action: Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mI now know Leo DiCaprio's girlfriend's name and that he's currently linked to Gigi Hadid. I need to find out Camila Morrone's age.\n", "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mDiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Camila Morrone's current age\n",
"Action: Search\n",
"Action Input: \"Camila Morrone age\"\u001b[0m\u001b[36;1m\u001b[1;3m25 years\u001b[0m\u001b[32;1m\u001b[1;3mI have Camila Morrone's age. I need to calculate her age raised to the 0.43 power.\n",
"Action: Calculator\n", "Action: Calculator\n",
"Action Input: 25^(0.43)\u001b[0m\n", "Action Input: 25^(0.43)\u001b[0m\n",
"\n", "\n",
@ -262,8 +270,8 @@
"\u001b[0m\n", "\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m3.991298452658078\u001b[0m\n", "Answer: \u001b[33;1m\u001b[1;3m3.991298452658078\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the answer to the original question.\n", "\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n",
"Final Answer: Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\u001b[0m\n", "Final Answer: 3.991298452658078\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -271,7 +279,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"\"Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\"" "'3.991298452658078'"
] ]
}, },
"execution_count": 9, "execution_count": 9,
@ -321,7 +329,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', args_schema=<class 'pydantic.main.ArgsModel'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x124346f10>, func=<function search_api at 0x16ad6e020>, coroutine=None)" "Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', args_schema=<class 'pydantic.main.SearchApi'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x12748c4c0>, func=<function search_api at 0x16bd664c0>, coroutine=None)"
] ]
}, },
"execution_count": 11, "execution_count": 11,
@ -365,7 +373,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=<class 'pydantic.main.ArgsModel'>, return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x124346f10>, func=<function search_api at 0x16ad6d3a0>, coroutine=None)" "Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=<class 'pydantic.main.SearchApi'>, return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x12748c4c0>, func=<function search_api at 0x16bd66310>, coroutine=None)"
] ]
}, },
"execution_count": 13, "execution_count": 13,
@ -377,6 +385,51 @@
"search_api" "search_api"
] ]
}, },
{
"cell_type": "markdown",
"id": "de34a6a3",
"metadata": {},
"source": [
"You can also provide `args_schema` to provide more information about the argument"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "f3a5c106",
"metadata": {},
"outputs": [],
"source": [
"class SearchInput(BaseModel):\n",
" query: str = Field(description=\"should be a search query\")\n",
" \n",
"@tool(\"search\", return_direct=True, args_schema=SearchInput)\n",
"def search_api(query: str) -> str:\n",
" \"\"\"Searches the API for the query.\"\"\"\n",
" return \"Results\""
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "7914ba6b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=<class '__main__.SearchInput'>, return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x12748c4c0>, func=<function search_api at 0x16bcf0ee0>, coroutine=None)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"search_api"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "1d0430d6", "id": "1d0430d6",
@ -784,7 +837,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.2" "version": "3.9.1"
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {

View File

@ -19,13 +19,12 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.agents import initialize_agent, Tool\n", "from typing import Any, Dict\n",
"from langchain.agents import AgentType\n", "\n",
"from langchain.agents import AgentType, initialize_agent\n",
"from langchain.llms import OpenAI\n", "from langchain.llms import OpenAI\n",
"from langchain.tools.python.tool import PythonREPLTool\n", "from langchain.tools.requests.tool import RequestsGetTool, TextRequestsWrapper\n",
"from pydantic import BaseModel\n", "from pydantic import BaseModel, Field, root_validator\n"
"from pydantic import Field, root_validator\n",
"from typing import Dict, Any"
] ]
}, },
{ {
@ -42,22 +41,20 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 3,
"metadata": { "metadata": {},
"tags": [] "outputs": [
}, {
"outputs": [], "name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [ "source": [
"class ToolInputModel(BaseModel):\n", "!pip install tldextract > /dev/null"
" query: str = Field(...)\n",
" \n",
" @root_validator\n",
" def validate_query(cls, values: Dict[str, Any]) -> Dict:\n",
" # Note: this is NOT a safe REPL! This is used for instructive purposes only\n",
" if \"import os\" in values[\"query\"]:\n",
" raise ValueError(\"'import os' not permitted in this python REPL.\")\n",
" return values\n",
" \n",
"tool = PythonREPLTool(args_schema=ToolInputModel)"
] ]
}, },
{ {
@ -68,7 +65,27 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"agent = initialize_agent([tool], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)" "import tldextract\n",
"\n",
"_APPROVED_DOMAINS = {\n",
" \"langchain\",\n",
" \"wikipedia\",\n",
"}\n",
"\n",
"class ToolInputSchema(BaseModel):\n",
"\n",
" url: str = Field(...)\n",
" \n",
" @root_validator\n",
" def validate_query(cls, values: Dict[str, Any]) -> Dict:\n",
" url = values[\"url\"]\n",
" domain = tldextract.extract(url).domain\n",
" if domain not in _APPROVED_DOMAINS:\n",
" raise ValueError(f\"Domain {domain} is not on the approved list:\"\n",
" f\" {sorted(_APPROVED_DOMAINS)}\")\n",
" return values\n",
" \n",
"tool = RequestsGetTool(args_schema=ToolInputSchema, requests_wrapper=TextRequestsWrapper())"
] ]
}, },
{ {
@ -77,40 +94,9 @@
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to define a function that adds two numbers\n",
"Action: Python REPL\n",
"Action Input: def add_two_numbers(a, b):\n",
" return a + b\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3m I need to call the function\n",
"Action: Python REPL\n",
"Action Input: print(add_two_numbers(2, 2))\u001b[0m\u001b[36;1m\u001b[1;3m4\n",
"\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 4\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'4'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"# This will succeed, since there aren't any arguments that will be triggered during validation\n", "agent = initialize_agent([tool], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)"
"agent.run(\"Run a python function that adds 2 and 2\")"
] ]
}, },
{ {
@ -124,38 +110,46 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\n", "The main title of langchain.com is \"LANG CHAIN 🦜️🔗 Official Home Page\"\n"
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to import os and then list the dir\n",
"Action: Python REPL\n",
"Action Input: import os; print(os.listdir())\u001b[0m"
]
},
{
"ename": "ValidationError",
"evalue": "1 validation error for ToolInputModel\n__root__\n 'import os' not (type=value_error)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43magent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mRun a python function that imports os and lists the dir\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:213\u001b[0m, in \u001b[0;36mChain.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(args) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 212\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`run` supports only one positional argument.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 213\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_keys[\u001b[38;5;241m0\u001b[39m]]\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m kwargs \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m args:\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m(kwargs)[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_keys[\u001b[38;5;241m0\u001b[39m]]\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:116\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_end(outputs, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprep_outputs(inputs, outputs, return_only_outputs)\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:113\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_start(\n\u001b[1;32m 108\u001b[0m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m},\n\u001b[1;32m 109\u001b[0m inputs,\n\u001b[1;32m 110\u001b[0m verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose,\n\u001b[1;32m 111\u001b[0m )\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 113\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:792\u001b[0m, in \u001b[0;36mAgentExecutor._call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# We now enter the agent loop (until it returns something).\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_continue(iterations, time_elapsed):\n\u001b[0;32m--> 792\u001b[0m next_step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_take_next_step\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 793\u001b[0m \u001b[43m \u001b[49m\u001b[43mname_to_tool_map\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcolor_mapping\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mintermediate_steps\u001b[49m\n\u001b[1;32m 794\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 795\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(next_step_output, AgentFinish):\n\u001b[1;32m 796\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_return(next_step_output, intermediate_steps)\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:695\u001b[0m, in \u001b[0;36mAgentExecutor._take_next_step\u001b[0;34m(self, name_to_tool_map, color_mapping, inputs, intermediate_steps)\u001b[0m\n\u001b[1;32m 693\u001b[0m tool_run_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mllm_prefix\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 694\u001b[0m \u001b[38;5;66;03m# We then call the tool on the tool input to get an observation\u001b[39;00m\n\u001b[0;32m--> 695\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[43mtool\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 696\u001b[0m \u001b[43m \u001b[49m\u001b[43magent_action\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtool_input\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 697\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 698\u001b[0m \u001b[43m \u001b[49m\u001b[43mcolor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcolor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 699\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtool_run_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 700\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 701\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 702\u001b[0m tool_run_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39magent\u001b[38;5;241m.\u001b[39mtool_run_logging_kwargs()\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:146\u001b[0m, in \u001b[0;36mBaseTool.run\u001b[0;34m(self, tool_input, verbose, start_color, color, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrun\u001b[39m(\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 139\u001b[0m tool_input: Union[\u001b[38;5;28mstr\u001b[39m, Dict],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any,\n\u001b[1;32m 144\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[1;32m 145\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Run the tool.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 146\u001b[0m run_input \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_parse_input\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose \u001b[38;5;129;01mand\u001b[39;00m verbose \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 148\u001b[0m verbose_ \u001b[38;5;241m=\u001b[39m verbose\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:112\u001b[0m, in \u001b[0;36mBaseTool._parse_input\u001b[0;34m(self, tool_input)\u001b[0m\n\u001b[1;32m 110\u001b[0m tool_input \u001b[38;5;241m=\u001b[39m {field_name: tool_input}\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pydantic_input_type \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 112\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mpydantic_input_type\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse_obj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 115\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124margs_schema required for tool \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m in order to\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m accept input of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(tool_input)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 117\u001b[0m )\n",
"File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:526\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.parse_obj\u001b[0;34m()\u001b[0m\n",
"File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:341\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[0;34m()\u001b[0m\n",
"\u001b[0;31mValidationError\u001b[0m: 1 validation error for ToolInputModel\n__root__\n 'import os' not (type=value_error)"
] ]
} }
], ],
"source": [ "source": [
"# This will fail, because the attempt to import os will trigger a validation error\n", "# This will succeed, since there aren't any arguments that will be triggered during validation\n",
"agent.run(\"Run a python function that imports os and lists the dir\")" "answer = agent.run(\"What's the main title on langchain.com?\")\n",
"print(answer)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"tags": []
},
"outputs": [
{
"ename": "ValidationError",
"evalue": "1 validation error for ToolInputSchema\n__root__\n Domain google is not on the approved list: ['langchain', 'wikipedia'] (type=value_error)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m agent\u001b[39m.\u001b[39;49mrun(\u001b[39m\"\u001b[39;49m\u001b[39mWhat\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39ms the main title on google.com?\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:213\u001b[0m, in \u001b[0;36mChain.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(args) \u001b[39m!=\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[1;32m 212\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m`run` supports only one positional argument.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 213\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m(args[\u001b[39m0\u001b[39;49m])[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput_keys[\u001b[39m0\u001b[39m]]\n\u001b[1;32m 215\u001b[0m \u001b[39mif\u001b[39;00m kwargs \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m args:\n\u001b[1;32m 216\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m(kwargs)[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput_keys[\u001b[39m0\u001b[39m]]\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:116\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[39mexcept\u001b[39;00m (\u001b[39mKeyboardInterrupt\u001b[39;00m, \u001b[39mException\u001b[39;00m) \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcallback_manager\u001b[39m.\u001b[39mon_chain_error(e, verbose\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose)\n\u001b[0;32m--> 116\u001b[0m \u001b[39mraise\u001b[39;00m e\n\u001b[1;32m 117\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcallback_manager\u001b[39m.\u001b[39mon_chain_end(outputs, verbose\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose)\n\u001b[1;32m 118\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprep_outputs(inputs, outputs, return_only_outputs)\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:113\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcallback_manager\u001b[39m.\u001b[39mon_chain_start(\n\u001b[1;32m 108\u001b[0m {\u001b[39m\"\u001b[39m\u001b[39mname\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m},\n\u001b[1;32m 109\u001b[0m inputs,\n\u001b[1;32m 110\u001b[0m verbose\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose,\n\u001b[1;32m 111\u001b[0m )\n\u001b[1;32m 112\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 113\u001b[0m outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call(inputs)\n\u001b[1;32m 114\u001b[0m \u001b[39mexcept\u001b[39;00m (\u001b[39mKeyboardInterrupt\u001b[39;00m, \u001b[39mException\u001b[39;00m) \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcallback_manager\u001b[39m.\u001b[39mon_chain_error(e, verbose\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose)\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:792\u001b[0m, in \u001b[0;36mAgentExecutor._call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[39m# We now enter the agent loop (until it returns something).\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_should_continue(iterations, time_elapsed):\n\u001b[0;32m--> 792\u001b[0m next_step_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_take_next_step(\n\u001b[1;32m 793\u001b[0m name_to_tool_map, color_mapping, inputs, intermediate_steps\n\u001b[1;32m 794\u001b[0m )\n\u001b[1;32m 795\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(next_step_output, AgentFinish):\n\u001b[1;32m 796\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_return(next_step_output, intermediate_steps)\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:695\u001b[0m, in \u001b[0;36mAgentExecutor._take_next_step\u001b[0;34m(self, name_to_tool_map, color_mapping, inputs, intermediate_steps)\u001b[0m\n\u001b[1;32m 693\u001b[0m tool_run_kwargs[\u001b[39m\"\u001b[39m\u001b[39mllm_prefix\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 694\u001b[0m \u001b[39m# We then call the tool on the tool input to get an observation\u001b[39;00m\n\u001b[0;32m--> 695\u001b[0m observation \u001b[39m=\u001b[39m tool\u001b[39m.\u001b[39;49mrun(\n\u001b[1;32m 696\u001b[0m agent_action\u001b[39m.\u001b[39;49mtool_input,\n\u001b[1;32m 697\u001b[0m verbose\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mverbose,\n\u001b[1;32m 698\u001b[0m color\u001b[39m=\u001b[39;49mcolor,\n\u001b[1;32m 699\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mtool_run_kwargs,\n\u001b[1;32m 700\u001b[0m )\n\u001b[1;32m 701\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 702\u001b[0m tool_run_kwargs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39magent\u001b[39m.\u001b[39mtool_run_logging_kwargs()\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:110\u001b[0m, in \u001b[0;36mBaseTool.run\u001b[0;34m(self, tool_input, verbose, start_color, color, **kwargs)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mrun\u001b[39m(\n\u001b[1;32m 102\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 103\u001b[0m tool_input: Union[\u001b[39mstr\u001b[39m, Dict],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs: Any,\n\u001b[1;32m 108\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mstr\u001b[39m:\n\u001b[1;32m 109\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Run the tool.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 110\u001b[0m run_input \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_parse_input(tool_input)\n\u001b[1;32m 111\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose \u001b[39mand\u001b[39;00m verbose \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 112\u001b[0m verbose_ \u001b[39m=\u001b[39m verbose\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:71\u001b[0m, in \u001b[0;36mBaseTool._parse_input\u001b[0;34m(self, tool_input)\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39missubclass\u001b[39m(input_args, BaseModel):\n\u001b[1;32m 70\u001b[0m key_ \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39m(\u001b[39miter\u001b[39m(input_args\u001b[39m.\u001b[39m__fields__\u001b[39m.\u001b[39mkeys()))\n\u001b[0;32m---> 71\u001b[0m input_args\u001b[39m.\u001b[39;49mparse_obj({key_: tool_input})\n\u001b[1;32m 72\u001b[0m \u001b[39m# Passing as a positional argument is more straightforward for\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \u001b[39m# backwards compatability\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39mreturn\u001b[39;00m tool_input\n",
"File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:526\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.parse_obj\u001b[0;34m()\u001b[0m\n",
"File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:341\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[0;34m()\u001b[0m\n",
"\u001b[0;31mValidationError\u001b[0m: 1 validation error for ToolInputSchema\n__root__\n Domain google is not on the approved list: ['langchain', 'wikipedia'] (type=value_error)"
]
}
],
"source": [
"agent.run(\"What's the main title on google.com?\")"
] ]
}, },
{ {

View File

@ -2,9 +2,9 @@
from inspect import signature from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Type, Union from typing import Any, Awaitable, Callable, Optional, Type, Union
from pydantic import BaseModel from pydantic import BaseModel, validate_arguments
from langchain.tools.base import BaseTool, create_args_schema_model_from_signature from langchain.tools.base import BaseTool
class Tool(BaseTool): class Tool(BaseTool):
@ -16,15 +16,6 @@ class Tool(BaseTool):
coroutine: Optional[Callable[..., Awaitable[str]]] = None coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function.""" """The asynchronous version of the function."""
@property
def args(self) -> Type[BaseModel]:
"""Generate an input pydantic model."""
if self.args_schema is not None:
return self.args_schema
# Infer the schema directly from the function to add more structured
# arguments.
return create_args_schema_model_from_signature(self.func)
def _run(self, *args: Any, **kwargs: Any) -> str: def _run(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool.""" """Use the tool."""
return self.func(*args, **kwargs) return self.func(*args, **kwargs)
@ -60,9 +51,23 @@ class InvalidTool(BaseTool):
return f"{tool_name} is not a valid tool, try another one." return f"{tool_name} is not a valid tool, try another one."
def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable: def tool(
*args: Union[str, Callable],
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
) -> Callable:
"""Make tools out of functions, can be used with or without arguments. """Make tools out of functions, can be used with or without arguments.
Args:
*args: The arguments to the tool.
return_direct: Whether to return directly from the tool rather
than continuing the agent loop.
args_schema: optional argument schema for user to specify
infer_schema: Whether to infer the schema of the arguments from
the function's signature. This also makes the resultant tool
accept a dictionary input to its `run()` function.
Requires: Requires:
- Function must be of type (str) -> str - Function must be of type (str) -> str
- Function must have a docstring - Function must have a docstring
@ -87,11 +92,13 @@ def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:
# Description example: # Description example:
# search_api(query: str) - Searches the API for the query. # search_api(query: str) - Searches the API for the query.
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}" description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
args_schema = create_args_schema_model_from_signature(func) _args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = validate_arguments(func).model # type: ignore
tool_ = Tool( tool_ = Tool(
name=tool_name, name=tool_name,
func=func, func=func,
args_schema=args_schema, args_schema=_args_schema,
description=description, description=description,
return_direct=return_direct, return_direct=return_direct,
) )

View File

@ -1,75 +1,21 @@
"""Base implementation for tools or skills.""" """Base implementation for tools or skills."""
import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
from pydantic import BaseModel, Extra, Field, create_model, validator from pydantic import BaseModel, Extra, Field, validator
from langchain.callbacks import get_callback_manager from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
def create_args_schema_model_from_signature(run_func: Callable) -> Type[BaseModel]: def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
"""Create a pydantic model type from a function's signature.""" # For backwards compatability, if run_input is a string,
signature_ = inspect.signature(run_func) # pass as a positional argument.
field_definitions: Dict[str, Any] = {} if isinstance(run_input, str):
return (run_input,), {}
for name, param in signature_.parameters.items(): else:
if name == "self": return [], run_input
continue
default_value = (
param.default if param.default != inspect.Parameter.empty else None
)
annotation = (
param.annotation if param.annotation != inspect.Parameter.empty else Any
)
# Handle functions with *args in the signature
if param.kind == inspect.Parameter.VAR_POSITIONAL:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_positional": True}),
)
# handle functions with **kwargs in the signature
elif param.kind == inspect.Parameter.VAR_KEYWORD:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_keyword": True}),
)
# Handle all other named parameters
else:
is_keyword_only = param.kind == inspect.Parameter.KEYWORD_ONLY
field_definitions[name] = (
annotation,
Field(
default=default_value, extra={"is_keyword_only": is_keyword_only}
),
)
return create_model("ArgsModel", **field_definitions) # type: ignore
def _to_args_and_kwargs(model: BaseModel) -> Tuple[Sequence, dict]:
args = []
kwargs = {}
for name, field in model.__fields__.items():
value = getattr(model, name)
# Handle *args in the function signature
if field.field_info.extra.get("extra", {}).get("is_var_positional"):
if isinstance(value, str):
# Base case for backwards compatability
args.append(value)
elif value is not None:
args.extend(value)
# Handle **kwargs in the function signature
elif field.field_info.extra.get("extra", {}).get("is_var_keyword"):
if value is not None:
kwargs.update(value)
elif field.field_info.extra.get("extra", {}).get("is_keyword_only"):
kwargs[name] = value
else:
args.append(value)
return tuple(args), kwargs
class BaseTool(ABC, BaseModel): class BaseTool(ABC, BaseModel):
@ -90,31 +36,28 @@ class BaseTool(ABC, BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property @property
def args(self) -> Type[BaseModel]: def args(self) -> Union[Type[BaseModel], Type[str]]:
"""Generate an input pydantic model.""" """Generate an input pydantic model."""
if self.args_schema is not None: return str if self.args_schema is None else self.args_schema
return self.args_schema
return create_args_schema_model_from_signature(self._run)
def _parse_input( def _parse_input(
self, self,
tool_input: Union[str, Dict], tool_input: Union[str, Dict],
) -> BaseModel: ) -> None:
"""Convert tool input to pydantic model.""" """Convert tool input to pydantic model."""
pydantic_input_type = self.args input_args = self.args
if isinstance(tool_input, str): if isinstance(tool_input, str):
# For backwards compatibility, a tool that only takes if issubclass(input_args, BaseModel):
# a single string input will be converted to a dict. key_ = next(iter(input_args.__fields__.keys()))
# to be validated. input_args.validate({key_: tool_input})
field_name = next(iter(pydantic_input_type.__fields__))
tool_input = {field_name: tool_input}
if pydantic_input_type is not None:
return pydantic_input_type.parse_obj(tool_input)
else: else:
raise ValueError( if issubclass(input_args, BaseModel):
f"args_schema required for tool {self.name} in order to" input_args.validate(tool_input)
f" accept input of type {type(tool_input)}" else:
) raise ValueError(
f"args_schema required for tool {self.name} in order to"
f" accept input of type {type(tool_input)}"
)
@validator("callback_manager", pre=True, always=True) @validator("callback_manager", pre=True, always=True)
def set_callback_manager( def set_callback_manager(
@ -143,20 +86,20 @@ class BaseTool(ABC, BaseModel):
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Run the tool.""" """Run the tool."""
run_input = self._parse_input(tool_input) self._parse_input(tool_input)
if not self.verbose and verbose is not None: if not self.verbose and verbose is not None:
verbose_ = verbose verbose_ = verbose
else: else:
verbose_ = self.verbose verbose_ = self.verbose
self.callback_manager.on_tool_start( self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
str(run_input), tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_, verbose=verbose_,
color=start_color, color=start_color,
**kwargs, **kwargs,
) )
try: try:
args, kwargs = _to_args_and_kwargs(run_input) args, kwargs = _to_args_and_kwargs(tool_input)
observation = self._run(*args, **kwargs) observation = self._run(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_) self.callback_manager.on_tool_error(e, verbose=verbose_)
@ -175,7 +118,7 @@ class BaseTool(ABC, BaseModel):
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
run_input = self._parse_input(tool_input) self._parse_input(tool_input)
if not self.verbose and verbose is not None: if not self.verbose and verbose is not None:
verbose_ = verbose verbose_ = verbose
else: else:
@ -183,7 +126,7 @@ class BaseTool(ABC, BaseModel):
if self.callback_manager.is_async: if self.callback_manager.is_async:
await self.callback_manager.on_tool_start( await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
str(run_input.dict()), tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_, verbose=verbose_,
color=start_color, color=start_color,
**kwargs, **kwargs,
@ -191,14 +134,14 @@ class BaseTool(ABC, BaseModel):
else: else:
self.callback_manager.on_tool_start( self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
str(run_input.dict()), tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_, verbose=verbose_,
color=start_color, color=start_color,
**kwargs, **kwargs,
) )
try: try:
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
args, kwargs = _to_args_and_kwargs(run_input) args, kwargs = _to_args_and_kwargs(tool_input)
observation = await self._arun(*args, **kwargs) observation = await self._arun(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async: if self.callback_manager.is_async:

View File

@ -29,6 +29,6 @@ class WriteFileTool(BaseTool):
except Exception as e: except Exception as e:
return "Error: " + str(e) return "Error: " + str(e)
async def _arun(self, tool_input: str) -> str: async def _arun(self, file_path: str, text: str) -> str:
# TODO: Add aiofiles method # TODO: Add aiofiles method
raise NotImplementedError raise NotImplementedError

View File

@ -1,6 +1,6 @@
"""Test tool utils.""" """Test tool utils."""
from datetime import datetime from datetime import datetime
from typing import Any, Optional, Type, Union from typing import Optional, Type, Union
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
@ -125,21 +125,27 @@ def test_tool_with_kwargs() -> None:
@tool(return_direct=True) @tool(return_direct=True)
def search_api( def search_api(
arg_1: float, *args: Any, ping: Optional[str] = None, **kwargs: Any arg_1: float,
ping: str = "hi",
) -> str: ) -> str:
"""Search the API for the query.""" """Search the API for the query."""
return f"arg_1={arg_1}, foo={args}, ping={ping}, kwargs={kwargs}" return f"arg_1={arg_1}, ping={ping}"
assert isinstance(search_api, Tool) assert isinstance(search_api, Tool)
result = search_api.run( result = search_api.run(
tool_input={ tool_input={
"arg_1": 3.2, "arg_1": 3.2,
"args": "fam",
"kwargs": {"bar": "baz"},
"ping": "pong", "ping": "pong",
} }
) )
assert result == "arg_1=3.2, foo=('fam',), ping=pong, kwargs={'bar': 'baz'}" assert result == "arg_1=3.2, ping=pong"
result = search_api.run(
tool_input={
"arg_1": 3.2,
}
)
assert result == "arg_1=3.2, ping=hi"
def test_missing_docstring() -> None: def test_missing_docstring() -> None: