tools refactor (#2961)

Co-authored-by: vowelparrot <130414180+vowelparrot@users.noreply.github.com>
pull/3060/head
Harrison Chase 1 year ago committed by GitHub
parent 7a8c935b90
commit db968284f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,28 +9,29 @@
"\n", "\n",
"When constructing your own agent, you will need to provide it with a list of Tools that it can use. Besides the actual function that is called, the Tool consists of several components:\n", "When constructing your own agent, you will need to provide it with a list of Tools that it can use. Besides the actual function that is called, the Tool consists of several components:\n",
"\n", "\n",
"- name (str), is required\n", "- name (str), is required and must be unique within a set of tools provided to an agent\n",
"- description (str), is optional\n", "- description (str), is optional but recommended, as it is used by an agent to determine tool use\n",
"- return_direct (bool), defaults to False\n", "- return_direct (bool), defaults to False\n",
"\n", "\n",
"The function that should be called when the tool is selected should take as input a single string and return a single string.\n", "The function that should be called when the tool is selected should return a single string.\n",
"\n", "\n",
"There are two ways to define a tool, we will cover both in the example below." "There are two ways to define a tool, we will cover both in the example below."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 1,
"id": "1aaba18c", "id": "1aaba18c",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Import things that are needed generically\n", "# Import things that are needed generically\n",
"from langchain.agents import initialize_agent, Tool\n", "from langchain import LLMMathChain, SerpAPIWrapper\n",
"from langchain.agents import AgentType\n", "from langchain.agents import AgentType, Tool, initialize_agent, tool\n",
"from langchain.tools import BaseTool\n", "from langchain.chat_models import ChatOpenAI\n",
"from langchain.llms import OpenAI\n", "from langchain.tools import BaseTool"
"from langchain import LLMMathChain, SerpAPIWrapper"
] ]
}, },
{ {
@ -43,12 +44,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 2,
"id": "36ed392e", "id": "36ed392e",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"llm = OpenAI(temperature=0)" "llm = ChatOpenAI(temperature=0)"
] ]
}, },
{ {
@ -74,7 +77,9 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 3,
"id": "56ff7670", "id": "56ff7670",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Load the tool configs that are needed.\n", "# Load the tool configs that are needed.\n",
@ -98,7 +103,9 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 4,
"id": "5b93047d", "id": "5b93047d",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Construct the agent. We will use the default agent type here.\n", "# Construct the agent. We will use the default agent type here.\n",
@ -110,7 +117,9 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 5,
"id": "6f96a891", "id": "6f96a891",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -119,29 +128,24 @@
"\n", "\n",
"\n", "\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n", "\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n",
"Action: Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Gigi Hadid's age\n",
"Action: Search\n", "Action: Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n", "Action Input: \"Gigi Hadid age\"\u001b[0m\u001b[36;1m\u001b[1;3m27 years\u001b[0m\u001b[32;1m\u001b[1;3mI need to calculate her age raised to the 0.43 power\n",
"Observation: \u001b[36;1m\u001b[1;3mCamila Morrone\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now need to calculate her age raised to the 0.43 power\n",
"Action: Calculator\n", "Action: Calculator\n",
"Action Input: 22^0.43\u001b[0m\n", "Action Input: 27^(0.43)\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
"22^0.43\u001b[32;1m\u001b[1;3m\n", "27^(0.43)\u001b[32;1m\u001b[1;3m```text\n",
"```python\n", "27**(0.43)\n",
"import math\n",
"print(math.pow(22, 0.43))\n",
"```\n", "```\n",
"...numexpr.evaluate(\"27**(0.43)\")...\n",
"\u001b[0m\n", "\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m3.777824273683966\n", "Answer: \u001b[33;1m\u001b[1;3m4.125593352125936\u001b[0m\n",
"\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n",
"\n", "\u001b[33;1m\u001b[1;3mAnswer: 4.125593352125936\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n",
"Observation: \u001b[33;1m\u001b[1;3mAnswer: 3.777824273683966\n", "Final Answer: 4.125593352125936\u001b[0m\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -149,7 +153,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"\"Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\"" "'4.125593352125936'"
] ]
}, },
"execution_count": 5, "execution_count": 5,
@ -171,9 +175,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 6,
"id": "c58a7c40", "id": "c58a7c40",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"class CustomSearchTool(BaseTool):\n", "class CustomSearchTool(BaseTool):\n",
@ -203,9 +209,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 7,
"id": "3318a46f", "id": "3318a46f",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"tools = [CustomSearchTool(), CustomCalculatorTool()]" "tools = [CustomSearchTool(), CustomCalculatorTool()]"
@ -213,9 +221,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 8,
"id": "ee2d0f3a", "id": "ee2d0f3a",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)" "agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)"
@ -223,9 +233,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 9,
"id": "6a2cebbf", "id": "6a2cebbf",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -234,29 +246,24 @@
"\n", "\n",
"\n", "\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n", "\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n",
"Action: Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mI now know Leo DiCaprio's girlfriend's name and that he's currently linked to Gigi Hadid. I need to find out Camila Morrone's age.\n",
"Action: Search\n", "Action: Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n", "Action Input: \"Camila Morrone age\"\u001b[0m\u001b[36;1m\u001b[1;3m25 years\u001b[0m\u001b[32;1m\u001b[1;3mI have Camila Morrone's age. I need to calculate her age raised to the 0.43 power.\n",
"Observation: \u001b[36;1m\u001b[1;3mCamila Morrone\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now need to calculate her age raised to the 0.43 power\n",
"Action: Calculator\n", "Action: Calculator\n",
"Action Input: 22^0.43\u001b[0m\n", "Action Input: 25^(0.43)\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
"22^0.43\u001b[32;1m\u001b[1;3m\n", "25^(0.43)\u001b[32;1m\u001b[1;3m```text\n",
"```python\n", "25**(0.43)\n",
"import math\n",
"print(math.pow(22, 0.43))\n",
"```\n", "```\n",
"...numexpr.evaluate(\"25**(0.43)\")...\n",
"\u001b[0m\n", "\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m3.777824273683966\n", "Answer: \u001b[33;1m\u001b[1;3m3.991298452658078\u001b[0m\n",
"\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n",
"\n", "\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the answer to the original question.\n",
"Observation: \u001b[33;1m\u001b[1;3mAnswer: 3.777824273683966\n", "Final Answer: Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\u001b[0m\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -264,10 +271,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"\"Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\"" "\"Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\""
] ]
}, },
"execution_count": 11, "execution_count": 9,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -288,9 +295,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 10,
"id": "8f15307d", "id": "8f15307d",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.agents import tool\n", "from langchain.agents import tool\n",
@ -298,22 +307,24 @@
"@tool\n", "@tool\n",
"def search_api(query: str) -> str:\n", "def search_api(query: str) -> str:\n",
" \"\"\"Searches the API for the query.\"\"\"\n", " \"\"\"Searches the API for the query.\"\"\"\n",
" return \"Results\"" " return f\"Results for query {query}\""
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 11,
"id": "0a23b91b", "id": "0a23b91b",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1184e0cd0>, func=<function search_api at 0x1635f8700>, coroutine=None)" "Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', args_schema=<class 'pydantic.main.ArgsModel'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x124346f10>, func=<function search_api at 0x16ad6e020>, coroutine=None)"
] ]
}, },
"execution_count": 5, "execution_count": 11,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -332,9 +343,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 12,
"id": "28cdf04d", "id": "28cdf04d",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"@tool(\"search\", return_direct=True)\n", "@tool(\"search\", return_direct=True)\n",
@ -345,17 +358,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 13,
"id": "1085a4bd", "id": "1085a4bd",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1184e0cd0>, func=<function search_api at 0x1635f8670>, coroutine=None)" "Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=<class 'pydantic.main.ArgsModel'>, return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x124346f10>, func=<function search_api at 0x16ad6d3a0>, coroutine=None)"
] ]
}, },
"execution_count": 7, "execution_count": 13,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -376,7 +389,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 14,
"id": "79213f40", "id": "79213f40",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -386,7 +399,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 15,
"id": "e1067dcb", "id": "e1067dcb",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -396,7 +409,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 16,
"id": "6c66ffe8", "id": "6c66ffe8",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -406,7 +419,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 17,
"id": "f45b5bc3", "id": "f45b5bc3",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -416,7 +429,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 18,
"id": "565e2b9b", "id": "565e2b9b",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -427,21 +440,12 @@
"\n", "\n",
"\n", "\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n", "\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age.\n",
"Action: Google Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mCamila Morrone\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Camila Morrone's age\n",
"Action: Google Search\n", "Action: Google Search\n",
"Action Input: \"Camila Morrone age\"\u001b[0m\n", "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mNow I need to find out Camila Morrone's current age.\n",
"Observation: \u001b[36;1m\u001b[1;3m25 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 25 raised to the 0.43 power\n",
"Action: Calculator\n", "Action: Calculator\n",
"Action Input: 25^0.43\u001b[0m\n", "Action Input: 25^0.43\u001b[0m\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
"Observation: \u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\n", "Final Answer: Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\u001b[0m\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Camila Morrone is Leo DiCaprio's girlfriend and her current age raised to the 0.43 power is 3.991298452658078.\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -449,10 +453,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"\"Camila Morrone is Leo DiCaprio's girlfriend and her current age raised to the 0.43 power is 3.991298452658078.\"" "\"Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\""
] ]
}, },
"execution_count": 12, "execution_count": 18,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -478,7 +482,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 19,
"id": "3450512e", "id": "3450512e",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -507,7 +511,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 20,
"id": "4b9a7849", "id": "4b9a7849",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -520,9 +524,7 @@
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I should use a music search engine to find the answer\n", "\u001b[32;1m\u001b[1;3m I should use a music search engine to find the answer\n",
"Action: Music Search\n", "Action: Music Search\n",
"Action Input: most famous song of christmas\u001b[0m\n", "Action Input: most famous song of christmas\u001b[0m\u001b[33;1m\u001b[1;3m'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Observation: \u001b[33;1m\u001b[1;3m'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\n", "Final Answer: 'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
@ -534,7 +536,7 @@
"\"'All I Want For Christmas Is You' by Mariah Carey.\"" "\"'All I Want For Christmas Is You' by Mariah Carey.\""
] ]
}, },
"execution_count": 14, "execution_count": 20,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -554,7 +556,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 21,
"id": "3bb6185f", "id": "3bb6185f",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -572,7 +574,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 22,
"id": "113ddb84", "id": "113ddb84",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -583,9 +585,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 23,
"id": "582439a6", "id": "582439a6",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -596,9 +600,7 @@
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to calculate this\n", "\u001b[32;1m\u001b[1;3m I need to calculate this\n",
"Action: Calculator\n", "Action: Calculator\n",
"Action Input: 2**.12\u001b[0m\n", "Action Input: 2**.12\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 1.086734862526058\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.2599210498948732\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -606,10 +608,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'Answer: 1.2599210498948732'" "'Answer: 1.086734862526058'"
] ]
}, },
"execution_count": 17, "execution_count": 23,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -618,10 +620,149 @@
"agent.run(\"whats 2**.12\")" "agent.run(\"whats 2**.12\")"
] ]
}, },
{
"cell_type": "markdown",
"id": "8aa3c353-bd89-467c-9c27-b83a90cd4daa",
"metadata": {},
"source": [
"## Multi-argument tools\n",
"\n",
"Many functions expect structured inputs. These can also be supported using the Tool decorator or by directly subclassing `BaseTool`! We have to modify the LLM's OutputParser to map its string output to a dictionary to pass to the action, however."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 24,
"id": "537bc628", "id": "537bc628",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from typing import Optional, Union\n",
"\n",
"@tool\n",
"def custom_search(k: int, query: str, other_arg: Optional[str] = None):\n",
" \"\"\"The custom search function.\"\"\"\n",
" return f\"Here are the results for the custom search: k={k}, query={query}, other_arg={other_arg}\""
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "d5c992cf-776a-40cd-a6c4-e7cf65ea709e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import re\n",
"from langchain.schema import (\n",
" AgentAction,\n",
" AgentFinish,\n",
")\n",
"from langchain.agents import AgentOutputParser\n",
"\n",
"# We will add a custom parser to map the arguments to a dictionary\n",
"class CustomOutputParser(AgentOutputParser):\n",
" \n",
" def parse_tool_input(self, action_input: str) -> dict:\n",
" # Regex pattern to match arguments and their values\n",
" pattern = r\"(\\w+)\\s*=\\s*(None|\\\"[^\\\"]*\\\"|\\d+)\"\n",
" matches = re.findall(pattern, action_input)\n",
" \n",
" if not matches:\n",
" raise ValueError(f\"Could not parse action input: `{action_input}`\")\n",
"\n",
" # Create a dictionary with the parsed arguments and their values\n",
" parsed_input = {}\n",
" for arg, value in matches:\n",
" if value == \"None\":\n",
" parsed_value = None\n",
" elif value.isdigit():\n",
" parsed_value = int(value)\n",
" else:\n",
" parsed_value = value.strip('\"')\n",
" parsed_input[arg] = parsed_value\n",
"\n",
" return parsed_input\n",
" \n",
" def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:\n",
" # Check if agent should finish\n",
" if \"Final Answer:\" in llm_output:\n",
" return AgentFinish(\n",
" # Return values is generally always a dictionary with a single `output` key\n",
" # It is not recommended to try anything else at the moment :)\n",
" return_values={\"output\": llm_output.split(\"Final Answer:\")[-1].strip()},\n",
" log=llm_output,\n",
" )\n",
" # Parse out the action and action input\n",
" regex = r\"Action\\s*\\d*\\s*:(.*?)\\nAction\\s*\\d*\\s*Input\\s*\\d*\\s*:[\\s]*(.*)\"\n",
" match = re.search(regex, llm_output, re.DOTALL)\n",
" if not match:\n",
" raise ValueError(f\"Could not parse LLM output: `{llm_output}`\")\n",
" action = match.group(1).strip()\n",
" action_input = match.group(2)\n",
" tool_input = self.parse_tool_input(action_input)\n",
" # Return the action and action \n",
" return AgentAction(tool=action, tool_input=tool_input, log=llm_output)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "68269547-1482-4138-a6ea-58f00b4a9548",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)\n",
"agent = initialize_agent([custom_search], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, agent_kwargs={\"output_parser\": CustomOutputParser()})"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "0947835a-691c-4f51-b8f4-6744e0e48ab1",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to use a search function to find the answer\n",
"Action: custom_search\n",
"Action Input: k=1, query=\"me\"\u001b[0m\u001b[36;1m\u001b[1;3mHere are the results for the custom search: k=1, query=me, other_arg=None\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: The results of the custom search for k=1, query=me, other_arg=None.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'The results of the custom search for k=1, query=me, other_arg=None.'"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"Search for me and tell me whatever it says\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "caf39c66-102b-42c1-baf2-777a49886ce4",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []
@ -643,7 +784,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.1" "version": "3.11.2"
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {

@ -0,0 +1,190 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"# Tool Input Schema\n",
"\n",
"By default, tools infer the argument schema by inspecting the function signature. For more strict requirements, custom input schema can be specified, along with custom validation logic."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.agents import initialize_agent, Tool\n",
"from langchain.agents import AgentType\n",
"from langchain.llms import OpenAI\n",
"from langchain.tools.python.tool import PythonREPLTool\n",
"from pydantic import BaseModel\n",
"from pydantic import Field, root_validator\n",
"from typing import Dict, Any"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"class ToolInputModel(BaseModel):\n",
" query: str = Field(...)\n",
" \n",
" @root_validator\n",
" def validate_query(cls, values: Dict[str, Any]) -> Dict:\n",
" # Note: this is NOT a safe REPL! This is used for instructive purposes only\n",
" if \"import os\" in values[\"query\"]:\n",
" raise ValueError(\"'import os' not permitted in this python REPL.\")\n",
" return values\n",
" \n",
"tool = PythonREPLTool(args_schema=ToolInputModel)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"agent = initialize_agent([tool], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to define a function that adds two numbers\n",
"Action: Python REPL\n",
"Action Input: def add_two_numbers(a, b):\n",
" return a + b\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3m I need to call the function\n",
"Action: Python REPL\n",
"Action Input: print(add_two_numbers(2, 2))\u001b[0m\u001b[36;1m\u001b[1;3m4\n",
"\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 4\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'4'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# This will succeed, since there aren't any arguments that will be triggered during validation\n",
"agent.run(\"Run a python function that adds 2 and 2\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to import os and then list the dir\n",
"Action: Python REPL\n",
"Action Input: import os; print(os.listdir())\u001b[0m"
]
},
{
"ename": "ValidationError",
"evalue": "1 validation error for ToolInputModel\n__root__\n 'import os' not (type=value_error)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43magent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mRun a python function that imports os and lists the dir\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:213\u001b[0m, in \u001b[0;36mChain.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(args) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 212\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`run` supports only one positional argument.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 213\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_keys[\u001b[38;5;241m0\u001b[39m]]\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m kwargs \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m args:\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m(kwargs)[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_keys[\u001b[38;5;241m0\u001b[39m]]\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:116\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_end(outputs, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprep_outputs(inputs, outputs, return_only_outputs)\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:113\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_start(\n\u001b[1;32m 108\u001b[0m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m},\n\u001b[1;32m 109\u001b[0m inputs,\n\u001b[1;32m 110\u001b[0m verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose,\n\u001b[1;32m 111\u001b[0m )\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 113\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:792\u001b[0m, in \u001b[0;36mAgentExecutor._call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# We now enter the agent loop (until it returns something).\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_continue(iterations, time_elapsed):\n\u001b[0;32m--> 792\u001b[0m next_step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_take_next_step\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 793\u001b[0m \u001b[43m \u001b[49m\u001b[43mname_to_tool_map\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcolor_mapping\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mintermediate_steps\u001b[49m\n\u001b[1;32m 794\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 795\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(next_step_output, AgentFinish):\n\u001b[1;32m 796\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_return(next_step_output, intermediate_steps)\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:695\u001b[0m, in \u001b[0;36mAgentExecutor._take_next_step\u001b[0;34m(self, name_to_tool_map, color_mapping, inputs, intermediate_steps)\u001b[0m\n\u001b[1;32m 693\u001b[0m tool_run_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mllm_prefix\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 694\u001b[0m \u001b[38;5;66;03m# We then call the tool on the tool input to get an observation\u001b[39;00m\n\u001b[0;32m--> 695\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[43mtool\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 696\u001b[0m \u001b[43m \u001b[49m\u001b[43magent_action\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtool_input\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 697\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 698\u001b[0m \u001b[43m \u001b[49m\u001b[43mcolor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcolor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 699\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtool_run_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 700\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 701\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 702\u001b[0m tool_run_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39magent\u001b[38;5;241m.\u001b[39mtool_run_logging_kwargs()\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:146\u001b[0m, in \u001b[0;36mBaseTool.run\u001b[0;34m(self, tool_input, verbose, start_color, color, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrun\u001b[39m(\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 139\u001b[0m tool_input: Union[\u001b[38;5;28mstr\u001b[39m, Dict],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any,\n\u001b[1;32m 144\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[1;32m 145\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Run the tool.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 146\u001b[0m run_input \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_parse_input\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose \u001b[38;5;129;01mand\u001b[39;00m verbose \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 148\u001b[0m verbose_ \u001b[38;5;241m=\u001b[39m verbose\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:112\u001b[0m, in \u001b[0;36mBaseTool._parse_input\u001b[0;34m(self, tool_input)\u001b[0m\n\u001b[1;32m 110\u001b[0m tool_input \u001b[38;5;241m=\u001b[39m {field_name: tool_input}\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pydantic_input_type \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 112\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mpydantic_input_type\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse_obj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 115\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124margs_schema required for tool \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m in order to\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m accept input of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(tool_input)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 117\u001b[0m )\n",
"File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:526\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.parse_obj\u001b[0;34m()\u001b[0m\n",
"File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:341\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[0;34m()\u001b[0m\n",
"\u001b[0;31mValidationError\u001b[0m: 1 validation error for ToolInputModel\n__root__\n 'import os' not (type=value_error)"
]
}
],
"source": [
"# This will fail, because the attempt to import os will trigger a validation error\n",
"agent.run(\"Run a python function that imports os and lists the dir\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

@ -1,25 +1,38 @@
"""Interface for tools.""" """Interface for tools."""
from inspect import signature from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Union from typing import Any, Awaitable, Callable, Optional, Type, Union
from langchain.tools.base import BaseTool from pydantic import BaseModel
from langchain.tools.base import BaseTool, create_args_schema_model_from_signature
class Tool(BaseTool): class Tool(BaseTool):
"""Tool that takes in function or coroutine directly.""" """Tool that takes in function or coroutine directly."""
description: str = "" description: str = ""
func: Callable[[str], str] func: Callable[..., str]
coroutine: Optional[Callable[[str], Awaitable[str]]] = None """The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[str]]] = None
def _run(self, tool_input: str) -> str: """The asynchronous version of the function."""
@property
def args(self) -> Type[BaseModel]:
"""Generate an input pydantic model."""
if self.args_schema is not None:
return self.args_schema
# Infer the schema directly from the function to add more structured
# arguments.
return create_args_schema_model_from_signature(self.func)
def _run(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool.""" """Use the tool."""
return self.func(tool_input) return self.func(*args, **kwargs)
async def _arun(self, tool_input: str) -> str: async def _arun(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
if self.coroutine: if self.coroutine:
return await self.coroutine(tool_input) return await self.coroutine(*args, **kwargs)
raise NotImplementedError("Tool does not support async") raise NotImplementedError("Tool does not support async")
# TODO: this is for backwards compatibility, remove in future # TODO: this is for backwards compatibility, remove in future
@ -69,14 +82,16 @@ def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:
""" """
def _make_with_name(tool_name: str) -> Callable: def _make_with_name(tool_name: str) -> Callable:
def _make_tool(func: Callable[[str], str]) -> Tool: def _make_tool(func: Callable) -> Tool:
assert func.__doc__, "Function must have a docstring" assert func.__doc__, "Function must have a docstring"
# Description example: # Description example:
# search_api(query: str) - Searches the API for the query. # search_api(query: str) - Searches the API for the query.
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}" description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
args_schema = create_args_schema_model_from_signature(func)
tool_ = Tool( tool_ = Tool(
name=tool_name, name=tool_name,
func=func, func=func,
args_schema=args_schema,
description=description, description=description,
return_direct=return_direct, return_direct=return_direct,
) )

@ -371,7 +371,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.starts += 1 self.starts += 1
tool = action.tool tool = action.tool
tool_input = action.tool_input tool_input = str(action.tool_input)
log = action.log log = action.log
resp = self._init_resp() resp = self._init_resp()

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, NamedTuple, Optional, TypeVar from typing import Any, Dict, Generic, List, NamedTuple, Optional, TypeVar, Union
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
@ -31,7 +31,7 @@ class AgentAction(NamedTuple):
"""Agent's action to take.""" """Agent's action to take."""
tool: str tool: str
tool_input: str tool_input: Union[str, dict]
log: str log: str

@ -1,19 +1,84 @@
"""Base implementation for tools or skills.""" """Base implementation for tools or skills."""
from abc import abstractmethod import inspect
from typing import Any, Optional from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from pydantic import BaseModel, Extra, Field, validator from pydantic import BaseModel, Extra, Field, create_model, validator
from langchain.callbacks import get_callback_manager from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
class BaseTool(BaseModel): def create_args_schema_model_from_signature(run_func: Callable) -> Type[BaseModel]:
"""Class responsible for defining a tool or skill for an LLM.""" """Create a pydantic model type from a function's signature."""
signature_ = inspect.signature(run_func)
field_definitions: Dict[str, Any] = {}
for name, param in signature_.parameters.items():
if name == "self":
continue
default_value = (
param.default if param.default != inspect.Parameter.empty else None
)
annotation = (
param.annotation if param.annotation != inspect.Parameter.empty else Any
)
# Handle functions with *args in the signature
if param.kind == inspect.Parameter.VAR_POSITIONAL:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_positional": True}),
)
# handle functions with **kwargs in the signature
elif param.kind == inspect.Parameter.VAR_KEYWORD:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_keyword": True}),
)
# Handle all other named parameters
else:
is_keyword_only = param.kind == inspect.Parameter.KEYWORD_ONLY
field_definitions[name] = (
annotation,
Field(
default=default_value, extra={"is_keyword_only": is_keyword_only}
),
)
return create_model("ArgsModel", **field_definitions) # type: ignore
def _to_args_and_kwargs(model: BaseModel) -> Tuple[Sequence, dict]:
args = []
kwargs = {}
for name, field in model.__fields__.items():
value = getattr(model, name)
# Handle *args in the function signature
if field.field_info.extra.get("extra", {}).get("is_var_positional"):
if isinstance(value, str):
# Base case for backwards compatability
args.append(value)
elif value is not None:
args.extend(value)
# Handle **kwargs in the function signature
elif field.field_info.extra.get("extra", {}).get("is_var_keyword"):
if value is not None:
kwargs.update(value)
elif field.field_info.extra.get("extra", {}).get("is_keyword_only"):
kwargs[name] = value
else:
args.append(value)
return tuple(args), kwargs
class BaseTool(ABC, BaseModel):
"""Interface LangChain tools must implement."""
name: str name: str
description: str description: str
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False return_direct: bool = False
verbose: bool = False verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
@ -24,6 +89,33 @@ class BaseTool(BaseModel):
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property
def args(self) -> Type[BaseModel]:
"""Generate an input pydantic model."""
if self.args_schema is not None:
return self.args_schema
return create_args_schema_model_from_signature(self._run)
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> BaseModel:
"""Convert tool input to pydantic model."""
pydantic_input_type = self.args
if isinstance(tool_input, str):
# For backwards compatibility, a tool that only takes
# a single string input will be converted to a dict.
# to be validated.
field_name = next(iter(pydantic_input_type.__fields__))
tool_input = {field_name: tool_input}
if pydantic_input_type is not None:
return pydantic_input_type.parse_obj(tool_input)
else:
raise ValueError(
f"args_schema required for tool {self.name} in order to"
f" accept input of type {type(tool_input)}"
)
@validator("callback_manager", pre=True, always=True) @validator("callback_manager", pre=True, always=True)
def set_callback_manager( def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager] cls, callback_manager: Optional[BaseCallbackManager]
@ -35,39 +127,37 @@ class BaseTool(BaseModel):
return callback_manager or get_callback_manager() return callback_manager or get_callback_manager()
@abstractmethod @abstractmethod
def _run(self, tool_input: str) -> str: def _run(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool.""" """Use the tool."""
@abstractmethod @abstractmethod
async def _arun(self, tool_input: str) -> str: async def _arun(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
def __call__(self, tool_input: str) -> str:
"""Make tools callable with str input."""
return self.run(tool_input)
def run( def run(
self, self,
tool_input: str, tool_input: Union[str, Dict],
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
start_color: Optional[str] = "green", start_color: Optional[str] = "green",
color: Optional[str] = "green", color: Optional[str] = "green",
**kwargs: Any **kwargs: Any,
) -> str: ) -> str:
"""Run the tool.""" """Run the tool."""
run_input = self._parse_input(tool_input)
if not self.verbose and verbose is not None: if not self.verbose and verbose is not None:
verbose_ = verbose verbose_ = verbose
else: else:
verbose_ = self.verbose verbose_ = self.verbose
self.callback_manager.on_tool_start( self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
tool_input, str(run_input),
verbose=verbose_, verbose=verbose_,
color=start_color, color=start_color,
**kwargs, **kwargs,
) )
try: try:
observation = self._run(tool_input) args, kwargs = _to_args_and_kwargs(run_input)
observation = self._run(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_) self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e raise e
@ -78,13 +168,14 @@ class BaseTool(BaseModel):
async def arun( async def arun(
self, self,
tool_input: str, tool_input: Union[str, Dict],
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
start_color: Optional[str] = "green", start_color: Optional[str] = "green",
color: Optional[str] = "green", color: Optional[str] = "green",
**kwargs: Any **kwargs: Any,
) -> str: ) -> str:
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
run_input = self._parse_input(tool_input)
if not self.verbose and verbose is not None: if not self.verbose and verbose is not None:
verbose_ = verbose verbose_ = verbose
else: else:
@ -92,7 +183,7 @@ class BaseTool(BaseModel):
if self.callback_manager.is_async: if self.callback_manager.is_async:
await self.callback_manager.on_tool_start( await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
tool_input, str(run_input.dict()),
verbose=verbose_, verbose=verbose_,
color=start_color, color=start_color,
**kwargs, **kwargs,
@ -100,14 +191,15 @@ class BaseTool(BaseModel):
else: else:
self.callback_manager.on_tool_start( self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {"name": self.name, "description": self.description},
tool_input, str(run_input.dict()),
verbose=verbose_, verbose=verbose_,
color=start_color, color=start_color,
**kwargs, **kwargs,
) )
try: try:
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
observation = await self._arun(tool_input) args, kwargs = _to_args_and_kwargs(run_input)
observation = await self._arun(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async: if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_) await self.callback_manager.on_tool_error(e, verbose=verbose_)
@ -123,3 +215,7 @@ class BaseTool(BaseModel):
observation, verbose=verbose_, color=color, name=self.name, **kwargs observation, verbose=verbose_, color=color, name=self.name, **kwargs
) )
return observation return observation
def __call__(self, tool_input: str) -> str:
"""Make tool callable."""
return self.run(tool_input)

@ -0,0 +1,29 @@
from typing import Type
from pydantic import BaseModel, Field
from langchain.tools.base import BaseTool
class ReadFileInput(BaseModel):
"""Input for ReadFileTool."""
file_path: str = Field(..., description="name of file")
class ReadFileTool(BaseTool):
name: str = "read_file"
tool_args: Type[BaseModel] = ReadFileInput
description: str = "Read file from disk"
def _run(self, file_path: str) -> str:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return content
except Exception as e:
return "Error: " + str(e)
async def _arun(self, tool_input: str) -> str:
# TODO: Add aiofiles method
raise NotImplementedError

@ -0,0 +1,34 @@
import os
from typing import Type
from pydantic import BaseModel, Field
from langchain.tools.base import BaseTool
class WriteFileInput(BaseModel):
"""Input for WriteFileTool."""
file_path: str = Field(..., description="name of file")
text: str = Field(..., description="text to write to file")
class WriteFileTool(BaseTool):
name: str = "write_file"
tool_args: Type[BaseModel] = WriteFileInput
description: str = "Write file to disk"
def _run(self, file_path: str, text: str) -> str:
try:
directory = os.path.dirname(file_path)
if not os.path.exists(directory) and directory:
os.makedirs(directory)
with open(file_path, "w", encoding="utf-8") as f:
f.write(text)
return "File written to successfully."
except Exception as e:
return "Error: " + str(e)
async def _arun(self, tool_input: str) -> str:
# TODO: Add aiofiles method
raise NotImplementedError

@ -16,7 +16,7 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
def get_action_and_input(text: str) -> Tuple[str, str]: def get_action_and_input(text: str) -> Tuple[str, str]:
output = MRKLOutputParser().parse(text) output = MRKLOutputParser().parse(text)
if isinstance(output, AgentAction): if isinstance(output, AgentAction):
return output.tool, output.tool_input return output.tool, str(output.tool_input)
else: else:
return "Final Answer", output.return_values["output"] return "Final Answer", output.return_values["output"]

@ -1,7 +1,12 @@
"""Test tool utils.""" """Test tool utils."""
from datetime import datetime
from typing import Any, Optional, Type, Union
import pytest import pytest
from pydantic import BaseModel
from langchain.agents.tools import Tool, tool from langchain.agents.tools import Tool, tool
from langchain.tools.base import BaseTool
def test_unnamed_decorator() -> None: def test_unnamed_decorator() -> None:
@ -18,6 +23,64 @@ def test_unnamed_decorator() -> None:
assert search_api("test") == "API result" assert search_api("test") == "API result"
class _MockSchema(BaseModel):
arg1: int
arg2: bool
arg3: Optional[dict] = None
class _MockStructuredTool(BaseTool):
name = "structured_api"
args_schema: Type[BaseModel] = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
raise NotImplementedError
def test_structured_args() -> None:
"""Test functionality with structured arguments."""
structured_api = _MockStructuredTool()
assert isinstance(structured_api, BaseTool)
assert structured_api.name == "structured_api"
expected_result = "1 True {'foo': 'bar'}"
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
assert structured_api.run(args) == expected_result
def test_structured_args_decorator() -> None:
"""Test functionality with structured arguments parsed as a decorator."""
@tool
def structured_tool_input(
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1}, {arg2}, {opt_arg}"
assert isinstance(structured_tool_input, Tool)
assert structured_tool_input.name == "structured_tool_input"
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
expected_result = "1, 0.001, {'foo': 'bar'}"
assert structured_tool_input.run(args) == expected_result
def test_empty_args_decorator() -> None:
"""Test functionality with no args parsed as a decorator."""
@tool
def empty_tool_input() -> str:
"""Return a constant."""
return "the empty result"
assert isinstance(empty_tool_input, Tool)
assert empty_tool_input.name == "empty_tool_input"
assert empty_tool_input.run({}) == "the empty result"
def test_named_tool_decorator() -> None: def test_named_tool_decorator() -> None:
"""Test functionality when arguments are provided as input to decorator.""" """Test functionality when arguments are provided as input to decorator."""
@ -57,6 +120,28 @@ def test_unnamed_tool_decorator_return_direct() -> None:
assert search_api.return_direct assert search_api.return_direct
def test_tool_with_kwargs() -> None:
"""Test functionality when only return direct is provided."""
@tool(return_direct=True)
def search_api(
arg_1: float, *args: Any, ping: Optional[str] = None, **kwargs: Any
) -> str:
"""Search the API for the query."""
return f"arg_1={arg_1}, foo={args}, ping={ping}, kwargs={kwargs}"
assert isinstance(search_api, Tool)
result = search_api.run(
tool_input={
"arg_1": 3.2,
"args": "fam",
"kwargs": {"bar": "baz"},
"ping": "pong",
}
)
assert result == "arg_1=3.2, foo=('fam',), ping=pong, kwargs={'bar': 'baz'}"
def test_missing_docstring() -> None: def test_missing_docstring() -> None:
"""Test error is raised when docstring is missing.""" """Test error is raised when docstring is missing."""
# expect to throw a value error if theres no docstring # expect to throw a value error if theres no docstring
@ -67,7 +152,7 @@ def test_missing_docstring() -> None:
return "API result" return "API result"
def test_create_tool_posistional_args() -> None: def test_create_tool_positional_args() -> None:
"""Test that positional arguments are allowed.""" """Test that positional arguments are allowed."""
test_tool = Tool("test_name", lambda x: x, "test_description") test_tool = Tool("test_name", lambda x: x, "test_description")
assert test_tool("foo") == "foo" assert test_tool("foo") == "foo"

@ -0,0 +1,100 @@
import asyncio
from typing import Any, Dict
import pytest
from langchain.requests import TextRequestsWrapper
from langchain.tools.requests.tool import (
RequestsDeleteTool,
RequestsGetTool,
RequestsPatchTool,
RequestsPostTool,
RequestsPutTool,
_parse_input,
)
class _MockTextRequestsWrapper(TextRequestsWrapper):
@staticmethod
def get(url: str, **kwargs: Any) -> str:
return "get_response"
@staticmethod
async def aget(url: str, **kwargs: Any) -> str:
return "aget_response"
@staticmethod
def post(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
return f"post {str(data)}"
@staticmethod
async def apost(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
return f"apost {str(data)}"
@staticmethod
def patch(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
return f"patch {str(data)}"
@staticmethod
async def apatch(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
return f"apatch {str(data)}"
@staticmethod
def put(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
return f"put {str(data)}"
@staticmethod
async def aput(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
return f"aput {str(data)}"
@staticmethod
def delete(url: str, **kwargs: Any) -> str:
return "delete_response"
@staticmethod
async def adelete(url: str, **kwargs: Any) -> str:
return "adelete_response"
@pytest.fixture
def mock_requests_wrapper() -> TextRequestsWrapper:
return _MockTextRequestsWrapper()
def test_parse_input() -> None:
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
expected_output = {"url": "https://example.com", "data": {"key": "value"}}
assert _parse_input(input_text) == expected_output
def test_requests_get_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
tool = RequestsGetTool(requests_wrapper=mock_requests_wrapper)
assert tool.run("https://example.com") == "get_response"
assert asyncio.run(tool.arun("https://example.com")) == "aget_response"
def test_requests_post_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
tool = RequestsPostTool(requests_wrapper=mock_requests_wrapper)
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
assert tool.run(input_text) == "post {'key': 'value'}"
assert asyncio.run(tool.arun(input_text)) == "apost {'key': 'value'}"
def test_requests_patch_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
tool = RequestsPatchTool(requests_wrapper=mock_requests_wrapper)
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
assert tool.run(input_text) == "patch {'key': 'value'}"
assert asyncio.run(tool.arun(input_text)) == "apatch {'key': 'value'}"
def test_requests_put_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
tool = RequestsPutTool(requests_wrapper=mock_requests_wrapper)
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
assert tool.run(input_text) == "put {'key': 'value'}"
assert asyncio.run(tool.arun(input_text)) == "aput {'key': 'value'}"
def test_requests_delete_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
tool = RequestsDeleteTool(requests_wrapper=mock_requests_wrapper)
assert tool.run("https://example.com") == "delete_response"
assert asyncio.run(tool.arun("https://example.com")) == "adelete_response"
Loading…
Cancel
Save