@ -1,6 +1,6 @@
import json
import json
import logging
import logging
from typing import List
from typing import Dict, List, Optional
from langchain_core . chat_history import BaseChatMessageHistory
from langchain_core . chat_history import BaseChatMessageHistory
from langchain_core . messages import (
from langchain_core . messages import (
@ -14,6 +14,8 @@ logger = logging.getLogger(__name__)
DEFAULT_DBNAME = " chat_history "
DEFAULT_DBNAME = " chat_history "
DEFAULT_COLLECTION_NAME = " message_store "
DEFAULT_COLLECTION_NAME = " message_store "
DEFAULT_SESSION_ID_KEY = " SessionId "
DEFAULT_HISTORY_KEY = " History "
class MongoDBChatMessageHistory ( BaseChatMessageHistory ) :
class MongoDBChatMessageHistory ( BaseChatMessageHistory ) :
@ -25,6 +27,10 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
of a single chat session .
of a single chat session .
database_name : name of the database to use
database_name : name of the database to use
collection_name : name of the collection to use
collection_name : name of the collection to use
session_id_key : name of the field that stores the session id
history_key : name of the field that stores the chat history
create_index : whether to create an index on the session id field
index_kwargs : additional keyword arguments to pass to the index creation
"""
"""
def __init__ (
def __init__ (
@ -33,11 +39,18 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
session_id : str ,
session_id : str ,
database_name : str = DEFAULT_DBNAME ,
database_name : str = DEFAULT_DBNAME ,
collection_name : str = DEFAULT_COLLECTION_NAME ,
collection_name : str = DEFAULT_COLLECTION_NAME ,
* ,
session_id_key : str = DEFAULT_SESSION_ID_KEY ,
history_key : str = DEFAULT_HISTORY_KEY ,
create_index : bool = True ,
index_kwargs : Optional [ Dict ] = None ,
) :
) :
self . connection_string = connection_string
self . connection_string = connection_string
self . session_id = session_id
self . session_id = session_id
self . database_name = database_name
self . database_name = database_name
self . collection_name = collection_name
self . collection_name = collection_name
self . session_id_key = session_id_key
self . history_key = history_key
try :
try :
self . client : MongoClient = MongoClient ( connection_string )
self . client : MongoClient = MongoClient ( connection_string )
@ -46,18 +59,21 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
self . db = self . client [ database_name ]
self . db = self . client [ database_name ]
self . collection = self . db [ collection_name ]
self . collection = self . db [ collection_name ]
self . collection . create_index ( " SessionId " )
if create_index :
index_kwargs = index_kwargs or { }
self . collection . create_index ( self . session_id_key , * * index_kwargs )
@property
@property
def messages ( self ) - > List [ BaseMessage ] : # type: ignore
def messages ( self ) - > List [ BaseMessage ] : # type: ignore
""" Retrieve the messages from MongoDB """
""" Retrieve the messages from MongoDB """
try :
try :
cursor = self . collection . find ( { " SessionId " : self . session_id } )
cursor = self . collection . find ( { self . session_id_key : self . session_id } )
except errors . OperationFailure as error :
except errors . OperationFailure as error :
logger . error ( error )
logger . error ( error )
if cursor :
if cursor :
items = [ json . loads ( document [ " History " ] ) for document in cursor ]
items = [ json . loads ( document [ self . history_key ] ) for document in cursor ]
else :
else :
items = [ ]
items = [ ]
@ -69,8 +85,8 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
try :
try :
self . collection . insert_one (
self . collection . insert_one (
{
{
" SessionId " : self . session_id ,
self . session_id_key : self . session_id ,
" History " : json . dumps ( message_to_dict ( message ) ) ,
self . history_key : json . dumps ( message_to_dict ( message ) ) ,
}
}
)
)
except errors . WriteError as err :
except errors . WriteError as err :
@ -79,6 +95,6 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
def clear ( self ) - > None :
def clear ( self ) - > None :
""" Clear session memory from MongoDB """
""" Clear session memory from MongoDB """
try :
try :
self . collection . delete_many ( { " SessionId " : self . session_id } )
self . collection . delete_many ( { self . session_id_key : self . session_id } )
except errors . WriteError as err :
except errors . WriteError as err :
logger . error ( err )
logger . error ( err )