forked from Archives/langchain
8a050ba4bf
The required arg is `question` not `query`
851 lines
26 KiB
Plaintext
851 lines
26 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "5436020b",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Defining Custom Tools\n",
|
||
"\n",
|
||
"When constructing your own agent, you will need to provide it with a list of Tools that it can use. Besides the actual function that is called, the Tool consists of several components:\n",
|
||
"\n",
|
||
"- name (str), is required and must be unique within a set of tools provided to an agent\n",
|
||
"- description (str), is optional but recommended, as it is used by an agent to determine tool use\n",
|
||
"- return_direct (bool), defaults to False\n",
|
||
"- args_schema (Pydantic BaseModel), is optional but recommended, can be used to provide more information or validation for expected parameters.\n",
|
||
"\n",
|
||
"The function that should be called when the tool is selected should return a single string.\n",
|
||
"\n",
|
||
"There are two ways to define a tool, we will cover both in the example below."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "1aaba18c",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Import things that are needed generically\n",
|
||
"from langchain import LLMMathChain, SerpAPIWrapper\n",
|
||
"from langchain.agents import AgentType, Tool, initialize_agent, tool\n",
|
||
"from langchain.chat_models import ChatOpenAI\n",
|
||
"from langchain.tools import BaseTool"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8e2c3874",
|
||
"metadata": {},
|
||
"source": [
|
||
"Initialize the LLM to use for the agent."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "36ed392e",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"llm = ChatOpenAI(temperature=0)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "f8bc72c2",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Completely New Tools \n",
|
||
"First, we show how to create completely new tools from scratch.\n",
|
||
"\n",
|
||
"There are two ways to do this: either by using the Tool dataclass, or by subclassing the BaseTool class."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b63fcc3b",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Tool dataclass"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "56ff7670",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Load the tool configs that are needed.\n",
|
||
"search = SerpAPIWrapper()\n",
|
||
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
|
||
"tools = [\n",
|
||
" Tool(\n",
|
||
" name = \"Search\",\n",
|
||
" func=search.run,\n",
|
||
" description=\"useful for when you need to answer questions about current events\"\n",
|
||
" ),\n",
|
||
"]\n",
|
||
"# You can also define an args_schema to provide more information about inputs\n",
|
||
"from pydantic import BaseModel, Field\n",
|
||
"\n",
|
||
"class CalculatorInput(BaseModel):\n",
|
||
" question: str = Field()\n",
|
||
" \n",
|
||
"\n",
|
||
"tools.append(\n",
|
||
" Tool(\n",
|
||
" name=\"Calculator\",\n",
|
||
" func=llm_math_chain.run,\n",
|
||
" description=\"useful for when you need to answer questions about math\",\n",
|
||
" args_schema=CalculatorInput\n",
|
||
" )\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "5b93047d",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Construct the agent. We will use the default agent type here.\n",
|
||
"# See documentation for a full list of options.\n",
|
||
"agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "6f96a891",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"\n",
|
||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n",
|
||
"Action: Search\n",
|
||
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mDiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Camila Morrone's current age\n",
|
||
"Action: Calculator\n",
|
||
"Action Input: 25^(0.43)\u001b[0m\n",
|
||
"\n",
|
||
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
|
||
"25^(0.43)\u001b[32;1m\u001b[1;3m```text\n",
|
||
"25**(0.43)\n",
|
||
"```\n",
|
||
"...numexpr.evaluate(\"25**(0.43)\")...\n",
|
||
"\u001b[0m\n",
|
||
"Answer: \u001b[33;1m\u001b[1;3m3.991298452658078\u001b[0m\n",
|
||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||
"\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n",
|
||
"Final Answer: 3.991298452658078\u001b[0m\n",
|
||
"\n",
|
||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'3.991298452658078'"
|
||
]
|
||
},
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"agent.run(\"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "6f12eaf0",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Subclassing the BaseTool class"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "c58a7c40",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"class CustomSearchTool(BaseTool):\n",
|
||
" name = \"Search\"\n",
|
||
" description = \"useful for when you need to answer questions about current events\"\n",
|
||
"\n",
|
||
" def _run(self, query: str) -> str:\n",
|
||
" \"\"\"Use the tool.\"\"\"\n",
|
||
" return search.run(query)\n",
|
||
" \n",
|
||
" async def _arun(self, query: str) -> str:\n",
|
||
" \"\"\"Use the tool asynchronously.\"\"\"\n",
|
||
" raise NotImplementedError(\"BingSearchRun does not support async\")\n",
|
||
" \n",
|
||
"class CustomCalculatorTool(BaseTool):\n",
|
||
" name = \"Calculator\"\n",
|
||
" description = \"useful for when you need to answer questions about math\"\n",
|
||
" args_schema=CalculatorInput\n",
|
||
"\n",
|
||
" def _run(self, query: str) -> str:\n",
|
||
" \"\"\"Use the tool.\"\"\"\n",
|
||
" return llm_math_chain.run(query)\n",
|
||
" \n",
|
||
" async def _arun(self, query: str) -> str:\n",
|
||
" \"\"\"Use the tool asynchronously.\"\"\"\n",
|
||
" raise NotImplementedError(\"BingSearchRun does not support async\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"id": "3318a46f",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"tools = [CustomSearchTool(), CustomCalculatorTool()]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"id": "ee2d0f3a",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"id": "6a2cebbf",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"\n",
|
||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n",
|
||
"Action: Search\n",
|
||
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mDiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Camila Morrone's current age\n",
|
||
"Action: Calculator\n",
|
||
"Action Input: 25^(0.43)\u001b[0m\n",
|
||
"\n",
|
||
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
|
||
"25^(0.43)\u001b[32;1m\u001b[1;3m```text\n",
|
||
"25**(0.43)\n",
|
||
"```\n",
|
||
"...numexpr.evaluate(\"25**(0.43)\")...\n",
|
||
"\u001b[0m\n",
|
||
"Answer: \u001b[33;1m\u001b[1;3m3.991298452658078\u001b[0m\n",
|
||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||
"\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n",
|
||
"Final Answer: 3.991298452658078\u001b[0m\n",
|
||
"\n",
|
||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'3.991298452658078'"
|
||
]
|
||
},
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"agent.run(\"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "824eaf74",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Using the `tool` decorator\n",
|
||
"\n",
|
||
"To make it easier to define custom tools, a `@tool` decorator is provided. This decorator can be used to quickly create a `Tool` from a simple function. The decorator uses the function name as the tool name by default, but this can be overridden by passing a string as the first argument. Additionally, the decorator will use the function's docstring as the tool's description."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"id": "8f15307d",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from langchain.agents import tool\n",
|
||
"\n",
|
||
"@tool\n",
|
||
"def search_api(query: str) -> str:\n",
|
||
" \"\"\"Searches the API for the query.\"\"\"\n",
|
||
" return f\"Results for query {query}\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"id": "0a23b91b",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', args_schema=<class 'pydantic.main.SearchApi'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x12748c4c0>, func=<function search_api at 0x16bd664c0>, coroutine=None)"
|
||
]
|
||
},
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"search_api"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "cc6ee8c1",
|
||
"metadata": {},
|
||
"source": [
|
||
"You can also provide arguments like the tool name and whether to return directly."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"id": "28cdf04d",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"@tool(\"search\", return_direct=True)\n",
|
||
"def search_api(query: str) -> str:\n",
|
||
" \"\"\"Searches the API for the query.\"\"\"\n",
|
||
" return \"Results\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"id": "1085a4bd",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=<class 'pydantic.main.SearchApi'>, return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x12748c4c0>, func=<function search_api at 0x16bd66310>, coroutine=None)"
|
||
]
|
||
},
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"search_api"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "de34a6a3",
|
||
"metadata": {},
|
||
"source": [
|
||
"You can also provide `args_schema` to provide more information about the argument"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"id": "f3a5c106",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class SearchInput(BaseModel):\n",
|
||
" query: str = Field(description=\"should be a search query\")\n",
|
||
" \n",
|
||
"@tool(\"search\", return_direct=True, args_schema=SearchInput)\n",
|
||
"def search_api(query: str) -> str:\n",
|
||
" \"\"\"Searches the API for the query.\"\"\"\n",
|
||
" return \"Results\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"id": "7914ba6b",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=<class '__main__.SearchInput'>, return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x12748c4c0>, func=<function search_api at 0x16bcf0ee0>, coroutine=None)"
|
||
]
|
||
},
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"search_api"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "1d0430d6",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Modify existing tools\n",
|
||
"\n",
|
||
"Now, we show how to load existing tools and just modify them. In the example below, we do something really simple and change the Search tool to have the name `Google Search`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"id": "79213f40",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from langchain.agents import load_tools"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"id": "e1067dcb",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"id": "6c66ffe8",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"tools[0].name = \"Google Search\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"id": "f45b5bc3",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"id": "565e2b9b",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"\n",
|
||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age.\n",
|
||
"Action: Google Search\n",
|
||
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel – Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mNow I need to find out Camila Morrone's current age.\n",
|
||
"Action: Calculator\n",
|
||
"Action Input: 25^0.43\u001b[0m\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
|
||
"Final Answer: Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\u001b[0m\n",
|
||
"\n",
|
||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"\"Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\""
|
||
]
|
||
},
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"agent.run(\"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "376813ed",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Defining the priorities among Tools\n",
|
||
"When you made a Custom tool, you may want the Agent to use the custom tool more than normal tools.\n",
|
||
"\n",
|
||
"For example, you made a custom tool, which gets information on music from your database. When a user wants information on songs, You want the Agent to use `the custom tool` more than the normal `Search tool`. But the Agent might prioritize a normal Search tool.\n",
|
||
"\n",
|
||
"This can be accomplished by adding a statement such as `Use this more than the normal search if the question is about Music, like 'who is the singer of yesterday?' or 'what is the most popular song in 2022?'` to the description.\n",
|
||
"\n",
|
||
"An example is below."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"id": "3450512e",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Import things that are needed generically\n",
|
||
"from langchain.agents import initialize_agent, Tool\n",
|
||
"from langchain.agents import AgentType\n",
|
||
"from langchain.llms import OpenAI\n",
|
||
"from langchain import LLMMathChain, SerpAPIWrapper\n",
|
||
"search = SerpAPIWrapper()\n",
|
||
"tools = [\n",
|
||
" Tool(\n",
|
||
" name = \"Search\",\n",
|
||
" func=search.run,\n",
|
||
" description=\"useful for when you need to answer questions about current events\"\n",
|
||
" ),\n",
|
||
" Tool(\n",
|
||
" name=\"Music Search\",\n",
|
||
" func=lambda x: \"'All I Want For Christmas Is You' by Mariah Carey.\", #Mock Function\n",
|
||
" description=\"A Music search engine. Use this more than the normal search if the question is about Music, like 'who is the singer of yesterday?' or 'what is the most popular song in 2022?'\",\n",
|
||
" )\n",
|
||
"]\n",
|
||
"\n",
|
||
"agent = initialize_agent(tools, OpenAI(temperature=0), agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"id": "4b9a7849",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"\n",
|
||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||
"\u001b[32;1m\u001b[1;3m I should use a music search engine to find the answer\n",
|
||
"Action: Music Search\n",
|
||
"Action Input: most famous song of christmas\u001b[0m\u001b[33;1m\u001b[1;3m'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||
"Final Answer: 'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\n",
|
||
"\n",
|
||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"\"'All I Want For Christmas Is You' by Mariah Carey.\""
|
||
]
|
||
},
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"agent.run(\"what is the most famous song of christmas\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "bc477d43",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Using tools to return directly\n",
|
||
"Often, it can be desirable to have a tool output returned directly to the user, if it’s called. You can do this easily with LangChain by setting the return_direct flag for a tool to be True."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"id": "3bb6185f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"llm_math_chain = LLMMathChain(llm=llm)\n",
|
||
"tools = [\n",
|
||
" Tool(\n",
|
||
" name=\"Calculator\",\n",
|
||
" func=llm_math_chain.run,\n",
|
||
" description=\"useful for when you need to answer questions about math\",\n",
|
||
" return_direct=True\n",
|
||
" )\n",
|
||
"]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"id": "113ddb84",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"llm = OpenAI(temperature=0)\n",
|
||
"agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"id": "582439a6",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"\n",
|
||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||
"\u001b[32;1m\u001b[1;3m I need to calculate this\n",
|
||
"Action: Calculator\n",
|
||
"Action Input: 2**.12\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 1.086734862526058\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n",
|
||
"\n",
|
||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'Answer: 1.086734862526058'"
|
||
]
|
||
},
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"agent.run(\"whats 2**.12\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8aa3c353-bd89-467c-9c27-b83a90cd4daa",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Multi-argument tools\n",
|
||
"\n",
|
||
"Many functions expect structured inputs. These can also be supported using the Tool decorator or by directly subclassing `BaseTool`! We have to modify the LLM's OutputParser to map its string output to a dictionary to pass to the action, however."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"id": "537bc628",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"from typing import Optional, Union\n",
|
||
"\n",
|
||
"@tool\n",
|
||
"def custom_search(k: int, query: str, other_arg: Optional[str] = None):\n",
|
||
" \"\"\"The custom search function.\"\"\"\n",
|
||
" return f\"Here are the results for the custom search: k={k}, query={query}, other_arg={other_arg}\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"id": "d5c992cf-776a-40cd-a6c4-e7cf65ea709e",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import re\n",
|
||
"from langchain.schema import (\n",
|
||
" AgentAction,\n",
|
||
" AgentFinish,\n",
|
||
")\n",
|
||
"from langchain.agents import AgentOutputParser\n",
|
||
"\n",
|
||
"# We will add a custom parser to map the arguments to a dictionary\n",
|
||
"class CustomOutputParser(AgentOutputParser):\n",
|
||
" \n",
|
||
" def parse_tool_input(self, action_input: str) -> dict:\n",
|
||
" # Regex pattern to match arguments and their values\n",
|
||
" pattern = r\"(\\w+)\\s*=\\s*(None|\\\"[^\\\"]*\\\"|\\d+)\"\n",
|
||
" matches = re.findall(pattern, action_input)\n",
|
||
" \n",
|
||
" if not matches:\n",
|
||
" raise ValueError(f\"Could not parse action input: `{action_input}`\")\n",
|
||
"\n",
|
||
" # Create a dictionary with the parsed arguments and their values\n",
|
||
" parsed_input = {}\n",
|
||
" for arg, value in matches:\n",
|
||
" if value == \"None\":\n",
|
||
" parsed_value = None\n",
|
||
" elif value.isdigit():\n",
|
||
" parsed_value = int(value)\n",
|
||
" else:\n",
|
||
" parsed_value = value.strip('\"')\n",
|
||
" parsed_input[arg] = parsed_value\n",
|
||
"\n",
|
||
" return parsed_input\n",
|
||
" \n",
|
||
" def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:\n",
|
||
" # Check if agent should finish\n",
|
||
" if \"Final Answer:\" in llm_output:\n",
|
||
" return AgentFinish(\n",
|
||
" # Return values is generally always a dictionary with a single `output` key\n",
|
||
" # It is not recommended to try anything else at the moment :)\n",
|
||
" return_values={\"output\": llm_output.split(\"Final Answer:\")[-1].strip()},\n",
|
||
" log=llm_output,\n",
|
||
" )\n",
|
||
" # Parse out the action and action input\n",
|
||
" regex = r\"Action\\s*\\d*\\s*:(.*?)\\nAction\\s*\\d*\\s*Input\\s*\\d*\\s*:[\\s]*(.*)\"\n",
|
||
" match = re.search(regex, llm_output, re.DOTALL)\n",
|
||
" if not match:\n",
|
||
" raise ValueError(f\"Could not parse LLM output: `{llm_output}`\")\n",
|
||
" action = match.group(1).strip()\n",
|
||
" action_input = match.group(2)\n",
|
||
" tool_input = self.parse_tool_input(action_input)\n",
|
||
" # Return the action and action \n",
|
||
" return AgentAction(tool=action, tool_input=tool_input, log=llm_output)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"id": "68269547-1482-4138-a6ea-58f00b4a9548",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"llm = OpenAI(temperature=0)\n",
|
||
"agent = initialize_agent([custom_search], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, agent_kwargs={\"output_parser\": CustomOutputParser()})"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"id": "0947835a-691c-4f51-b8f4-6744e0e48ab1",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"\n",
|
||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||
"\u001b[32;1m\u001b[1;3m I need to use a search function to find the answer\n",
|
||
"Action: custom_search\n",
|
||
"Action Input: k=1, query=\"me\"\u001b[0m\u001b[36;1m\u001b[1;3mHere are the results for the custom search: k=1, query=me, other_arg=None\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||
"Final Answer: The results of the custom search for k=1, query=me, other_arg=None.\u001b[0m\n",
|
||
"\n",
|
||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'The results of the custom search for k=1, query=me, other_arg=None.'"
|
||
]
|
||
},
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"agent.run(\"Search for me and tell me whatever it says\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "caf39c66-102b-42c1-baf2-777a49886ce4",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.1"
|
||
},
|
||
"vscode": {
|
||
"interpreter": {
|
||
"hash": "e90c8aa204a57276aa905271aff2d11799d0acb3547adabc5892e639a5e45e34"
|
||
}
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|