diff --git a/libs/partners/mongodb/langchain_mongodb/chat_message_histories.py b/libs/partners/mongodb/langchain_mongodb/chat_message_histories.py index 38c3a9271e..d50538b3c0 100644 --- a/libs/partners/mongodb/langchain_mongodb/chat_message_histories.py +++ b/libs/partners/mongodb/langchain_mongodb/chat_message_histories.py @@ -1,6 +1,6 @@ import json import logging -from typing import List +from typing import Dict, List, Optional from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import ( @@ -14,6 +14,8 @@ logger = logging.getLogger(__name__) DEFAULT_DBNAME = "chat_history" DEFAULT_COLLECTION_NAME = "message_store" +DEFAULT_SESSION_ID_KEY = "SessionId" +DEFAULT_HISTORY_KEY = "History" class MongoDBChatMessageHistory(BaseChatMessageHistory): @@ -25,6 +27,10 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory): of a single chat session. database_name: name of the database to use collection_name: name of the collection to use + session_id_key: name of the field that stores the session id + history_key: name of the field that stores the chat history + create_index: whether to create an index on the session id field + index_kwargs: additional keyword arguments to pass to the index creation """ def __init__( @@ -33,11 +39,18 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory): session_id: str, database_name: str = DEFAULT_DBNAME, collection_name: str = DEFAULT_COLLECTION_NAME, + *, + session_id_key: str = DEFAULT_SESSION_ID_KEY, + history_key: str = DEFAULT_HISTORY_KEY, + create_index: bool = True, + index_kwargs: Optional[Dict] = None, ): self.connection_string = connection_string self.session_id = session_id self.database_name = database_name self.collection_name = collection_name + self.session_id_key = session_id_key + self.history_key = history_key try: self.client: MongoClient = MongoClient(connection_string) @@ -46,18 +59,21 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory): self.db = self.client[database_name] self.collection = self.db[collection_name] - self.collection.create_index("SessionId") + + if create_index: + index_kwargs = index_kwargs or {} + self.collection.create_index(self.session_id_key, **index_kwargs) @property def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve the messages from MongoDB""" try: - cursor = self.collection.find({"SessionId": self.session_id}) + cursor = self.collection.find({self.session_id_key: self.session_id}) except errors.OperationFailure as error: logger.error(error) if cursor: - items = [json.loads(document["History"]) for document in cursor] + items = [json.loads(document[self.history_key]) for document in cursor] else: items = [] @@ -69,8 +85,8 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory): try: self.collection.insert_one( { - "SessionId": self.session_id, - "History": json.dumps(message_to_dict(message)), + self.session_id_key: self.session_id, + self.history_key: json.dumps(message_to_dict(message)), } ) except errors.WriteError as err: @@ -79,6 +95,6 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory): def clear(self) -> None: """Clear session memory from MongoDB""" try: - self.collection.delete_many({"SessionId": self.session_id}) + self.collection.delete_many({self.session_id_key: self.session_id}) except errors.WriteError as err: logger.error(err) diff --git a/libs/partners/mongodb/tests/unit_tests/test_chat_message_histories.py b/libs/partners/mongodb/tests/unit_tests/test_chat_message_histories.py index 2c1889a43a..9c89c6d424 100644 --- a/libs/partners/mongodb/tests/unit_tests/test_chat_message_histories.py +++ b/libs/partners/mongodb/tests/unit_tests/test_chat_message_histories.py @@ -13,6 +13,8 @@ class PatchedMongoDBChatMessageHistory(MongoDBChatMessageHistory): self.database_name = "test-database" self.collection_name = "test-collection" self.collection = MockCollection() + self.session_id_key = "SessionId" + self.history_key = "History" def test_memory_with_message_store() -> None: