import json import logging from datetime import datetime from typing import List, Optional from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict from sqlalchemy import create_engine, text from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import sessionmaker logger = logging.getLogger(__name__) class TiDBChatMessageHistory(BaseChatMessageHistory): """ Represents a chat message history stored in a TiDB database. """ def __init__( self, session_id: str, connection_string: str, table_name: str = "langchain_message_store", earliest_time: Optional[datetime] = None, ): """ Initializes a new instance of the TiDBChatMessageHistory class. Args: session_id (str): The ID of the chat session. connection_string (str): The connection string for the TiDB database. format: mysql+pymysql://:@:4000/?ssl_ca=/etc/ssl/cert.pem&ssl_verify_cert=true&ssl_verify_identity=true table_name (str, optional): the table name to store the chat messages. Defaults to "langchain_message_store". earliest_time (Optional[datetime], optional): The earliest time to retrieve messages from. Defaults to None. """ # noqa self.session_id = session_id self.table_name = table_name self.earliest_time = earliest_time self.cache = [] # type: ignore[var-annotated] # Set up SQLAlchemy engine and session self.engine = create_engine(connection_string) Session = sessionmaker(bind=self.engine) self.session = Session() self._create_table_if_not_exists() self._load_messages_to_cache() def _create_table_if_not_exists(self) -> None: """ Creates a table if it does not already exist in the database. """ create_table_query = text( f""" CREATE TABLE IF NOT EXISTS {self.table_name} ( id INT AUTO_INCREMENT PRIMARY KEY, session_id VARCHAR(255) NOT NULL, message JSON NOT NULL, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, INDEX session_idx (session_id) );""" ) try: self.session.execute(create_table_query) self.session.commit() except SQLAlchemyError as e: logger.error(f"Error creating table: {e}") self.session.rollback() def _load_messages_to_cache(self) -> None: """ Loads messages from the database into the cache. This method retrieves messages from the database table. The retrieved messages are then stored in the cache for faster access. Raises: SQLAlchemyError: If there is an error executing the database query. """ time_condition = ( f"AND create_time >= '{self.earliest_time}'" if self.earliest_time else "" ) query = text( f""" SELECT message FROM {self.table_name} WHERE session_id = :session_id {time_condition} ORDER BY id; """ ) try: result = self.session.execute(query, {"session_id": self.session_id}) for record in result.fetchall(): message_dict = json.loads(record[0]) self.cache.append(messages_from_dict([message_dict])[0]) except SQLAlchemyError as e: logger.error(f"Error loading messages to cache: {e}") @property def messages(self) -> List[BaseMessage]: # type: ignore[override] """returns all messages""" if len(self.cache) == 0: self.reload_cache() return self.cache def add_message(self, message: BaseMessage) -> None: """adds a message to the database and cache""" query = text( f"INSERT INTO {self.table_name} (session_id, message) VALUES (:session_id, :message);" # noqa ) try: self.session.execute( query, { "session_id": self.session_id, "message": json.dumps(message_to_dict(message)), }, ) self.session.commit() self.cache.append(message) except SQLAlchemyError as e: logger.error(f"Error adding message: {e}") self.session.rollback() def clear(self) -> None: """clears all messages""" query = text(f"DELETE FROM {self.table_name} WHERE session_id = :session_id;") try: self.session.execute(query, {"session_id": self.session_id}) self.session.commit() self.cache.clear() except SQLAlchemyError as e: logger.error(f"Error clearing messages: {e}") self.session.rollback() def reload_cache(self) -> None: """reloads messages from database to cache""" self.cache.clear() self._load_messages_to_cache() def __del__(self) -> None: """closes the session""" self.session.close()