tools refactor (#2961)

Co-authored-by: vowelparrot <130414180+vowelparrot@users.noreply.github.com>
fix_agent_callbacks
Harrison Chase 1 year ago committed by GitHub
parent 7a8c935b90
commit db968284f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,28 +9,29 @@
"\n",
"When constructing your own agent, you will need to provide it with a list of Tools that it can use. Besides the actual function that is called, the Tool consists of several components:\n",
"\n",
"- name (str), is required\n",
"- description (str), is optional\n",
"- name (str), is required and must be unique within a set of tools provided to an agent\n",
"- description (str), is optional but recommended, as it is used by an agent to determine tool use\n",
"- return_direct (bool), defaults to False\n",
"\n",
"The function that should be called when the tool is selected should take as input a single string and return a single string.\n",
"The function that should be called when the tool is selected should return a single string.\n",
"\n",
"There are two ways to define a tool, we will cover both in the example below."
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "1aaba18c",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Import things that are needed generically\n",
"from langchain.agents import initialize_agent, Tool\n",
"from langchain.agents import AgentType\n",
"from langchain.tools import BaseTool\n",
"from langchain.llms import OpenAI\n",
"from langchain import LLMMathChain, SerpAPIWrapper"
"from langchain import LLMMathChain, SerpAPIWrapper\n",
"from langchain.agents import AgentType, Tool, initialize_agent, tool\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.tools import BaseTool"
]
},
{
@ -43,12 +44,14 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "36ed392e",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)"
"llm = ChatOpenAI(temperature=0)"
]
},
{
@ -74,7 +77,9 @@
"cell_type": "code",
"execution_count": 3,
"id": "56ff7670",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Load the tool configs that are needed.\n",
@ -98,7 +103,9 @@
"cell_type": "code",
"execution_count": 4,
"id": "5b93047d",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Construct the agent. We will use the default agent type here.\n",
@ -110,7 +117,9 @@
"cell_type": "code",
"execution_count": 5,
"id": "6f96a891",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
@ -119,29 +128,24 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n",
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n",
"Action: Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Gigi Hadid's age\n",
"Action: Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mCamila Morrone\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now need to calculate her age raised to the 0.43 power\n",
"Action Input: \"Gigi Hadid age\"\u001b[0m\u001b[36;1m\u001b[1;3m27 years\u001b[0m\u001b[32;1m\u001b[1;3mI need to calculate her age raised to the 0.43 power\n",
"Action: Calculator\n",
"Action Input: 22^0.43\u001b[0m\n",
"Action Input: 27^(0.43)\u001b[0m\n",
"\n",
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
"22^0.43\u001b[32;1m\u001b[1;3m\n",
"```python\n",
"import math\n",
"print(math.pow(22, 0.43))\n",
"27^(0.43)\u001b[32;1m\u001b[1;3m```text\n",
"27**(0.43)\n",
"```\n",
"...numexpr.evaluate(\"27**(0.43)\")...\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m3.777824273683966\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m4.125593352125936\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"Observation: \u001b[33;1m\u001b[1;3mAnswer: 3.777824273683966\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\u001b[0m\n",
"\u001b[33;1m\u001b[1;3mAnswer: 4.125593352125936\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n",
"Final Answer: 4.125593352125936\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -149,7 +153,7 @@
{
"data": {
"text/plain": [
"\"Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\""
"'4.125593352125936'"
]
},
"execution_count": 5,
@ -171,9 +175,11 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"id": "c58a7c40",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"class CustomSearchTool(BaseTool):\n",
@ -203,9 +209,11 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"id": "3318a46f",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"tools = [CustomSearchTool(), CustomCalculatorTool()]"
@ -213,9 +221,11 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 8,
"id": "ee2d0f3a",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)"
@ -223,9 +233,11 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 9,
"id": "6a2cebbf",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
@ -234,29 +246,24 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n",
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n",
"Action: Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mI now know Leo DiCaprio's girlfriend's name and that he's currently linked to Gigi Hadid. I need to find out Camila Morrone's age.\n",
"Action: Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mCamila Morrone\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now need to calculate her age raised to the 0.43 power\n",
"Action Input: \"Camila Morrone age\"\u001b[0m\u001b[36;1m\u001b[1;3m25 years\u001b[0m\u001b[32;1m\u001b[1;3mI have Camila Morrone's age. I need to calculate her age raised to the 0.43 power.\n",
"Action: Calculator\n",
"Action Input: 22^0.43\u001b[0m\n",
"Action Input: 25^(0.43)\u001b[0m\n",
"\n",
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
"22^0.43\u001b[32;1m\u001b[1;3m\n",
"```python\n",
"import math\n",
"print(math.pow(22, 0.43))\n",
"25^(0.43)\u001b[32;1m\u001b[1;3m```text\n",
"25**(0.43)\n",
"```\n",
"...numexpr.evaluate(\"25**(0.43)\")...\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m3.777824273683966\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m3.991298452658078\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"Observation: \u001b[33;1m\u001b[1;3mAnswer: 3.777824273683966\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\u001b[0m\n",
"\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the answer to the original question.\n",
"Final Answer: Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -264,10 +271,10 @@
{
"data": {
"text/plain": [
"\"Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\""
"\"Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\""
]
},
"execution_count": 11,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -288,9 +295,11 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 10,
"id": "8f15307d",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.agents import tool\n",
@ -298,22 +307,24 @@
"@tool\n",
"def search_api(query: str) -> str:\n",
" \"\"\"Searches the API for the query.\"\"\"\n",
" return \"Results\""
" return f\"Results for query {query}\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 11,
"id": "0a23b91b",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1184e0cd0>, func=<function search_api at 0x1635f8700>, coroutine=None)"
"Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', args_schema=<class 'pydantic.main.ArgsModel'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x124346f10>, func=<function search_api at 0x16ad6e020>, coroutine=None)"
]
},
"execution_count": 5,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
@ -332,9 +343,11 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 12,
"id": "28cdf04d",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"@tool(\"search\", return_direct=True)\n",
@ -345,17 +358,17 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 13,
"id": "1085a4bd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1184e0cd0>, func=<function search_api at 0x1635f8670>, coroutine=None)"
"Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=<class 'pydantic.main.ArgsModel'>, return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x124346f10>, func=<function search_api at 0x16ad6d3a0>, coroutine=None)"
]
},
"execution_count": 7,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@ -376,7 +389,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 14,
"id": "79213f40",
"metadata": {},
"outputs": [],
@ -386,7 +399,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 15,
"id": "e1067dcb",
"metadata": {},
"outputs": [],
@ -396,7 +409,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 16,
"id": "6c66ffe8",
"metadata": {},
"outputs": [],
@ -406,7 +419,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 17,
"id": "f45b5bc3",
"metadata": {},
"outputs": [],
@ -416,7 +429,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 18,
"id": "565e2b9b",
"metadata": {},
"outputs": [
@ -427,21 +440,12 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n",
"Action: Google Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mCamila Morrone\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Camila Morrone's age\n",
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age.\n",
"Action: Google Search\n",
"Action Input: \"Camila Morrone age\"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m25 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 25 raised to the 0.43 power\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mNow I need to find out Camila Morrone's current age.\n",
"Action: Calculator\n",
"Action Input: 25^0.43\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Camila Morrone is Leo DiCaprio's girlfriend and her current age raised to the 0.43 power is 3.991298452658078.\u001b[0m\n",
"Action Input: 25^0.43\u001b[0m\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
"Final Answer: Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -449,10 +453,10 @@
{
"data": {
"text/plain": [
"\"Camila Morrone is Leo DiCaprio's girlfriend and her current age raised to the 0.43 power is 3.991298452658078.\""
"\"Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\""
]
},
"execution_count": 12,
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
@ -478,7 +482,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 19,
"id": "3450512e",
"metadata": {},
"outputs": [],
@ -507,7 +511,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 20,
"id": "4b9a7849",
"metadata": {},
"outputs": [
@ -520,9 +524,7 @@
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I should use a music search engine to find the answer\n",
"Action: Music Search\n",
"Action Input: most famous song of christmas\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Action Input: most famous song of christmas\u001b[0m\u001b[33;1m\u001b[1;3m'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
@ -534,7 +536,7 @@
"\"'All I Want For Christmas Is You' by Mariah Carey.\""
]
},
"execution_count": 14,
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
@ -554,7 +556,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 21,
"id": "3bb6185f",
"metadata": {},
"outputs": [],
@ -572,7 +574,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 22,
"id": "113ddb84",
"metadata": {},
"outputs": [],
@ -583,9 +585,11 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 23,
"id": "582439a6",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
@ -596,9 +600,7 @@
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to calculate this\n",
"Action: Calculator\n",
"Action Input: 2**.12\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.2599210498948732\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"Action Input: 2**.12\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 1.086734862526058\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -606,10 +608,10 @@
{
"data": {
"text/plain": [
"'Answer: 1.2599210498948732'"
"'Answer: 1.086734862526058'"
]
},
"execution_count": 17,
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
@ -618,10 +620,149 @@
"agent.run(\"whats 2**.12\")"
]
},
{
"cell_type": "markdown",
"id": "8aa3c353-bd89-467c-9c27-b83a90cd4daa",
"metadata": {},
"source": [
"## Multi-argument tools\n",
"\n",
"Many functions expect structured inputs. These can also be supported using the Tool decorator or by directly subclassing `BaseTool`! We have to modify the LLM's OutputParser to map its string output to a dictionary to pass to the action, however."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 24,
"id": "537bc628",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from typing import Optional, Union\n",
"\n",
"@tool\n",
"def custom_search(k: int, query: str, other_arg: Optional[str] = None):\n",
" \"\"\"The custom search function.\"\"\"\n",
" return f\"Here are the results for the custom search: k={k}, query={query}, other_arg={other_arg}\""
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "d5c992cf-776a-40cd-a6c4-e7cf65ea709e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import re\n",
"from langchain.schema import (\n",
" AgentAction,\n",
" AgentFinish,\n",
")\n",
"from langchain.agents import AgentOutputParser\n",
"\n",
"# We will add a custom parser to map the arguments to a dictionary\n",
"class CustomOutputParser(AgentOutputParser):\n",
" \n",
" def parse_tool_input(self, action_input: str) -> dict:\n",
" # Regex pattern to match arguments and their values\n",
" pattern = r\"(\\w+)\\s*=\\s*(None|\\\"[^\\\"]*\\\"|\\d+)\"\n",
" matches = re.findall(pattern, action_input)\n",
" \n",
" if not matches:\n",
" raise ValueError(f\"Could not parse action input: `{action_input}`\")\n",
"\n",
" # Create a dictionary with the parsed arguments and their values\n",
" parsed_input = {}\n",
" for arg, value in matches:\n",
" if value == \"None\":\n",
" parsed_value = None\n",
" elif value.isdigit():\n",
" parsed_value = int(value)\n",
" else:\n",
" parsed_value = value.strip('\"')\n",
" parsed_input[arg] = parsed_value\n",
"\n",
" return parsed_input\n",
" \n",
" def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:\n",
" # Check if agent should finish\n",
" if \"Final Answer:\" in llm_output:\n",
" return AgentFinish(\n",
" # Return values is generally always a dictionary with a single `output` key\n",
" # It is not recommended to try anything else at the moment :)\n",
" return_values={\"output\": llm_output.split(\"Final Answer:\")[-1].strip()},\n",
" log=llm_output,\n",
" )\n",
" # Parse out the action and action input\n",
" regex = r\"Action\\s*\\d*\\s*:(.*?)\\nAction\\s*\\d*\\s*Input\\s*\\d*\\s*:[\\s]*(.*)\"\n",
" match = re.search(regex, llm_output, re.DOTALL)\n",
" if not match:\n",
" raise ValueError(f\"Could not parse LLM output: `{llm_output}`\")\n",
" action = match.group(1).strip()\n",
" action_input = match.group(2)\n",
" tool_input = self.parse_tool_input(action_input)\n",
" # Return the action and action \n",
" return AgentAction(tool=action, tool_input=tool_input, log=llm_output)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "68269547-1482-4138-a6ea-58f00b4a9548",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)\n",
"agent = initialize_agent([custom_search], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, agent_kwargs={\"output_parser\": CustomOutputParser()})"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "0947835a-691c-4f51-b8f4-6744e0e48ab1",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to use a search function to find the answer\n",
"Action: custom_search\n",
"Action Input: k=1, query=\"me\"\u001b[0m\u001b[36;1m\u001b[1;3mHere are the results for the custom search: k=1, query=me, other_arg=None\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: The results of the custom search for k=1, query=me, other_arg=None.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'The results of the custom search for k=1, query=me, other_arg=None.'"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"Search for me and tell me whatever it says\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "caf39c66-102b-42c1-baf2-777a49886ce4",
"metadata": {},
"outputs": [],
"source": []
@ -643,7 +784,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.11.2"
},
"vscode": {
"interpreter": {

@ -0,0 +1,190 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"# Tool Input Schema\n",
"\n",
"By default, tools infer the argument schema by inspecting the function signature. For more strict requirements, custom input schema can be specified, along with custom validation logic."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.agents import initialize_agent, Tool\n",
"from langchain.agents import AgentType\n",
"from langchain.llms import OpenAI\n",
"from langchain.tools.python.tool import PythonREPLTool\n",
"from pydantic import BaseModel\n",
"from pydantic import Field, root_validator\n",
"from typing import Dict, Any"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"class ToolInputModel(BaseModel):\n",
" query: str = Field(...)\n",
" \n",
" @root_validator\n",
" def validate_query(cls, values: Dict[str, Any]) -> Dict:\n",
" # Note: this is NOT a safe REPL! This is used for instructive purposes only\n",
" if \"import os\" in values[\"query\"]:\n",
" raise ValueError(\"'import os' not permitted in this python REPL.\")\n",
" return values\n",
" \n",
"tool = PythonREPLTool(args_schema=ToolInputModel)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"agent = initialize_agent([tool], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to define a function that adds two numbers\n",
"Action: Python REPL\n",
"Action Input: def add_two_numbers(a, b):\n",
" return a + b\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3m I need to call the function\n",
"Action: Python REPL\n",
"Action Input: print(add_two_numbers(2, 2))\u001b[0m\u001b[36;1m\u001b[1;3m4\n",
"\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 4\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'4'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# This will succeed, since there aren't any arguments that will be triggered during validation\n",
"agent.run(\"Run a python function that adds 2 and 2\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to import os and then list the dir\n",
"Action: Python REPL\n",
"Action Input: import os; print(os.listdir())\u001b[0m"
]
},
{
"ename": "ValidationError",
"evalue": "1 validation error for ToolInputModel\n__root__\n 'import os' not (type=value_error)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43magent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mRun a python function that imports os and lists the dir\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:213\u001b[0m, in \u001b[0;36mChain.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(args) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 212\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`run` supports only one positional argument.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 213\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_keys[\u001b[38;5;241m0\u001b[39m]]\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m kwargs \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m args:\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m(kwargs)[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_keys[\u001b[38;5;241m0\u001b[39m]]\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:116\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_end(outputs, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprep_outputs(inputs, outputs, return_only_outputs)\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/chains/base.py:113\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_start(\n\u001b[1;32m 108\u001b[0m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m},\n\u001b[1;32m 109\u001b[0m inputs,\n\u001b[1;32m 110\u001b[0m verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose,\n\u001b[1;32m 111\u001b[0m )\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 113\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:792\u001b[0m, in \u001b[0;36mAgentExecutor._call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# We now enter the agent loop (until it returns something).\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_continue(iterations, time_elapsed):\n\u001b[0;32m--> 792\u001b[0m next_step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_take_next_step\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 793\u001b[0m \u001b[43m \u001b[49m\u001b[43mname_to_tool_map\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcolor_mapping\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mintermediate_steps\u001b[49m\n\u001b[1;32m 794\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 795\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(next_step_output, AgentFinish):\n\u001b[1;32m 796\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_return(next_step_output, intermediate_steps)\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/agents/agent.py:695\u001b[0m, in \u001b[0;36mAgentExecutor._take_next_step\u001b[0;34m(self, name_to_tool_map, color_mapping, inputs, intermediate_steps)\u001b[0m\n\u001b[1;32m 693\u001b[0m tool_run_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mllm_prefix\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 694\u001b[0m \u001b[38;5;66;03m# We then call the tool on the tool input to get an observation\u001b[39;00m\n\u001b[0;32m--> 695\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[43mtool\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 696\u001b[0m \u001b[43m \u001b[49m\u001b[43magent_action\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtool_input\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 697\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 698\u001b[0m \u001b[43m \u001b[49m\u001b[43mcolor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcolor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 699\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtool_run_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 700\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 701\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 702\u001b[0m tool_run_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39magent\u001b[38;5;241m.\u001b[39mtool_run_logging_kwargs()\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:146\u001b[0m, in \u001b[0;36mBaseTool.run\u001b[0;34m(self, tool_input, verbose, start_color, color, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrun\u001b[39m(\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 139\u001b[0m tool_input: Union[\u001b[38;5;28mstr\u001b[39m, Dict],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any,\n\u001b[1;32m 144\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[1;32m 145\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Run the tool.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 146\u001b[0m run_input \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_parse_input\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose \u001b[38;5;129;01mand\u001b[39;00m verbose \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 148\u001b[0m verbose_ \u001b[38;5;241m=\u001b[39m verbose\n",
"File \u001b[0;32m~/code/lc/lckg/langchain/tools/base.py:112\u001b[0m, in \u001b[0;36mBaseTool._parse_input\u001b[0;34m(self, tool_input)\u001b[0m\n\u001b[1;32m 110\u001b[0m tool_input \u001b[38;5;241m=\u001b[39m {field_name: tool_input}\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pydantic_input_type \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 112\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mpydantic_input_type\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse_obj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtool_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 115\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124margs_schema required for tool \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m in order to\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m accept input of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(tool_input)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 117\u001b[0m )\n",
"File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:526\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.parse_obj\u001b[0;34m()\u001b[0m\n",
"File \u001b[0;32m~/code/lc/lckg/.venv/lib/python3.11/site-packages/pydantic/main.py:341\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[0;34m()\u001b[0m\n",
"\u001b[0;31mValidationError\u001b[0m: 1 validation error for ToolInputModel\n__root__\n 'import os' not (type=value_error)"
]
}
],
"source": [
"# This will fail, because the attempt to import os will trigger a validation error\n",
"agent.run(\"Run a python function that imports os and lists the dir\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

@ -1,25 +1,38 @@
"""Interface for tools."""
from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Union
from typing import Any, Awaitable, Callable, Optional, Type, Union
from langchain.tools.base import BaseTool
from pydantic import BaseModel
from langchain.tools.base import BaseTool, create_args_schema_model_from_signature
class Tool(BaseTool):
"""Tool that takes in function or coroutine directly."""
description: str = ""
func: Callable[[str], str]
coroutine: Optional[Callable[[str], Awaitable[str]]] = None
def _run(self, tool_input: str) -> str:
func: Callable[..., str]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function."""
@property
def args(self) -> Type[BaseModel]:
"""Generate an input pydantic model."""
if self.args_schema is not None:
return self.args_schema
# Infer the schema directly from the function to add more structured
# arguments.
return create_args_schema_model_from_signature(self.func)
def _run(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool."""
return self.func(tool_input)
return self.func(*args, **kwargs)
async def _arun(self, tool_input: str) -> str:
async def _arun(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(tool_input)
return await self.coroutine(*args, **kwargs)
raise NotImplementedError("Tool does not support async")
# TODO: this is for backwards compatibility, remove in future
@ -69,14 +82,16 @@ def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(func: Callable[[str], str]) -> Tool:
def _make_tool(func: Callable) -> Tool:
assert func.__doc__, "Function must have a docstring"
# Description example:
# search_api(query: str) - Searches the API for the query.
# search_api(query: str) - Searches the API for the query.
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
args_schema = create_args_schema_model_from_signature(func)
tool_ = Tool(
name=tool_name,
func=func,
args_schema=args_schema,
description=description,
return_direct=return_direct,
)

@ -371,7 +371,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.starts += 1
tool = action.tool
tool_input = action.tool_input
tool_input = str(action.tool_input)
log = action.log
resp = self._init_resp()

@ -2,7 +2,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, NamedTuple, Optional, TypeVar
from typing import Any, Dict, Generic, List, NamedTuple, Optional, TypeVar, Union
from pydantic import BaseModel, Extra, Field, root_validator
@ -31,7 +31,7 @@ class AgentAction(NamedTuple):
"""Agent's action to take."""
tool: str
tool_input: str
tool_input: Union[str, dict]
log: str

@ -1,19 +1,84 @@
"""Base implementation for tools or skills."""
from abc import abstractmethod
from typing import Any, Optional
import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from pydantic import BaseModel, Extra, Field, validator
from pydantic import BaseModel, Extra, Field, create_model, validator
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
class BaseTool(BaseModel):
"""Class responsible for defining a tool or skill for an LLM."""
def create_args_schema_model_from_signature(run_func: Callable) -> Type[BaseModel]:
"""Create a pydantic model type from a function's signature."""
signature_ = inspect.signature(run_func)
field_definitions: Dict[str, Any] = {}
for name, param in signature_.parameters.items():
if name == "self":
continue
default_value = (
param.default if param.default != inspect.Parameter.empty else None
)
annotation = (
param.annotation if param.annotation != inspect.Parameter.empty else Any
)
# Handle functions with *args in the signature
if param.kind == inspect.Parameter.VAR_POSITIONAL:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_positional": True}),
)
# handle functions with **kwargs in the signature
elif param.kind == inspect.Parameter.VAR_KEYWORD:
field_definitions[name] = (
Any,
Field(default=None, extra={"is_var_keyword": True}),
)
# Handle all other named parameters
else:
is_keyword_only = param.kind == inspect.Parameter.KEYWORD_ONLY
field_definitions[name] = (
annotation,
Field(
default=default_value, extra={"is_keyword_only": is_keyword_only}
),
)
return create_model("ArgsModel", **field_definitions) # type: ignore
def _to_args_and_kwargs(model: BaseModel) -> Tuple[Sequence, dict]:
args = []
kwargs = {}
for name, field in model.__fields__.items():
value = getattr(model, name)
# Handle *args in the function signature
if field.field_info.extra.get("extra", {}).get("is_var_positional"):
if isinstance(value, str):
# Base case for backwards compatability
args.append(value)
elif value is not None:
args.extend(value)
# Handle **kwargs in the function signature
elif field.field_info.extra.get("extra", {}).get("is_var_keyword"):
if value is not None:
kwargs.update(value)
elif field.field_info.extra.get("extra", {}).get("is_keyword_only"):
kwargs[name] = value
else:
args.append(value)
return tuple(args), kwargs
class BaseTool(ABC, BaseModel):
"""Interface LangChain tools must implement."""
name: str
description: str
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False
verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
@ -24,6 +89,33 @@ class BaseTool(BaseModel):
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def args(self) -> Type[BaseModel]:
"""Generate an input pydantic model."""
if self.args_schema is not None:
return self.args_schema
return create_args_schema_model_from_signature(self._run)
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> BaseModel:
"""Convert tool input to pydantic model."""
pydantic_input_type = self.args
if isinstance(tool_input, str):
# For backwards compatibility, a tool that only takes
# a single string input will be converted to a dict.
# to be validated.
field_name = next(iter(pydantic_input_type.__fields__))
tool_input = {field_name: tool_input}
if pydantic_input_type is not None:
return pydantic_input_type.parse_obj(tool_input)
else:
raise ValueError(
f"args_schema required for tool {self.name} in order to"
f" accept input of type {type(tool_input)}"
)
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
@ -35,39 +127,37 @@ class BaseTool(BaseModel):
return callback_manager or get_callback_manager()
@abstractmethod
def _run(self, tool_input: str) -> str:
def _run(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool."""
@abstractmethod
async def _arun(self, tool_input: str) -> str:
async def _arun(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool asynchronously."""
def __call__(self, tool_input: str) -> str:
"""Make tools callable with str input."""
return self.run(tool_input)
def run(
self,
tool_input: str,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any
**kwargs: Any,
) -> str:
"""Run the tool."""
run_input = self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
str(run_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
observation = self._run(tool_input)
args, kwargs = _to_args_and_kwargs(run_input)
observation = self._run(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
@ -78,13 +168,14 @@ class BaseTool(BaseModel):
async def arun(
self,
tool_input: str,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any
**kwargs: Any,
) -> str:
"""Run the tool asynchronously."""
run_input = self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
@ -92,7 +183,7 @@ class BaseTool(BaseModel):
if self.callback_manager.is_async:
await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
str(run_input.dict()),
verbose=verbose_,
color=start_color,
**kwargs,
@ -100,14 +191,15 @@ class BaseTool(BaseModel):
else:
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
str(run_input.dict()),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
observation = await self._arun(tool_input)
args, kwargs = _to_args_and_kwargs(run_input)
observation = await self._arun(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_)
@ -123,3 +215,7 @@ class BaseTool(BaseModel):
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
return observation
def __call__(self, tool_input: str) -> str:
"""Make tool callable."""
return self.run(tool_input)

@ -0,0 +1,29 @@
from typing import Type
from pydantic import BaseModel, Field
from langchain.tools.base import BaseTool
class ReadFileInput(BaseModel):
"""Input for ReadFileTool."""
file_path: str = Field(..., description="name of file")
class ReadFileTool(BaseTool):
name: str = "read_file"
tool_args: Type[BaseModel] = ReadFileInput
description: str = "Read file from disk"
def _run(self, file_path: str) -> str:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return content
except Exception as e:
return "Error: " + str(e)
async def _arun(self, tool_input: str) -> str:
# TODO: Add aiofiles method
raise NotImplementedError

@ -0,0 +1,34 @@
import os
from typing import Type
from pydantic import BaseModel, Field
from langchain.tools.base import BaseTool
class WriteFileInput(BaseModel):
"""Input for WriteFileTool."""
file_path: str = Field(..., description="name of file")
text: str = Field(..., description="text to write to file")
class WriteFileTool(BaseTool):
name: str = "write_file"
tool_args: Type[BaseModel] = WriteFileInput
description: str = "Write file to disk"
def _run(self, file_path: str, text: str) -> str:
try:
directory = os.path.dirname(file_path)
if not os.path.exists(directory) and directory:
os.makedirs(directory)
with open(file_path, "w", encoding="utf-8") as f:
f.write(text)
return "File written to successfully."
except Exception as e:
return "Error: " + str(e)
async def _arun(self, tool_input: str) -> str:
# TODO: Add aiofiles method
raise NotImplementedError

@ -16,7 +16,7 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
def get_action_and_input(text: str) -> Tuple[str, str]:
output = MRKLOutputParser().parse(text)
if isinstance(output, AgentAction):
return output.tool, output.tool_input
return output.tool, str(output.tool_input)
else:
return "Final Answer", output.return_values["output"]

@ -1,7 +1,12 @@
"""Test tool utils."""
from datetime import datetime
from typing import Any, Optional, Type, Union
import pytest
from pydantic import BaseModel
from langchain.agents.tools import Tool, tool
from langchain.tools.base import BaseTool
def test_unnamed_decorator() -> None:
@ -18,6 +23,64 @@ def test_unnamed_decorator() -> None:
assert search_api("test") == "API result"
class _MockSchema(BaseModel):
arg1: int
arg2: bool
arg3: Optional[dict] = None
class _MockStructuredTool(BaseTool):
name = "structured_api"
args_schema: Type[BaseModel] = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
raise NotImplementedError
def test_structured_args() -> None:
"""Test functionality with structured arguments."""
structured_api = _MockStructuredTool()
assert isinstance(structured_api, BaseTool)
assert structured_api.name == "structured_api"
expected_result = "1 True {'foo': 'bar'}"
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
assert structured_api.run(args) == expected_result
def test_structured_args_decorator() -> None:
"""Test functionality with structured arguments parsed as a decorator."""
@tool
def structured_tool_input(
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1}, {arg2}, {opt_arg}"
assert isinstance(structured_tool_input, Tool)
assert structured_tool_input.name == "structured_tool_input"
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
expected_result = "1, 0.001, {'foo': 'bar'}"
assert structured_tool_input.run(args) == expected_result
def test_empty_args_decorator() -> None:
"""Test functionality with no args parsed as a decorator."""
@tool
def empty_tool_input() -> str:
"""Return a constant."""
return "the empty result"
assert isinstance(empty_tool_input, Tool)
assert empty_tool_input.name == "empty_tool_input"
assert empty_tool_input.run({}) == "the empty result"
def test_named_tool_decorator() -> None:
"""Test functionality when arguments are provided as input to decorator."""
@ -57,6 +120,28 @@ def test_unnamed_tool_decorator_return_direct() -> None:
assert search_api.return_direct
def test_tool_with_kwargs() -> None:
"""Test functionality when only return direct is provided."""
@tool(return_direct=True)
def search_api(
arg_1: float, *args: Any, ping: Optional[str] = None, **kwargs: Any
) -> str:
"""Search the API for the query."""
return f"arg_1={arg_1}, foo={args}, ping={ping}, kwargs={kwargs}"
assert isinstance(search_api, Tool)
result = search_api.run(
tool_input={
"arg_1": 3.2,
"args": "fam",
"kwargs": {"bar": "baz"},
"ping": "pong",
}
)
assert result == "arg_1=3.2, foo=('fam',), ping=pong, kwargs={'bar': 'baz'}"
def test_missing_docstring() -> None:
"""Test error is raised when docstring is missing."""
# expect to throw a value error if theres no docstring
@ -67,7 +152,7 @@ def test_missing_docstring() -> None:
return "API result"
def test_create_tool_posistional_args() -> None:
def test_create_tool_positional_args() -> None:
"""Test that positional arguments are allowed."""
test_tool = Tool("test_name", lambda x: x, "test_description")
assert test_tool("foo") == "foo"

@ -0,0 +1,100 @@
import asyncio
from typing import Any, Dict
import pytest
from langchain.requests import TextRequestsWrapper
from langchain.tools.requests.tool import (
RequestsDeleteTool,
RequestsGetTool,
RequestsPatchTool,
RequestsPostTool,
RequestsPutTool,
_parse_input,
)
class _MockTextRequestsWrapper(TextRequestsWrapper):
@staticmethod
def get(url: str, **kwargs: Any) -> str:
return "get_response"
@staticmethod
async def aget(url: str, **kwargs: Any) -> str:
return "aget_response"
@staticmethod
def post(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
return f"post {str(data)}"
@staticmethod
async def apost(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
return f"apost {str(data)}"
@staticmethod
def patch(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
return f"patch {str(data)}"
@staticmethod
async def apatch(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
return f"apatch {str(data)}"
@staticmethod
def put(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
return f"put {str(data)}"
@staticmethod
async def aput(url: str, data: Dict[str, Any], **kwargs: Any) -> str:
return f"aput {str(data)}"
@staticmethod
def delete(url: str, **kwargs: Any) -> str:
return "delete_response"
@staticmethod
async def adelete(url: str, **kwargs: Any) -> str:
return "adelete_response"
@pytest.fixture
def mock_requests_wrapper() -> TextRequestsWrapper:
return _MockTextRequestsWrapper()
def test_parse_input() -> None:
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
expected_output = {"url": "https://example.com", "data": {"key": "value"}}
assert _parse_input(input_text) == expected_output
def test_requests_get_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
tool = RequestsGetTool(requests_wrapper=mock_requests_wrapper)
assert tool.run("https://example.com") == "get_response"
assert asyncio.run(tool.arun("https://example.com")) == "aget_response"
def test_requests_post_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
tool = RequestsPostTool(requests_wrapper=mock_requests_wrapper)
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
assert tool.run(input_text) == "post {'key': 'value'}"
assert asyncio.run(tool.arun(input_text)) == "apost {'key': 'value'}"
def test_requests_patch_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
tool = RequestsPatchTool(requests_wrapper=mock_requests_wrapper)
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
assert tool.run(input_text) == "patch {'key': 'value'}"
assert asyncio.run(tool.arun(input_text)) == "apatch {'key': 'value'}"
def test_requests_put_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
tool = RequestsPutTool(requests_wrapper=mock_requests_wrapper)
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
assert tool.run(input_text) == "put {'key': 'value'}"
assert asyncio.run(tool.arun(input_text)) == "aput {'key': 'value'}"
def test_requests_delete_tool(mock_requests_wrapper: TextRequestsWrapper) -> None:
tool = RequestsDeleteTool(requests_wrapper=mock_requests_wrapper)
assert tool.run("https://example.com") == "delete_response"
assert asyncio.run(tool.arun("https://example.com")) == "adelete_response"
Loading…
Cancel
Save