From 8a338412fa471f4ef0151ed1480bfb8f7d949db3 Mon Sep 17 00:00:00 2001 From: Jinto Jose <129657162+jj701@users.noreply.github.com> Date: Mon, 8 May 2023 08:34:05 -0700 Subject: [PATCH] mongodb support for chat history (#4266) --- langchain/memory/__init__.py | 2 + .../memory/chat_message_histories/__init__.py | 2 + .../memory/chat_message_histories/mongodb.py | 98 +++++++++++++++++++ pyproject.toml | 2 + .../integration_tests/memory/test_mongodb.py | 36 +++++++ 5 files changed, 140 insertions(+) create mode 100644 langchain/memory/chat_message_histories/mongodb.py create mode 100644 tests/integration_tests/memory/test_mongodb.py diff --git a/langchain/memory/__init__.py b/langchain/memory/__init__.py index 2a6f3495..b5e9950c 100644 --- a/langchain/memory/__init__.py +++ b/langchain/memory/__init__.py @@ -7,6 +7,7 @@ from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessag from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory from langchain.memory.chat_message_histories.file import FileChatMessageHistory from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory +from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHistory from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory from langchain.memory.combined import CombinedMemory @@ -44,4 +45,5 @@ __all__ = [ "VectorStoreRetrieverMemory", "CosmosDBChatMessageHistory", "FileChatMessageHistory", + "MongoDBChatMessageHistory", ] diff --git a/langchain/memory/chat_message_histories/__init__.py b/langchain/memory/chat_message_histories/__init__.py index 4a6768a9..cb646aaa 100644 --- a/langchain/memory/chat_message_histories/__init__.py +++ b/langchain/memory/chat_message_histories/__init__.py @@ -4,6 +4,7 @@ from langchain.memory.chat_message_histories.file import FileChatMessageHistory from langchain.memory.chat_message_histories.firestore import ( FirestoreChatMessageHistory, ) +from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHistory from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory @@ -16,4 +17,5 @@ __all__ = [ "FileChatMessageHistory", "CosmosDBChatMessageHistory", "FirestoreChatMessageHistory", + "MongoDBChatMessageHistory", ] diff --git a/langchain/memory/chat_message_histories/mongodb.py b/langchain/memory/chat_message_histories/mongodb.py new file mode 100644 index 00000000..7995609b --- /dev/null +++ b/langchain/memory/chat_message_histories/mongodb.py @@ -0,0 +1,98 @@ +import json +import logging +from typing import List + +from langchain.schema import ( + AIMessage, + BaseChatMessageHistory, + BaseMessage, + HumanMessage, + _message_to_dict, + messages_from_dict, +) + +logger = logging.getLogger(__name__) + +DEFAULT_DBNAME = "chat_history" +DEFAULT_COLLECTION_NAME = "message_store" + + +class MongoDBChatMessageHistory(BaseChatMessageHistory): + """Chat message history that stores history in MongoDB. + + Args: + connection_string: connection string to connect to MongoDB + session_id: arbitrary key that is used to store the messages + of a single chat session. + database_name: name of the database to use + collection_name: name of the collection to use + """ + + def __init__( + self, + connection_string: str, + session_id: str, + database_name: str = DEFAULT_DBNAME, + collection_name: str = DEFAULT_COLLECTION_NAME, + ): + from pymongo import MongoClient, errors + + self.connection_string = connection_string + self.session_id = session_id + self.database_name = database_name + self.collection_name = collection_name + + try: + self.client: MongoClient = MongoClient(connection_string) + except errors.ConnectionFailure as error: + logger.error(error) + + self.db = self.client[database_name] + self.collection = self.db[collection_name] + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + """Retrieve the messages from MongoDB""" + from pymongo import errors + + try: + cursor = self.collection.find({"SessionId": self.session_id}) + except errors.OperationFailure as error: + logger.error(error) + + if cursor: + items = [json.loads(document["History"]) for document in cursor] + else: + items = [] + + messages = messages_from_dict(items) + return messages + + def add_user_message(self, message: str) -> None: + self.append(HumanMessage(content=message)) + + def add_ai_message(self, message: str) -> None: + self.append(AIMessage(content=message)) + + def append(self, message: BaseMessage) -> None: + """Append the message to the record in MongoDB""" + from pymongo import errors + + try: + self.collection.insert_one( + { + "SessionId": self.session_id, + "History": json.dumps(_message_to_dict(message)), + } + ) + except errors.WriteError as err: + logger.error(err) + + def clear(self) -> None: + """Clear session memory from MongoDB""" + from pymongo import errors + + try: + self.collection.delete_many({"SessionId": self.session_id}) + except errors.WriteError as err: + logger.error(err) diff --git a/pyproject.toml b/pyproject.toml index a4660c68..07b632b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ pyvespa = {version = "^0.33.0", optional = true} O365 = {version = "^2.0.26", optional = true} jq = {version = "^1.4.1", optional = true} + [tool.poetry.group.docs.dependencies] autodoc_pydantic = "^1.8.0" myst_parser = "^0.18.1" @@ -130,6 +131,7 @@ sentence-transformers = "^2" gptcache = "^0.1.9" promptlayer = "^0.1.80" tair = "^1.3.3" +pymongo = "^4.3.3" [tool.poetry.group.lint.dependencies] ruff = "^0.0.249" diff --git a/tests/integration_tests/memory/test_mongodb.py b/tests/integration_tests/memory/test_mongodb.py new file mode 100644 index 00000000..9e1b0f00 --- /dev/null +++ b/tests/integration_tests/memory/test_mongodb.py @@ -0,0 +1,36 @@ +import json +import os + +from langchain.memory import ConversationBufferMemory +from langchain.memory.chat_message_histories import MongoDBChatMessageHistory +from langchain.schema import _message_to_dict + +# Replace these with your mongodb connection string +connection_string = os.environ["MONGODB_CONNECTION_STRING"] + + +def test_memory_with_message_store() -> None: + """Test the memory with a message store.""" + # setup MongoDB as a message store + message_history = MongoDBChatMessageHistory( + connection_string=connection_string, session_id="test-session" + ) + memory = ConversationBufferMemory( + memory_key="baz", chat_memory=message_history, return_messages=True + ) + + # add some messages + memory.chat_memory.add_ai_message("This is me, the AI") + memory.chat_memory.add_user_message("This is me, the human") + + # get the message history from the memory store and turn it into a json + messages = memory.chat_memory.messages + messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) + + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + # remove the record from Azure Cosmos DB, so the next test run won't pick it up + memory.chat_memory.clear() + + assert memory.chat_memory.messages == []