From d4cf1eb60a83ac602bfd049fb2868bf68c0d6206 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 3 May 2023 22:55:47 -0700 Subject: [PATCH] Add firestore memory (#3792) (#3941) If you have any other suggestions or feedback, please let me know. --------- Co-authored-by: yakigac <10434946+yakigac@users.noreply.github.com> --- .../memory/chat_message_histories/__init__.py | 4 + .../chat_message_histories/firestore.py | 112 ++++++++++++++++++ .../memory/test_firestore.py | 43 +++++++ 3 files changed, 159 insertions(+) create mode 100644 langchain/memory/chat_message_histories/firestore.py create mode 100644 tests/integration_tests/memory/test_firestore.py diff --git a/langchain/memory/chat_message_histories/__init__.py b/langchain/memory/chat_message_histories/__init__.py index a891b57f..4a6768a9 100644 --- a/langchain/memory/chat_message_histories/__init__.py +++ b/langchain/memory/chat_message_histories/__init__.py @@ -1,6 +1,9 @@ from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory from langchain.memory.chat_message_histories.file import FileChatMessageHistory +from langchain.memory.chat_message_histories.firestore import ( + FirestoreChatMessageHistory, +) from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory @@ -12,4 +15,5 @@ __all__ = [ "SQLChatMessageHistory", "FileChatMessageHistory", "CosmosDBChatMessageHistory", + "FirestoreChatMessageHistory", ] diff --git a/langchain/memory/chat_message_histories/firestore.py b/langchain/memory/chat_message_histories/firestore.py new file mode 100644 index 00000000..15bbe253 --- /dev/null +++ b/langchain/memory/chat_message_histories/firestore.py @@ -0,0 +1,112 @@ +"""Firestore Chat Message History.""" +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, List, Optional + +from langchain.schema import ( + AIMessage, + BaseChatMessageHistory, + BaseMessage, + HumanMessage, + messages_from_dict, + messages_to_dict, +) + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from google.cloud.firestore import DocumentReference + + +class FirestoreChatMessageHistory(BaseChatMessageHistory): + """Chat history backed by Google Firestore.""" + + def __init__( + self, + collection_name: str, + session_id: str, + user_id: str, + ): + """ + Initialize a new instance of the FirestoreChatMessageHistory class. + + :param collection_name: The name of the collection to use. + :param session_id: The session ID for the chat.. + :param user_id: The user ID for the chat. + """ + self.collection_name = collection_name + self.session_id = session_id + self.user_id = user_id + + self._document: Optional[DocumentReference] = None + self.messages: List[BaseMessage] = [] + + self.prepare_firestore() + + def prepare_firestore(self) -> None: + """Prepare the Firestore client. + + Use this function to make sure your database is ready. + """ + try: + import firebase_admin + from firebase_admin import firestore + except ImportError as e: + logger.error( + "Failed to import Firebase and Firestore: %s. " + "Make sure to install the 'firebase-admin' module.", + e, + ) + raise e + + # For multiple instances, only initialize the app once. + try: + firebase_admin.get_app() + except ValueError as e: + logger.debug("Initializing Firebase app: %s", e) + firebase_admin.initialize_app() + + self.firestore_client = firestore.client() + self._document = self.firestore_client.collection( + self.collection_name + ).document(self.session_id) + self.load_messages() + + def load_messages(self) -> None: + """Retrieve the messages from Firestore""" + if not self._document: + raise ValueError("Document not initialized") + doc = self._document.get() + if doc.exists: + data = doc.to_dict() + if "messages" in data and len(data["messages"]) > 0: + self.messages = messages_from_dict(data["messages"]) + + def add_user_message(self, message: str) -> None: + """Add a user message to the memory.""" + self.upsert_messages(HumanMessage(content=message)) + + def add_ai_message(self, message: str) -> None: + """Add a AI message to the memory.""" + self.upsert_messages(AIMessage(content=message)) + + def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None: + """Update the Firestore document.""" + if new_message: + self.messages.append(new_message) + if not self._document: + raise ValueError("Document not initialized") + self._document.set( + { + "id": self.session_id, + "user_id": self.user_id, + "messages": messages_to_dict(self.messages), + } + ) + + def clear(self) -> None: + """Clear session memory from this memory and Firestore.""" + self.messages = [] + if self._document: + self._document.delete() diff --git a/tests/integration_tests/memory/test_firestore.py b/tests/integration_tests/memory/test_firestore.py new file mode 100644 index 00000000..0391b39e --- /dev/null +++ b/tests/integration_tests/memory/test_firestore.py @@ -0,0 +1,43 @@ +import json + +from langchain.memory import ConversationBufferMemory +from langchain.memory.chat_message_histories import FirestoreChatMessageHistory +from langchain.schema import _message_to_dict + + +def test_memory_with_message_store() -> None: + """Test the memory with a message store.""" + + message_history = FirestoreChatMessageHistory( + collection_name="chat_history", + session_id="my-test-session", + user_id="my-test-user", + ) + memory = ConversationBufferMemory( + memory_key="baz", chat_memory=message_history, return_messages=True + ) + + # add some messages + memory.chat_memory.add_ai_message("This is me, the AI") + memory.chat_memory.add_user_message("This is me, the human") + + # get the message history from the memory store + # and check if the messages are there as expected + message_history = FirestoreChatMessageHistory( + collection_name="chat_history", + session_id="my-test-session", + user_id="my-test-user", + ) + memory = ConversationBufferMemory( + memory_key="baz", chat_memory=message_history, return_messages=True + ) + messages = memory.chat_memory.messages + messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) + + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + # remove the record from Firestore, so the next test run won't pick it up + memory.chat_memory.clear() + + assert memory.chat_memory.messages == []