diff --git a/docs/modules/memory/examples/agent_with_memory_in_db.ipynb b/docs/modules/memory/examples/agent_with_memory_in_db.ipynb new file mode 100644 index 00000000..201a6533 --- /dev/null +++ b/docs/modules/memory/examples/agent_with_memory_in_db.ipynb @@ -0,0 +1,353 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fa6802ac", + "metadata": {}, + "source": [ + "# Adding Message Memory backed by a database to an Agent\n", + "\n", + "This notebook goes over adding memory to an Agent where the memory uses an external message store. Before going through this notebook, please walkthrough the following notebooks, as this will build on top of both of them:\n", + "\n", + "- [Adding memory to an LLM Chain](adding_memory.ipynb)\n", + "- [Custom Agents](../../agents/examples/custom_agent.ipynb)\n", + "- [Agent with Memory](agetn_with_memory.ipynb)\n", + "\n", + "In order to add a memory with an external message store to an agent we are going to do the following steps:\n", + "\n", + "1. We are going to create a `RedisChatMessageHistory` to connect to an external database to store the messages in.\n", + "2. We are going to create an `LLMChain` useing that chat history as memory.\n", + "3. We are going to use that `LLMChain` to create a custom Agent.\n", + "\n", + "For the purposes of this exercise, we are going to create a simple custom Agent that has access to a search tool and utilizes the `ConversationBufferMemory` class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8db95912", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from langchain.agents import ZeroShotAgent, Tool, AgentExecutor\n", + "from langchain.memory import ConversationBufferMemory\n", + "from langchain.memory.chat_memory import ChatMessageHistory\n", + "from langchain.memory.chat_message_histories import RedisChatMessageHistory\n", + "from langchain import OpenAI, LLMChain\n", + "from langchain.utilities import GoogleSearchAPIWrapper" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "97ad8467", + "metadata": {}, + "outputs": [], + "source": [ + "search = GoogleSearchAPIWrapper()\n", + "tools = [\n", + " Tool(\n", + " name = \"Search\",\n", + " func=search.run,\n", + " description=\"useful for when you need to answer questions about current events\"\n", + " )\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "4ad2e708", + "metadata": {}, + "source": [ + "Notice the usage of the `chat_history` variable in the PromptTemplate, which matches up with the dynamic key name in the ConversationBufferMemory." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "e3439cd6", + "metadata": {}, + "outputs": [], + "source": [ + "prefix = \"\"\"Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:\"\"\"\n", + "suffix = \"\"\"Begin!\"\n", + "\n", + "{chat_history}\n", + "Question: {input}\n", + "{agent_scratchpad}\"\"\"\n", + "\n", + "prompt = ZeroShotAgent.create_prompt(\n", + " tools, \n", + " prefix=prefix, \n", + " suffix=suffix, \n", + " input_variables=[\"input\", \"chat_history\", \"agent_scratchpad\"]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Now we can create the ChatMessageHistory backed by the database." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "message_history = RedisChatMessageHistory(url='redis://localhost:6379/0', ttl=600, session_id='my-session')\n", + "\n", + "memory = ConversationBufferMemory(memory_key=\"chat_history\", chat_memory=message_history)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "id": "0021675b", + "metadata": {}, + "source": [ + "We can now construct the LLMChain, with the Memory object, and then create the agent." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c56a0e73", + "metadata": {}, + "outputs": [], + "source": [ + "llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)\n", + "agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)\n", + "agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "ca4bc1fb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mThought: I need to find out the population of Canada\n", + "Action: Search\n", + "Action Input: Population of Canada\u001B[0m\n", + "Observation: \u001B[36;1m\u001B[1;3mThe current population of Canada is 38,566,192 as of Saturday, December 31, 2022, based on Worldometer elaboration of the latest United Nations data. · Canada ... Additional information related to Canadian population trends can be found on Statistics Canada's Population and Demography Portal. Population of Canada (real- ... Index to the latest information from the Census of Population. This survey conducted by Statistics Canada provides a statistical portrait of Canada and its ... 14 records ... Estimated number of persons by quarter of a year and by year, Canada, provinces and territories. The 2021 Canadian census counted a total population of 36,991,981, an increase of around 5.2 percent over the 2016 figure. ... Between 1990 and 2008, the ... ( 2 ) Census reports and other statistical publications from national statistical offices, ( 3 ) Eurostat: Demographic Statistics, ( 4 ) United Nations ... Canada is a country in North America. Its ten provinces and three territories extend from ... Population. • Q4 2022 estimate. 39,292,355 (37th). Information is available for the total Indigenous population and each of the three ... The term 'Aboriginal' or 'Indigenous' used on the Statistics Canada ... Jun 14, 2022 ... Determinants of health are the broad range of personal, social, economic and environmental factors that determine individual and population ... COVID-19 vaccination coverage across Canada by demographics and key populations. Updated every Friday at 12:00 PM Eastern Time.\u001B[0m\n", + "Thought:\u001B[32;1m\u001B[1;3m I now know the final answer\n", + "Final Answer: The current population of Canada is 38,566,192 as of Saturday, December 31, 2022, based on Worldometer elaboration of the latest United Nations data.\u001B[0m\n", + "\u001B[1m> Finished AgentExecutor chain.\u001B[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'The current population of Canada is 38,566,192 as of Saturday, December 31, 2022, based on Worldometer elaboration of the latest United Nations data.'" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_chain.run(input=\"How many people live in canada?\")" + ] + }, + { + "cell_type": "markdown", + "id": "45627664", + "metadata": {}, + "source": [ + "To test the memory of this agent, we can ask a followup question that relies on information in the previous exchange to be answered correctly." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "eecc0462", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mThought: I need to find out what the national anthem of Canada is called.\n", + "Action: Search\n", + "Action Input: National Anthem of Canada\u001B[0m\n", + "Observation: \u001B[36;1m\u001B[1;3mJun 7, 2010 ... https://twitter.com/CanadaImmigrantCanadian National Anthem O Canada in HQ - complete with lyrics, captions, vocals & music.LYRICS:O Canada! Nov 23, 2022 ... After 100 years of tradition, O Canada was proclaimed Canada's national anthem in 1980. The music for O Canada was composed in 1880 by Calixa ... O Canada, national anthem of Canada. It was proclaimed the official national anthem on July 1, 1980. “God Save the Queen” remains the royal anthem of Canada ... O Canada! Our home and native land! True patriot love in all of us command. Car ton bras sait porter l'épée,. Il sait porter la croix! \"O Canada\" (French: Ô Canada) is the national anthem of Canada. The song was originally commissioned by Lieutenant Governor of Quebec Théodore Robitaille ... Feb 1, 2018 ... It was a simple tweak — just two words. But with that, Canada just voted to make its national anthem, “O Canada,” gender neutral, ... \"O Canada\" was proclaimed Canada's national anthem on July 1,. 1980, 100 years after it was first sung on June 24, 1880. The music. Patriotic music in Canada dates back over 200 years as a distinct category from British or French patriotism, preceding the first legal steps to ... Feb 4, 2022 ... English version: O Canada! Our home and native land! True patriot love in all of us command. With glowing hearts we ... Feb 1, 2018 ... Canada's Senate has passed a bill making the country's national anthem gender-neutral. If you're not familiar with the words to “O Canada,” ...\u001B[0m\n", + "Thought:\u001B[32;1m\u001B[1;3m I now know the final answer.\n", + "Final Answer: The national anthem of Canada is called \"O Canada\".\u001B[0m\n", + "\u001B[1m> Finished AgentExecutor chain.\u001B[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'The national anthem of Canada is called \"O Canada\".'" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_chain.run(input=\"what is their national anthem called?\")" + ] + }, + { + "cell_type": "markdown", + "id": "cc3d0aa4", + "metadata": {}, + "source": [ + "We can see that the agent remembered that the previous question was about Canada, and properly asked Google Search what the name of Canada's national anthem was.\n", + "\n", + "For fun, let's compare this to an agent that does NOT have memory." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "3359d043", + "metadata": {}, + "outputs": [], + "source": [ + "prefix = \"\"\"Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:\"\"\"\n", + "suffix = \"\"\"Begin!\"\n", + "\n", + "Question: {input}\n", + "{agent_scratchpad}\"\"\"\n", + "\n", + "prompt = ZeroShotAgent.create_prompt(\n", + " tools, \n", + " prefix=prefix, \n", + " suffix=suffix, \n", + " input_variables=[\"input\", \"agent_scratchpad\"]\n", + ")\n", + "llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)\n", + "agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)\n", + "agent_without_memory = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "970d23df", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mThought: I need to find out the population of Canada\n", + "Action: Search\n", + "Action Input: Population of Canada\u001B[0m\n", + "Observation: \u001B[36;1m\u001B[1;3mThe current population of Canada is 38,566,192 as of Saturday, December 31, 2022, based on Worldometer elaboration of the latest United Nations data. · Canada ... Additional information related to Canadian population trends can be found on Statistics Canada's Population and Demography Portal. Population of Canada (real- ... Index to the latest information from the Census of Population. This survey conducted by Statistics Canada provides a statistical portrait of Canada and its ... 14 records ... Estimated number of persons by quarter of a year and by year, Canada, provinces and territories. The 2021 Canadian census counted a total population of 36,991,981, an increase of around 5.2 percent over the 2016 figure. ... Between 1990 and 2008, the ... ( 2 ) Census reports and other statistical publications from national statistical offices, ( 3 ) Eurostat: Demographic Statistics, ( 4 ) United Nations ... Canada is a country in North America. Its ten provinces and three territories extend from ... Population. • Q4 2022 estimate. 39,292,355 (37th). Information is available for the total Indigenous population and each of the three ... The term 'Aboriginal' or 'Indigenous' used on the Statistics Canada ... Jun 14, 2022 ... Determinants of health are the broad range of personal, social, economic and environmental factors that determine individual and population ... COVID-19 vaccination coverage across Canada by demographics and key populations. Updated every Friday at 12:00 PM Eastern Time.\u001B[0m\n", + "Thought:\u001B[32;1m\u001B[1;3m I now know the final answer\n", + "Final Answer: The current population of Canada is 38,566,192 as of Saturday, December 31, 2022, based on Worldometer elaboration of the latest United Nations data.\u001B[0m\n", + "\u001B[1m> Finished AgentExecutor chain.\u001B[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'The current population of Canada is 38,566,192 as of Saturday, December 31, 2022, based on Worldometer elaboration of the latest United Nations data.'" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_without_memory.run(\"How many people live in canada?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "d9ea82f0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n", + "\u001B[32;1m\u001B[1;3mThought: I should look up the answer\n", + "Action: Search\n", + "Action Input: national anthem of [country]\u001B[0m\n", + "Observation: \u001B[36;1m\u001B[1;3mMost nation states have an anthem, defined as \"a song, as of praise, devotion, or patriotism\"; most anthems are either marches or hymns in style. List of all countries around the world with its national anthem. ... Title and lyrics in the language of the country and translated into English, Aug 1, 2021 ... 1. Afghanistan, \"Milli Surood\" (National Anthem) · 2. Armenia, \"Mer Hayrenik\" (Our Fatherland) · 3. Azerbaijan (a transcontinental country with ... A national anthem is a patriotic musical composition symbolizing and evoking eulogies of the history and traditions of a country or nation. National Anthem of Every Country ; Fiji, “Meda Dau Doka” (“God Bless Fiji”) ; Finland, “Maamme”. (“Our Land”) ; France, “La Marseillaise” (“The Marseillaise”). You can find an anthem in the menu at the top alphabetically or you can use the search feature. This site is focussed on the scholarly study of national anthems ... Feb 13, 2022 ... The 38-year-old country music artist had the honor of singing the National Anthem during this year's big game, and she did not disappoint. Oldest of the World's National Anthems ; France, La Marseillaise (“The Marseillaise”), 1795 ; Argentina, Himno Nacional Argentino (“Argentine National Anthem”) ... Mar 3, 2022 ... Country music star Jessie James Decker gained the respect of music and hockey fans alike after a jaw-dropping rendition of \"The Star-Spangled ... This list shows the country on the left, the national anthem in the ... There are many countries over the world who have a national anthem of their own.\u001B[0m\n", + "Thought:\u001B[32;1m\u001B[1;3m I now know the final answer\n", + "Final Answer: The national anthem of [country] is [name of anthem].\u001B[0m\n", + "\u001B[1m> Finished AgentExecutor chain.\u001B[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'The national anthem of [country] is [name of anthem].'" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_without_memory.run(\"what is their national anthem called?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b1f9223", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/modules/memory/examples/redis_chat_message_history.ipynb b/docs/modules/memory/examples/redis_chat_message_history.ipynb new file mode 100644 index 00000000..e4876131 --- /dev/null +++ b/docs/modules/memory/examples/redis_chat_message_history.ipynb @@ -0,0 +1,81 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "91c6a7ef", + "metadata": {}, + "source": [ + "# Redis Chat Message History\n", + "\n", + "This notebook goes over how to use Redis to store chat message history." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d15e3302", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.memory import RedisChatMessageHistory\n", + "\n", + "history = RedisChatMessageHistory(\"foo\")\n", + "\n", + "history.add_user_message(\"hi!\")\n", + "\n", + "history.add_ai_message(\"whats up?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "64fc465e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[AIMessage(content='whats up?', additional_kwargs={}),\n", + " HumanMessage(content='hi!', additional_kwargs={})]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "history.messages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8af285f8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/modules/memory/how_to_guides.rst b/docs/modules/memory/how_to_guides.rst index 6c36cd2f..4994c0ad 100644 --- a/docs/modules/memory/how_to_guides.rst +++ b/docs/modules/memory/how_to_guides.rst @@ -18,7 +18,6 @@ Usage ----- The examples here all highlight how to use memory in different ways. - .. toctree:: :maxdepth: 1 :glob: diff --git a/langchain/memory/__init__.py b/langchain/memory/__init__.py index 65fe5d62..5799e734 100644 --- a/langchain/memory/__init__.py +++ b/langchain/memory/__init__.py @@ -3,7 +3,9 @@ from langchain.memory.buffer import ( ConversationStringBufferMemory, ) from langchain.memory.buffer_window import ConversationBufferWindowMemory -from langchain.memory.chat_memory import ChatMessageHistory +from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory +from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory +from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory from langchain.memory.combined import CombinedMemory from langchain.memory.entity import ConversationEntityMemory from langchain.memory.kg import ConversationKGMemory @@ -26,4 +28,6 @@ __all__ = [ "ConversationStringBufferMemory", "ReadOnlySharedMemory", "ConversationTokenBufferMemory", + "RedisChatMessageHistory", + "DynamoDBChatMessageHistory", ] diff --git a/langchain/memory/chat_memory.py b/langchain/memory/chat_memory.py index 0dc80103..3fbf35e7 100644 --- a/langchain/memory/chat_memory.py +++ b/langchain/memory/chat_memory.py @@ -1,27 +1,18 @@ from abc import ABC -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional -from pydantic import BaseModel, Field +from pydantic import Field +from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory from langchain.memory.utils import get_prompt_input_key -from langchain.schema import AIMessage, BaseMemory, BaseMessage, HumanMessage - - -class ChatMessageHistory(BaseModel): - messages: List[BaseMessage] = Field(default_factory=list) - - def add_user_message(self, message: str) -> None: - self.messages.append(HumanMessage(content=message)) - - def add_ai_message(self, message: str) -> None: - self.messages.append(AIMessage(content=message)) - - def clear(self) -> None: - self.messages = [] +from langchain.schema import ( + BaseChatMessageHistory, + BaseMemory, +) class BaseChatMemory(BaseMemory, ABC): - chat_memory: ChatMessageHistory = Field(default_factory=ChatMessageHistory) + chat_memory: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory) output_key: Optional[str] = None input_key: Optional[str] = None return_messages: bool = False diff --git a/langchain/memory/chat_message_histories/__init__.py b/langchain/memory/chat_message_histories/__init__.py new file mode 100644 index 00000000..199844ef --- /dev/null +++ b/langchain/memory/chat_message_histories/__init__.py @@ -0,0 +1,7 @@ +from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory +from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory + +__all__ = [ + "DynamoDBChatMessageHistory", + "RedisChatMessageHistory", +] diff --git a/langchain/memory/chat_message_histories/dynamodb.py b/langchain/memory/chat_message_histories/dynamodb.py new file mode 100644 index 00000000..413183ea --- /dev/null +++ b/langchain/memory/chat_message_histories/dynamodb.py @@ -0,0 +1,84 @@ +import logging +from typing import List + +from langchain.schema import ( + AIMessage, + BaseChatMessageHistory, + BaseMessage, + HumanMessage, + _message_to_dict, + messages_from_dict, + messages_to_dict, +) + +logger = logging.getLogger(__name__) + + +class DynamoDBChatMessageHistory(BaseChatMessageHistory): + """Chat message history that stores history in AWS DynamoDB. + This class expects that a DynamoDB table with name `table_name` + and a partition Key of `SessionId` is present. + + Args: + table_name: name of the DynamoDB table + session_id: arbitrary key that is used to store the messages + of a single chat session. + """ + + def __init__(self, table_name: str, session_id: str): + import boto3 + + client = boto3.resource("dynamodb") + self.table = client.Table(table_name) + self.session_id = session_id + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + """Retrieve the messages from DynamoDB""" + from botocore.exceptions import ClientError + + try: + response = self.table.get_item(Key={"SessionId": self.session_id}) + except ClientError as error: + if error.response["Error"]["Code"] == "ResourceNotFoundException": + logger.warning("No record found with session id: %s", self.session_id) + else: + logger.error(error) + + if response and "Item" in response: + items = response["Item"]["History"] + else: + items = [] + + messages = messages_from_dict(items) + return messages + + def add_user_message(self, message: str) -> None: + self.append(HumanMessage(content=message)) + + def add_ai_message(self, message: str) -> None: + self.append(AIMessage(content=message)) + + def append(self, message: BaseMessage) -> None: + """Append the message to the record in DynamoDB""" + from botocore.exceptions import ClientError + + messages = messages_to_dict(self.messages) + _message = _message_to_dict(message) + messages.append(_message) + + try: + self.table.put_item( + Item={"SessionId": self.session_id, "History": messages} + ) + except ClientError as err: + logger.error(err) + + def clear(self) -> None: + """Clear session memory from DynamoDB""" + from botocore.exceptions import ClientError + + try: + self.table.delete_item(Key={"SessionId": self.session_id}) + except ClientError as err: + logger.error(err) diff --git a/langchain/memory/chat_message_histories/in_memory.py b/langchain/memory/chat_message_histories/in_memory.py new file mode 100644 index 00000000..0760bd3c --- /dev/null +++ b/langchain/memory/chat_message_histories/in_memory.py @@ -0,0 +1,23 @@ +from typing import List + +from pydantic import BaseModel + +from langchain.schema import ( + AIMessage, + BaseChatMessageHistory, + BaseMessage, + HumanMessage, +) + + +class ChatMessageHistory(BaseChatMessageHistory, BaseModel): + messages: List[BaseMessage] = [] + + def add_user_message(self, message: str) -> None: + self.messages.append(HumanMessage(content=message)) + + def add_ai_message(self, message: str) -> None: + self.messages.append(AIMessage(content=message)) + + def clear(self) -> None: + self.messages = [] diff --git a/langchain/memory/chat_message_histories/redis.py b/langchain/memory/chat_message_histories/redis.py new file mode 100644 index 00000000..86c025c7 --- /dev/null +++ b/langchain/memory/chat_message_histories/redis.py @@ -0,0 +1,69 @@ +import json +import logging +from typing import List, Optional + +from langchain.schema import ( + AIMessage, + BaseChatMessageHistory, + BaseMessage, + HumanMessage, + _message_to_dict, + messages_from_dict, +) + +logger = logging.getLogger(__name__) + + +class RedisChatMessageHistory(BaseChatMessageHistory): + def __init__( + self, + session_id: str, + url: str = "redis://localhost:6379/0", + key_prefix: str = "message_store:", + ttl: Optional[int] = None, + ): + try: + import redis + except ImportError: + raise ValueError( + "Could not import redis python package. " + "Please install it with `pip install redis`." + ) + + try: + self.redis_client = redis.Redis.from_url(url=url) + except redis.exceptions.ConnectionError as error: + logger.error(error) + + self.session_id = session_id + self.key_prefix = key_prefix + self.ttl = ttl + + @property + def key(self) -> str: + """Construct the record key to use""" + return self.key_prefix + self.session_id + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + """Retrieve the messages from Redis""" + _items = self.redis_client.lrange(self.key, 0, -1) + items = [json.loads(m.decode("utf-8")) for m in _items] + messages = messages_from_dict(items) + return messages + + def add_user_message(self, message: str) -> None: + self.append(HumanMessage(content=message)) + + def add_ai_message(self, message: str) -> None: + self.append(AIMessage(content=message)) + + def append(self, message: BaseMessage) -> None: + """Append the message to the record in Redis""" + self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message))) + if self.ttl: + self.redis_client.expire(self.key, self.ttl) + + def clear(self) -> None: + """Clear session memory from Redis""" + self.redis_client.delete(self.key) diff --git a/langchain/schema.py b/langchain/schema.py index 8d4676ae..14ec1dfa 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -240,6 +240,57 @@ class BaseMemory(BaseModel, ABC): """Clear memory contents.""" +class BaseChatMessageHistory(ABC): + """Base interface for chat message history + See `ChatMessageHistory` for default implementation. + """ + + """ + Example: + .. code-block:: python + + class FileChatMessageHistory(BaseChatMessageHistory): + storage_path: str + session_id: str + + @property + def messages(self): + with open(os.path.join(storage_path, session_id), 'r:utf-8') as f: + messages = json.loads(f.read()) + return messages_from_dict(messages) + + def add_user_message(self, message: str): + message_ = HumanMessage(content=message) + messages = self.messages.append(_message_to_dict(_message)) + with open(os.path.join(storage_path, session_id), 'w') as f: + json.dump(f, messages) + + def add_ai_message(self, message: str): + message_ = AIMessage(content=message) + messages = self.messages.append(_message_to_dict(_message)) + with open(os.path.join(storage_path, session_id), 'w') as f: + json.dump(f, messages) + + def clear(self): + with open(os.path.join(storage_path, session_id), 'w') as f: + f.write("[]") + """ + + messages: List[BaseMessage] + + @abstractmethod + def add_user_message(self, message: str) -> None: + """Add a user message to the store""" + + @abstractmethod + def add_ai_message(self, message: str) -> None: + """Add an AI message to the store""" + + @abstractmethod + def clear(self) -> None: + """Remove all messages from the store""" + + class Document(BaseModel): """Interface for interacting with a document.""" diff --git a/poetry.lock b/poetry.lock index 6bc3a0aa..a815ba80 100644 --- a/poetry.lock +++ b/poetry.lock @@ -12,48 +12,6 @@ files = [ {file = "absl_py-1.4.0-py3-none-any.whl", hash = "sha256:0d3fe606adfa4f7db64792dd4c7aee4ee0c38ab75dfd353b7a83ed3e957fcb47"}, ] -[[package]] -name = "aioboto3" -version = "10.4.0" -description = "Async boto3 wrapper" -category = "main" -optional = true -python-versions = ">=3.7,<4.0" -files = [ - {file = "aioboto3-10.4.0-py3-none-any.whl", hash = "sha256:6d0f0bf6af0168c27828e108f1a24182669a6ea6939437c27638caf06a693403"}, - {file = "aioboto3-10.4.0.tar.gz", hash = "sha256:e52b5f96b67031ddcbabcc55015bad3f851d3d4e6d5bfc7a1d1518d90e0c1fd8"}, -] - -[package.dependencies] -aiobotocore = {version = "2.4.2", extras = ["boto3"]} - -[package.extras] -chalice = ["chalice (>=1.24.0)"] -s3cse = ["cryptography (>=2.3.1)"] - -[[package]] -name = "aiobotocore" -version = "2.4.2" -description = "Async client for aws services using botocore and aiohttp" -category = "main" -optional = true -python-versions = ">=3.7" -files = [ - {file = "aiobotocore-2.4.2-py3-none-any.whl", hash = "sha256:4acd1ebe2e44be4b100aa553910bda899f6dc090b3da2bc1cf3d5de2146ed208"}, - {file = "aiobotocore-2.4.2.tar.gz", hash = "sha256:0603b74a582dffa7511ce7548d07dc9b10ec87bc5fb657eb0b34f9bd490958bf"}, -] - -[package.dependencies] -aiohttp = ">=3.3.1" -aioitertools = ">=0.5.1" -boto3 = {version = ">=1.24.59,<1.24.60", optional = true, markers = "extra == \"boto3\""} -botocore = ">=1.27.59,<1.27.60" -wrapt = ">=1.10.10" - -[package.extras] -awscli = ["awscli (>=1.25.60,<1.25.61)"] -boto3 = ["boto3 (>=1.24.59,<1.24.60)"] - [[package]] name = "aiodns" version = "3.0.0" @@ -205,21 +163,6 @@ files = [ [package.dependencies] aiohttp = "*" -[[package]] -name = "aioitertools" -version = "0.11.0" -description = "itertools and builtins for AsyncIO and mixed iterables" -category = "main" -optional = true -python-versions = ">=3.6" -files = [ - {file = "aioitertools-0.11.0-py3-none-any.whl", hash = "sha256:04b95e3dab25b449def24d7df809411c10e62aab0cbe31a50ca4e68748c43394"}, - {file = "aioitertools-0.11.0.tar.gz", hash = "sha256:42c68b8dd3a69c2bf7f2233bf7df4bb58b557bca5252ac02ed5187bbc67d6831"}, -] - -[package.dependencies] -typing_extensions = {version = ">=4.0", markers = "python_version < \"3.10\""} - [[package]] name = "aiosignal" version = "1.3.1" @@ -698,18 +641,18 @@ numpy = ">=1.15.0" [[package]] name = "boto3" -version = "1.24.59" +version = "1.26.101" description = "The AWS SDK for Python" category = "main" optional = true python-versions = ">= 3.7" files = [ - {file = "boto3-1.24.59-py3-none-any.whl", hash = "sha256:34ab44146a2c4e7f4e72737f4b27e6eb5e0a7855c2f4599e3d9199b6a0a2d575"}, - {file = "boto3-1.24.59.tar.gz", hash = "sha256:a50b4323f9579cfe22fcf5531fbd40b567d4d74c1adce06aeb5c95fce2a6fb40"}, + {file = "boto3-1.26.101-py3-none-any.whl", hash = "sha256:5f5279a63b359ba8889e9a81b319e745b14216608ffb5a39fcbf269d1af1ea83"}, + {file = "boto3-1.26.101.tar.gz", hash = "sha256:670ae4d1875a2162e11c6e941888817c3e9cf1bb9a3335b3588d805b7d24da31"}, ] [package.dependencies] -botocore = ">=1.27.59,<1.28.0" +botocore = ">=1.29.101,<1.30.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.6.0,<0.7.0" @@ -718,14 +661,14 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.27.59" +version = "1.29.101" description = "Low-level, data-driven core of boto 3." category = "main" optional = true python-versions = ">= 3.7" files = [ - {file = "botocore-1.27.59-py3-none-any.whl", hash = "sha256:69d756791fc024bda54f6c53f71ae34e695ee41bbbc1743d9179c4837a4929da"}, - {file = "botocore-1.27.59.tar.gz", hash = "sha256:eda4aed6ee719a745d1288eaf1beb12f6f6448ad1fa12f159405db14ba9c92cf"}, + {file = "botocore-1.29.101-py3-none-any.whl", hash = "sha256:60c7a7bf8e2a288735e507007a6769be03dc24815f7dc5c7b59b12743f4a31cf"}, + {file = "botocore-1.29.101.tar.gz", hash = "sha256:7bb60d9d4c49500df55dfb6005c16002703333ff5f69dada565167c8d493dfd5"}, ] [package.dependencies] @@ -734,7 +677,7 @@ python-dateutil = ">=2.1,<3.0.0" urllib3 = ">=1.25.4,<1.27" [package.extras] -crt = ["awscrt (==0.14.0)"] +crt = ["awscrt (==0.16.9)"] [[package]] name = "cachetools" @@ -1220,21 +1163,20 @@ files = [ [[package]] name = "deeplake" -version = "3.2.18" +version = "3.2.15" description = "Activeloop Deep Lake" category = "main" optional = true python-versions = "*" files = [ - {file = "deeplake-3.2.18.tar.gz", hash = "sha256:cb381fc771b08b32415efbb88c4adb57fc54ffa01f19d86b76dd6f839108799a"}, + {file = "deeplake-3.2.15.tar.gz", hash = "sha256:c9a72fa059ee106a90592a6bab411f88f6efcfaa32ed4d4298def59ce737c762"}, ] [package.dependencies] -aioboto3 = {version = "10.4.0", markers = "python_version >= \"3.7\" and sys_platform != \"win32\""} boto3 = "*" click = "*" +hub = ">=2.8.7" humbug = ">=0.2.6" -nest_asyncio = {version = "*", markers = "python_version >= \"3.7\" and sys_platform != \"win32\""} numcodecs = "*" numpy = "*" pathos = "*" @@ -1243,11 +1185,11 @@ pyjwt = "*" tqdm = "*" [package.extras] -all = ["IPython", "av (>=8.1.0)", "flask", "google-api-python-client (>=2.31.0,<2.32.0)", "google-auth (>=2.0.1,<2.1.0)", "google-auth-oauthlib (>=0.4.5,<0.5.0)", "google-cloud-storage (>=1.42.0,<1.43.0)", "laspy", "libdeeplake (==0.0.41)", "nibabel", "oauth2client (>=4.1.3,<4.2.0)", "pydicom"] +all = ["IPython", "av (>=8.1.0)", "flask", "google-api-python-client (>=2.31.0,<2.32.0)", "google-auth (>=2.0.1,<2.1.0)", "google-auth-oauthlib (>=0.4.5,<0.5.0)", "google-cloud-storage (>=1.42.0,<1.43.0)", "laspy", "libdeeplake (==0.0.40)", "nibabel", "oauth2client (>=4.1.3,<4.2.0)", "pydicom"] audio = ["av (>=8.1.0)"] av = ["av (>=8.1.0)"] dicom = ["nibabel", "pydicom"] -enterprise = ["libdeeplake (==0.0.41)", "pyjwt"] +enterprise = ["libdeeplake (==0.0.40)", "pyjwt"] gcp = ["google-auth (>=2.0.1,<2.1.0)", "google-auth-oauthlib (>=0.4.5,<0.5.0)", "google-cloud-storage (>=1.42.0,<1.43.0)"] gdrive = ["google-api-python-client (>=2.31.0,<2.32.0)", "google-auth (>=2.0.1,<2.1.0)", "google-auth-oauthlib (>=0.4.5,<0.5.0)", "oauth2client (>=4.1.3,<4.2.0)"] medical = ["nibabel", "pydicom"] @@ -2366,6 +2308,21 @@ cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<13)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (>=1.0.0,<2.0.0)"] +[[package]] +name = "hub" +version = "3.0.1" +description = "Activeloop Deep Lake" +category = "main" +optional = true +python-versions = "*" +files = [ + {file = "hub-3.0.1-py3-none-any.whl", hash = "sha256:16d20fbf44700b438dc372d697ee683f0d7c58b178ad01d2daf81efe88bc692d"}, + {file = "hub-3.0.1.tar.gz", hash = "sha256:3866425914ed522090f0634887f06ff77517e8e6d7b9370e42009d774b725514"}, +] + +[package.dependencies] +deeplake = "*" + [[package]] name = "huggingface-hub" version = "0.13.3" @@ -3887,7 +3844,7 @@ traitlets = ">=5" name = "nest-asyncio" version = "1.5.6" description = "Patch asyncio to allow nested event loops" -category = "main" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -8535,4 +8492,4 @@ llms = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "manifes [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "b8dd776b9c9bfda2413bdc0c58eb9cc8975a0ec770a830af5568cc80e4a1925d" +content-hash = "3dd7ff0edb145aff1ba1ea7e35ffa4b224fb71f934be0e13d4427fd796e54869" diff --git a/pyproject.toml b/pyproject.toml index 1374f5ec..a2feef4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ aleph-alpha-client = {version="^2.15.0", optional = true} deeplake = {version = "^3.2.9", optional = true} pgvector = {version = "^0.1.6", optional = true} psycopg2-binary = {version = "^2.9.5", optional = true} +boto3 = {version = "^1.26.96", optional = true} pyowm = {version = "^3.3.0", optional = true} [tool.poetry.group.docs.dependencies] diff --git a/tests/integration_tests/memory/__init__.py b/tests/integration_tests/memory/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/memory/test_redis.py b/tests/integration_tests/memory/test_redis.py new file mode 100644 index 00000000..16ea653c --- /dev/null +++ b/tests/integration_tests/memory/test_redis.py @@ -0,0 +1,30 @@ +import json + +from langchain.memory import ConversationBufferMemory +from langchain.memory.chat_message_histories import RedisChatMessageHistory +from langchain.schema import _message_to_dict + + +def test_memory_with_message_store() -> None: + """Test the memory with a message store.""" + # setup Redis as a message store + message_history = RedisChatMessageHistory( + url="redis://localhost:6379/0", ttl=10, session_id="my-test-session" + ) + memory = ConversationBufferMemory( + memory_key="baz", chat_memory=message_history, return_messages=True + ) + + # add some messages + memory.chat_memory.add_ai_message("This is me, the AI") + memory.chat_memory.add_user_message("This is me, the human") + + # get the message history from the memory store and turn it into a json + messages = memory.chat_memory.messages + messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) + + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + # remove the record from Redis, so the next test run won't pick it up + memory.chat_memory.clear()