From 647bbf61c186cebbec0e2772a3d98ac877a51a0d Mon Sep 17 00:00:00 2001 From: Zura Isakadze Date: Tue, 2 May 2023 02:40:00 +0400 Subject: [PATCH] Add SQLiteChatMessageHistory (#3534) It's based on already existing `PostgresChatMessageHistory` Use case somewhere in between multiple files and Postgres storage. --- .../memory/chat_message_histories/__init__.py | 2 + .../memory/chat_message_histories/sql.py | 83 ++++++++++++++++++ .../memory/chat_message_histories/test_sql.py | 85 +++++++++++++++++++ 3 files changed, 170 insertions(+) create mode 100644 langchain/memory/chat_message_histories/sql.py create mode 100644 tests/unit_tests/memory/chat_message_histories/test_sql.py diff --git a/langchain/memory/chat_message_histories/__init__.py b/langchain/memory/chat_message_histories/__init__.py index 05805ece..a891b57f 100644 --- a/langchain/memory/chat_message_histories/__init__.py +++ b/langchain/memory/chat_message_histories/__init__.py @@ -3,11 +3,13 @@ from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessage from langchain.memory.chat_message_histories.file import FileChatMessageHistory from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory +from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory __all__ = [ "DynamoDBChatMessageHistory", "RedisChatMessageHistory", "PostgresChatMessageHistory", + "SQLChatMessageHistory", "FileChatMessageHistory", "CosmosDBChatMessageHistory", ] diff --git a/langchain/memory/chat_message_histories/sql.py b/langchain/memory/chat_message_histories/sql.py new file mode 100644 index 00000000..e3770133 --- /dev/null +++ b/langchain/memory/chat_message_histories/sql.py @@ -0,0 +1,83 @@ +import json +import logging +from typing import List + +from sqlalchemy import Column, Integer, Text, create_engine +from sqlalchemy.orm import declarative_base, sessionmaker + +from langchain.schema import ( + AIMessage, + BaseChatMessageHistory, + BaseMessage, + HumanMessage, + _message_to_dict, + messages_from_dict, +) + +logger = logging.getLogger(__name__) + + +def create_message_model(table_name, DynamicBase): # type: ignore + # Model decleared inside a function to have a dynamic table name + class Message(DynamicBase): + __tablename__ = table_name + id = Column(Integer, primary_key=True) + session_id = Column(Text) + message = Column(Text) + + return Message + + +class SQLChatMessageHistory(BaseChatMessageHistory): + def __init__( + self, + session_id: str, + connection_string: str, + table_name: str = "message_store", + ): + self.table_name = table_name + self.connection_string = connection_string + self.engine = create_engine(connection_string, echo=False) + self._create_table_if_not_exists() + + self.session_id = session_id + self.Session = sessionmaker(self.engine) + + def _create_table_if_not_exists(self) -> None: + DynamicBase = declarative_base() + self.Message = create_message_model(self.table_name, DynamicBase) + # Create all does the check for us in case the table exists. + DynamicBase.metadata.create_all(self.engine) + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + """Retrieve all messages from db""" + with self.Session() as session: + result = session.query(self.Message).where( + self.Message.session_id == self.session_id + ) + items = [json.loads(record.message) for record in result] + messages = messages_from_dict(items) + return messages + + def add_user_message(self, message: str) -> None: + self.append(HumanMessage(content=message)) + + def add_ai_message(self, message: str) -> None: + self.append(AIMessage(content=message)) + + def append(self, message: BaseMessage) -> None: + """Append the message to the record in db""" + with self.Session() as session: + jsonstr = json.dumps(_message_to_dict(message)) + session.add(self.Message(session_id=self.session_id, message=jsonstr)) + session.commit() + + def clear(self) -> None: + """Clear session memory from db""" + + with self.Session() as session: + session.query(self.Message).filter( + self.Message.session_id == self.session_id + ).delete() + session.commit() diff --git a/tests/unit_tests/memory/chat_message_histories/test_sql.py b/tests/unit_tests/memory/chat_message_histories/test_sql.py new file mode 100644 index 00000000..0299ad0a --- /dev/null +++ b/tests/unit_tests/memory/chat_message_histories/test_sql.py @@ -0,0 +1,85 @@ +from pathlib import Path +from typing import Tuple + +import pytest + +from langchain.memory.chat_message_histories import SQLChatMessageHistory +from langchain.schema import AIMessage, HumanMessage + + +# @pytest.fixture(params=[("SQLite"), ("postgresql")]) +@pytest.fixture(params=[("SQLite")]) +def sql_histories(request, tmp_path: Path): # type: ignore + if request.param == "SQLite": + file_path = tmp_path / "db.sqlite3" + con_str = f"sqlite:///{file_path}" + elif request.param == "postgresql": + con_str = "postgresql://postgres:postgres@localhost/postgres" + + message_history = SQLChatMessageHistory( + session_id="123", connection_string=con_str, table_name="test_table" + ) + # Create history for other session + other_history = SQLChatMessageHistory( + session_id="456", connection_string=con_str, table_name="test_table" + ) + + yield (message_history, other_history) + message_history.clear() + other_history.clear() + + +def test_add_messages( + sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory] +) -> None: + sql_history, other_history = sql_histories + sql_history.add_user_message("Hello!") + sql_history.add_ai_message("Hi there!") + + messages = sql_history.messages + assert len(messages) == 2 + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + assert messages[0].content == "Hello!" + assert messages[1].content == "Hi there!" + + +def test_multiple_sessions( + sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory] +) -> None: + sql_history, other_history = sql_histories + sql_history.add_user_message("Hello!") + sql_history.add_ai_message("Hi there!") + sql_history.add_user_message("Whats cracking?") + + # Ensure the messages are added correctly in the first session + assert len(sql_history.messages) == 3, "waat" + assert sql_history.messages[0].content == "Hello!" + assert sql_history.messages[1].content == "Hi there!" + assert sql_history.messages[2].content == "Whats cracking?" + + # second session + other_history.add_user_message("Hellox") + assert len(other_history.messages) == 1 + assert len(sql_history.messages) == 3 + assert other_history.messages[0].content == "Hellox" + assert sql_history.messages[0].content == "Hello!" + assert sql_history.messages[1].content == "Hi there!" + assert sql_history.messages[2].content == "Whats cracking?" + + +def test_clear_messages( + sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory] +) -> None: + sql_history, other_history = sql_histories + sql_history.add_user_message("Hello!") + sql_history.add_ai_message("Hi there!") + assert len(sql_history.messages) == 2 + # Now create another history with different session id + other_history.add_user_message("Hellox") + assert len(other_history.messages) == 1 + assert len(sql_history.messages) == 2 + # Now clear the first history + sql_history.clear() + assert len(sql_history.messages) == 0 + assert len(other_history.messages) == 1