forked from Archives/langchain
326 lines
10 KiB
Plaintext
326 lines
10 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fa6802ac",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Adding Memory to an Agent\n",
|
|
"\n",
|
|
"This notebook goes over adding memory to an Agent. Before going through this notebook, please walkthrough the following notebooks, as this will build on top of both of them:\n",
|
|
"\n",
|
|
"- [Adding memory to an LLM Chain](adding_memory.ipynb)\n",
|
|
"- [Custom Agents](../agents/custom_agent.ipynb)\n",
|
|
"\n",
|
|
"In order to add a memory to an agent we are going to the the following steps:\n",
|
|
"\n",
|
|
"1. We are going to create an LLMChain with memory.\n",
|
|
"2. We are going to use that LLMChain to create a custom Agent.\n",
|
|
"\n",
|
|
"For the purposes of this exercise, we are going to create a simple custom Agent that has access to a search tool and utilizes the `ConversationBufferMemory` class."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "8db95912",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.agents import ZeroShotAgent, Tool\n",
|
|
"from langchain.chains.conversation.memory import ConversationBufferMemory\n",
|
|
"from langchain import OpenAI, SerpAPIWrapper, LLMChain"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "97ad8467",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"search = SerpAPIWrapper()\n",
|
|
"tools = [\n",
|
|
" Tool(\n",
|
|
" name = \"Search\",\n",
|
|
" func=search.run,\n",
|
|
" description=\"useful for when you need to answer questions about current events\"\n",
|
|
" )\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4ad2e708",
|
|
"metadata": {},
|
|
"source": [
|
|
"Notice the usage of the `chat_history` variable in the PromptTemplate, which matches up with the dynamic key name in the ConversationBufferMemory."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "e3439cd6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"prefix = \"\"\"Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:\"\"\"\n",
|
|
"suffix = \"\"\"Begin!\"\n",
|
|
"\n",
|
|
"{chat_history}\n",
|
|
"Question: {input}\"\"\"\n",
|
|
"\n",
|
|
"prompt = ZeroShotAgent.create_prompt(\n",
|
|
" tools, \n",
|
|
" prefix=prefix, \n",
|
|
" suffix=suffix, \n",
|
|
" input_variables=[\"input\", \"chat_history\"]\n",
|
|
")\n",
|
|
"memory = ConversationBufferMemory(memory_key=\"chat_history\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0021675b",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can now construct the LLMChain, with the Memory object, and then create the agent."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "c56a0e73",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt, memory=memory)\n",
|
|
"agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "ca4bc1fb",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"How many people live in canada?\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3m I should look up how many people live in canada\n",
|
|
"Action: Search\n",
|
|
"Action Input: \"How many people live in canada?\"\u001b[0m\n",
|
|
"Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,533,678 as of Friday, November 25, 2022, based on Worldometer elaboration of the latest United Nations data. · Canada 2020 ...\u001b[0m\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
|
"Final Answer: The current population of Canada is 38,533,678 as of Friday, November 25, 2022, based on Worldometer elaboration of the latest United Nations data.\u001b[0m\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'The current population of Canada is 38,533,678 as of Friday, November 25, 2022, based on Worldometer elaboration of the latest United Nations data.'"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent.run(\"How many people live in canada?\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "45627664",
|
|
"metadata": {},
|
|
"source": [
|
|
"To test the memory of this agent, we can ask a followup question that relies on information in the previous exchange to be answered correctly."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "eecc0462",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"what is their national anthem called?\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3m\n",
|
|
"AI: I should look up the name of Canada's national anthem\n",
|
|
"Action: Search\n",
|
|
"Action Input: \"What is the name of Canada's national anthem?\"\u001b[0m\n",
|
|
"Observation: \u001b[36;1m\u001b[1;3mAfter 100 years of tradition, O Canada was proclaimed Canada's national anthem in 1980. The music for O Canada was composed in 1880 by Calixa ...\u001b[0m\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3m\n",
|
|
"AI: I now know the final answer\n",
|
|
"Final Answer: After 100 years of tradition, O Canada was proclaimed Canada's national anthem in 1980. The music for O Canada was composed in 1880 by Calixa Lavallée.\u001b[0m\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"\"After 100 years of tradition, O Canada was proclaimed Canada's national anthem in 1980. The music for O Canada was composed in 1880 by Calixa Lavallée.\""
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent.run(\"what is their national anthem called?\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cc3d0aa4",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can see that the agent remembered that the previous question was about Canada, and properly asked Google Search what the name of Canada's national anthem was.\n",
|
|
"\n",
|
|
"For fun, let's compare this to an agent that does NOT have memory."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "3359d043",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"prefix = \"\"\"Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:\"\"\"\n",
|
|
"suffix = \"\"\"Begin!\"\n",
|
|
"\n",
|
|
"Question: {input}\"\"\"\n",
|
|
"\n",
|
|
"prompt = ZeroShotAgent.create_prompt(\n",
|
|
" tools, \n",
|
|
" prefix=prefix, \n",
|
|
" suffix=suffix, \n",
|
|
" input_variables=[\"input\"]\n",
|
|
")\n",
|
|
"llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)\n",
|
|
"agent_without_memory = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "970d23df",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"How many people live in canada?\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3m I should look up how many people live in canada\n",
|
|
"Action: Search\n",
|
|
"Action Input: \"How many people live in canada?\"\u001b[0m\n",
|
|
"Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,533,678 as of Friday, November 25, 2022, based on Worldometer elaboration of the latest United Nations data. · Canada 2020 ...\u001b[0m\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
|
"Final Answer: The current population of Canada is 38,533,678\u001b[0m\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'The current population of Canada is 38,533,678'"
|
|
]
|
|
},
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent_without_memory.run(\"How many people live in canada?\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "d9ea82f0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"what is their national anthem called?\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3m I should probably look this up\n",
|
|
"Action: Search\n",
|
|
"Action Input: \"What is the national anthem of [country]\"\u001b[0m\n",
|
|
"Observation: \u001b[36;1m\u001b[1;3mMost nation states have an anthem, defined as \"a song, as of praise, devotion, or patriotism\"; most anthems are either marches or hymns in style.\u001b[0m\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
|
"Final Answer: The national anthem is called \"the national anthem.\"\u001b[0m\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'The national anthem is called \"the national anthem.\"'"
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent_without_memory.run(\"what is their national anthem called?\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5b1f9223",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|