docs: Re-write custom agent to show to write a tools agent (#15907)

Shows how to write a tools agent rather than a functions agent.
This commit is contained in:
Eugene Yurtsev 2024-01-22 20:28:31 -05:00 committed by GitHub
parent 404abf139a
commit e5672bc944
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,7 +19,7 @@
"\n", "\n",
"This notebook goes through how to create your own custom agent.\n", "This notebook goes through how to create your own custom agent.\n",
"\n", "\n",
"In this example, we will use OpenAI Function Calling to create this agent.\n", "In this example, we will use OpenAI Tool Calling to create this agent.\n",
"**This is generally the most reliable way to create agents.**\n", "**This is generally the most reliable way to create agents.**\n",
"\n", "\n",
"We will first create it WITHOUT memory, but we will then show how to add memory in.\n", "We will first create it WITHOUT memory, but we will then show how to add memory in.\n",
@ -61,10 +61,21 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 2,
"id": "fbe32b5f", "id": "490bab35-adbb-4b45-8d0d-232414121e97",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"from langchain.agents import tool\n", "from langchain.agents import tool\n",
"\n", "\n",
@ -75,6 +86,16 @@
" return len(word)\n", " return len(word)\n",
"\n", "\n",
"\n", "\n",
"get_word_length.invoke(\"abc\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c9821fb3-4449-49a0-a708-88a18d39e068",
"metadata": {},
"outputs": [],
"source": [
"tools = [get_word_length]" "tools = [get_word_length]"
] ]
}, },
@ -91,7 +112,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 4,
"id": "aa4b50ea", "id": "aa4b50ea",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -116,22 +137,24 @@
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Bind tools to LLM\n", "## Bind tools to LLM\n",
"How does the agent know what tools it can use?\n",
"In this case we're relying on OpenAI function calling LLMs, which take functions as a separate argument and have been specifically trained to know when to invoke those functions.\n",
"\n", "\n",
"To pass in our tools to the agent, we just need to format them to the [OpenAI function format](https://openai.com/blog/function-calling-and-other-api-updates) and pass them to our model. (By `bind`-ing the functions, we're making sure that they're passed in each time the model is invoked.)" "How does the agent know what tools it can use?\n",
"\n",
"In this case we're relying on OpenAI tool calling LLMs, which take tools as a separate argument and have been specifically trained to know when to invoke those tools.\n",
"\n",
"To pass in our tools to the agent, we just need to format them to the [OpenAI tool format](https://platform.openai.com/docs/api-reference/chat/create) and pass them to our model. (By `bind`-ing the functions, we're making sure that they're passed in each time the model is invoked.)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 5,
"id": "e82713b6", "id": "e82713b6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain_community.tools.convert_to_openai import format_tool_to_openai_function\n", "from langchain_community.tools.convert_to_openai import format_tool_to_openai_tool\n",
"\n", "\n",
"llm_with_tools = llm.bind(functions=[format_tool_to_openai_function(t) for t in tools])" "llm_with_tools = llm.bind(tools=[format_tool_to_openai_tool(tool) for tool in tools])"
] ]
}, },
{ {
@ -146,30 +169,32 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 6,
"id": "925a8ca4", "id": "925a8ca4",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.agents.format_scratchpad import format_to_openai_function_messages\n", "from langchain.agents.format_scratchpad.openai_tools import (\n",
"from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser\n", " format_to_openai_tool_messages,\n",
")\n",
"from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser\n",
"\n", "\n",
"agent = (\n", "agent = (\n",
" {\n", " {\n",
" \"input\": lambda x: x[\"input\"],\n", " \"input\": lambda x: x[\"input\"],\n",
" \"agent_scratchpad\": lambda x: format_to_openai_function_messages(\n", " \"agent_scratchpad\": lambda x: format_to_openai_tool_messages(\n",
" x[\"intermediate_steps\"]\n", " x[\"intermediate_steps\"]\n",
" ),\n", " ),\n",
" }\n", " }\n",
" | prompt\n", " | prompt\n",
" | llm_with_tools\n", " | llm_with_tools\n",
" | OpenAIFunctionsAgentOutputParser()\n", " | OpenAIToolsAgentOutputParser()\n",
")" ")"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 7,
"id": "9af9734e", "id": "9af9734e",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -181,7 +206,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 8,
"id": "653b1617", "id": "653b1617",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -193,10 +218,10 @@
"\n", "\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n", "\u001b[32;1m\u001b[1;3m\n",
"Invoking: `get_word_length` with `{'word': 'educa'}`\n", "Invoking: `get_word_length` with `{'word': 'eudca'}`\n",
"\n", "\n",
"\n", "\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m5\u001b[0m\u001b[32;1m\u001b[1;3mThere are 5 letters in the word \"educa\".\u001b[0m\n", "\u001b[0m\u001b[36;1m\u001b[1;3m5\u001b[0m\u001b[32;1m\u001b[1;3mThere are 5 letters in the word \"eudca\".\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@ -204,17 +229,21 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'input': 'How many letters in the word educa',\n", "[{'actions': [OpenAIToolAgentAction(tool='get_word_length', tool_input={'word': 'eudca'}, log=\"\\nInvoking: `get_word_length` with `{'word': 'eudca'}`\\n\\n\\n\", message_log=[AIMessageChunk(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_U9SR78eT398r9UbzID2N9LXh', 'function': {'arguments': '{\\n \"word\": \"eudca\"\\n}', 'name': 'get_word_length'}, 'type': 'function'}]})], tool_call_id='call_U9SR78eT398r9UbzID2N9LXh')],\n",
" 'output': 'There are 5 letters in the word \"educa\".'}" " 'messages': [AIMessageChunk(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_U9SR78eT398r9UbzID2N9LXh', 'function': {'arguments': '{\\n \"word\": \"eudca\"\\n}', 'name': 'get_word_length'}, 'type': 'function'}]})]},\n",
" {'steps': [AgentStep(action=OpenAIToolAgentAction(tool='get_word_length', tool_input={'word': 'eudca'}, log=\"\\nInvoking: `get_word_length` with `{'word': 'eudca'}`\\n\\n\\n\", message_log=[AIMessageChunk(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_U9SR78eT398r9UbzID2N9LXh', 'function': {'arguments': '{\\n \"word\": \"eudca\"\\n}', 'name': 'get_word_length'}, 'type': 'function'}]})], tool_call_id='call_U9SR78eT398r9UbzID2N9LXh'), observation=5)],\n",
" 'messages': [FunctionMessage(content='5', name='get_word_length')]},\n",
" {'output': 'There are 5 letters in the word \"eudca\".',\n",
" 'messages': [AIMessage(content='There are 5 letters in the word \"eudca\".')]}]"
] ]
}, },
"execution_count": 14, "execution_count": 8,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"agent_executor.invoke({\"input\": \"How many letters in the word educa\"})" "list(agent_executor.stream({\"input\": \"How many letters in the word eudca\"}))"
] ]
}, },
{ {
@ -227,7 +256,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 9,
"id": "60f5dc19", "id": "60f5dc19",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -237,7 +266,7 @@
"AIMessage(content='There are 6 letters in the word \"educa\".')" "AIMessage(content='There are 6 letters in the word \"educa\".')"
] ]
}, },
"execution_count": 15, "execution_count": 9,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -270,7 +299,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 10,
"id": "169006d5", "id": "169006d5",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -301,7 +330,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 11,
"id": "8c03f36c", "id": "8c03f36c",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -321,7 +350,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 12,
"id": "5429d97f", "id": "5429d97f",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -329,14 +358,14 @@
"agent = (\n", "agent = (\n",
" {\n", " {\n",
" \"input\": lambda x: x[\"input\"],\n", " \"input\": lambda x: x[\"input\"],\n",
" \"agent_scratchpad\": lambda x: format_to_openai_function_messages(\n", " \"agent_scratchpad\": lambda x: format_to_openai_tool_messages(\n",
" x[\"intermediate_steps\"]\n", " x[\"intermediate_steps\"]\n",
" ),\n", " ),\n",
" \"chat_history\": lambda x: x[\"chat_history\"],\n", " \"chat_history\": lambda x: x[\"chat_history\"],\n",
" }\n", " }\n",
" | prompt\n", " | prompt\n",
" | llm_with_tools\n", " | llm_with_tools\n",
" | OpenAIFunctionsAgentOutputParser()\n", " | OpenAIToolsAgentOutputParser()\n",
")\n", ")\n",
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)" "agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)"
] ]
@ -351,7 +380,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 13,
"id": "9d9da346", "id": "9d9da346",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -386,7 +415,7 @@
" 'output': 'No, \"educa\" is not a real word in English.'}" " 'output': 'No, \"educa\" is not a real word in English.'}"
] ]
}, },
"execution_count": 19, "execution_count": 13,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -402,14 +431,6 @@
")\n", ")\n",
"agent_executor.invoke({\"input\": \"is that a real word?\", \"chat_history\": chat_history})" "agent_executor.invoke({\"input\": \"is that a real word?\", \"chat_history\": chat_history})"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f21bcd99",
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {
@ -428,7 +449,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.1" "version": "3.11.4"
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {