mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
324 lines
10 KiB
Plaintext
324 lines
10 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fa6802ac",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Adding Memory to an Agent\n",
|
|
"\n",
|
|
"This notebook goes over adding memory to an Agent. Before going through this notebook, please walkthrough the following notebooks, as this will build on top of both of them:\n",
|
|
"\n",
|
|
"- [Adding memory to an LLM Chain](adding_memory.ipynb)\n",
|
|
"- [Custom Agents](../agents/custom_agent.ipynb)\n",
|
|
"\n",
|
|
"In order to add a memory to an agent we are going to the the following steps:\n",
|
|
"\n",
|
|
"1. We are going to create an LLMChain with memory.\n",
|
|
"2. We are going to use that LLMChain to create a custom Agent.\n",
|
|
"\n",
|
|
"For the purposes of this exercise, we are going to create a simple custom Agent that has access to a search tool and utilizes the `ConversationBufferMemory` class."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "8db95912",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.agents import ZeroShotAgent, Tool, AgentExecutor\n",
|
|
"from langchain.chains.conversation.memory import ConversationBufferMemory\n",
|
|
"from langchain import OpenAI, SerpAPIWrapper, LLMChain"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "97ad8467",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"search = SerpAPIWrapper()\n",
|
|
"tools = [\n",
|
|
" Tool(\n",
|
|
" name = \"Search\",\n",
|
|
" func=search.run,\n",
|
|
" description=\"useful for when you need to answer questions about current events\"\n",
|
|
" )\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4ad2e708",
|
|
"metadata": {},
|
|
"source": [
|
|
"Notice the usage of the `chat_history` variable in the PromptTemplate, which matches up with the dynamic key name in the ConversationBufferMemory."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "e3439cd6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"prefix = \"\"\"Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:\"\"\"\n",
|
|
"suffix = \"\"\"Begin!\"\n",
|
|
"\n",
|
|
"{chat_history}\n",
|
|
"Question: {input}\n",
|
|
"{agent_scratchpad}\"\"\"\n",
|
|
"\n",
|
|
"prompt = ZeroShotAgent.create_prompt(\n",
|
|
" tools, \n",
|
|
" prefix=prefix, \n",
|
|
" suffix=suffix, \n",
|
|
" input_variables=[\"input\", \"chat_history\", \"agent_scratchpad\"]\n",
|
|
")\n",
|
|
"memory = ConversationBufferMemory(memory_key=\"chat_history\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0021675b",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can now construct the LLMChain, with the Memory object, and then create the agent."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "c56a0e73",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)\n",
|
|
"agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)\n",
|
|
"agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "ca4bc1fb",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mThought: I need to find out the population of Canada\n",
|
|
"Action: Search\n",
|
|
"Action Input: Population of Canada\u001b[0m\n",
|
|
"Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,555,354 as of Monday, December 19, 2022, based on Worldometer elaboration of the latest United Nations data. · Canada 2020 ...\u001b[0m\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
|
"Final Answer: The current population of Canada is 38,555,354.\u001b[0m\n",
|
|
"\u001b[1m> Finished AgentExecutor chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'The current population of Canada is 38,555,354.'"
|
|
]
|
|
},
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent_chain.run(input=\"How many people live in canada?\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "45627664",
|
|
"metadata": {},
|
|
"source": [
|
|
"To test the memory of this agent, we can ask a followup question that relies on information in the previous exchange to be answered correctly."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "eecc0462",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mThought: I need to find out what the national anthem of Canada is called.\n",
|
|
"Action: Search\n",
|
|
"Action Input: National anthem of Canada\u001b[0m\n",
|
|
"Observation: \u001b[36;1m\u001b[1;3mAfter 100 years of tradition, O Canada was proclaimed Canada's national anthem in 1980. The music for O Canada was composed in 1880 by Calixa ...\u001b[0m\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
|
"Final Answer: O Canada is the national anthem of Canada.\u001b[0m\n",
|
|
"\u001b[1m> Finished AgentExecutor chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'O Canada is the national anthem of Canada.'"
|
|
]
|
|
},
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent_chain.run(input=\"what is their national anthem called?\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cc3d0aa4",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can see that the agent remembered that the previous question was about Canada, and properly asked Google Search what the name of Canada's national anthem was.\n",
|
|
"\n",
|
|
"For fun, let's compare this to an agent that does NOT have memory."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "3359d043",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"prefix = \"\"\"Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:\"\"\"\n",
|
|
"suffix = \"\"\"Begin!\"\n",
|
|
"\n",
|
|
"Question: {input}\n",
|
|
"{agent_scratchpad}\"\"\"\n",
|
|
"\n",
|
|
"prompt = ZeroShotAgent.create_prompt(\n",
|
|
" tools, \n",
|
|
" prefix=prefix, \n",
|
|
" suffix=suffix, \n",
|
|
" input_variables=[\"input\", \"agent_scratchpad\"]\n",
|
|
")\n",
|
|
"llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)\n",
|
|
"agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)\n",
|
|
"agent_without_memory = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "970d23df",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mThought: I should look up the answer\n",
|
|
"Action: Search\n",
|
|
"Action Input: population of Canada\u001b[0m\n",
|
|
"Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,555,354 as of Monday, December 19, 2022, based on Worldometer elaboration of the latest United Nations data. · Canada 2020 ...\u001b[0m\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
|
"Final Answer: The current population of Canada is 38,555,354.\u001b[0m\n",
|
|
"\u001b[1m> Finished AgentExecutor chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'The current population of Canada is 38,555,354.'"
|
|
]
|
|
},
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent_without_memory.run(\"How many people live in canada?\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"id": "d9ea82f0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mThought: I should look up the answer\n",
|
|
"Action: Search\n",
|
|
"Action Input: national anthem of [country]\u001b[0m\n",
|
|
"Observation: \u001b[36;1m\u001b[1;3m\"Himno Nacional\" (\"National Anthem\") · \"Pátria\" (\"Fatherland\") · \"Salve, Oh Patria\" (\"We Salute You, Our Homeland\") · \"Bilady, Bilady, Bilady\" (\"My Country, My ...\u001b[0m\n",
|
|
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
|
"Final Answer: The national anthem of [country] is called \"Himno Nacional\", \"Pátria\", \"Salve, Oh Patria\", and \"Bilady, Bilady, Bilady\".\u001b[0m\n",
|
|
"\u001b[1m> Finished AgentExecutor chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'The national anthem of [country] is called \"Himno Nacional\", \"Pátria\", \"Salve, Oh Patria\", and \"Bilady, Bilady, Bilady\".'"
|
|
]
|
|
},
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent_without_memory.run(\"what is their national anthem called?\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5b1f9223",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|