Add SQLiteChatMessageHistory (#3534)

It's based on already existing `PostgresChatMessageHistory`

Use case somewhere in between multiple files and Postgres storage.
This commit is contained in:
Zura Isakadze 2023-05-02 02:40:00 +04:00 committed by GitHub
parent 921894960b
commit 647bbf61c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 170 additions and 0 deletions

View File

@ -3,11 +3,13 @@ from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessage
from langchain.memory.chat_message_histories.file import FileChatMessageHistory
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory
__all__ = [
"DynamoDBChatMessageHistory",
"RedisChatMessageHistory",
"PostgresChatMessageHistory",
"SQLChatMessageHistory",
"FileChatMessageHistory",
"CosmosDBChatMessageHistory",
]

View File

@ -0,0 +1,83 @@
import json
import logging
from typing import List
from sqlalchemy import Column, Integer, Text, create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
logger = logging.getLogger(__name__)
def create_message_model(table_name, DynamicBase): # type: ignore
# Model decleared inside a function to have a dynamic table name
class Message(DynamicBase):
__tablename__ = table_name
id = Column(Integer, primary_key=True)
session_id = Column(Text)
message = Column(Text)
return Message
class SQLChatMessageHistory(BaseChatMessageHistory):
def __init__(
self,
session_id: str,
connection_string: str,
table_name: str = "message_store",
):
self.table_name = table_name
self.connection_string = connection_string
self.engine = create_engine(connection_string, echo=False)
self._create_table_if_not_exists()
self.session_id = session_id
self.Session = sessionmaker(self.engine)
def _create_table_if_not_exists(self) -> None:
DynamicBase = declarative_base()
self.Message = create_message_model(self.table_name, DynamicBase)
# Create all does the check for us in case the table exists.
DynamicBase.metadata.create_all(self.engine)
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve all messages from db"""
with self.Session() as session:
result = session.query(self.Message).where(
self.Message.session_id == self.session_id
)
items = [json.loads(record.message) for record in result]
messages = messages_from_dict(items)
return messages
def add_user_message(self, message: str) -> None:
self.append(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))
def append(self, message: BaseMessage) -> None:
"""Append the message to the record in db"""
with self.Session() as session:
jsonstr = json.dumps(_message_to_dict(message))
session.add(self.Message(session_id=self.session_id, message=jsonstr))
session.commit()
def clear(self) -> None:
"""Clear session memory from db"""
with self.Session() as session:
session.query(self.Message).filter(
self.Message.session_id == self.session_id
).delete()
session.commit()

View File

@ -0,0 +1,85 @@
from pathlib import Path
from typing import Tuple
import pytest
from langchain.memory.chat_message_histories import SQLChatMessageHistory
from langchain.schema import AIMessage, HumanMessage
# @pytest.fixture(params=[("SQLite"), ("postgresql")])
@pytest.fixture(params=[("SQLite")])
def sql_histories(request, tmp_path: Path): # type: ignore
if request.param == "SQLite":
file_path = tmp_path / "db.sqlite3"
con_str = f"sqlite:///{file_path}"
elif request.param == "postgresql":
con_str = "postgresql://postgres:postgres@localhost/postgres"
message_history = SQLChatMessageHistory(
session_id="123", connection_string=con_str, table_name="test_table"
)
# Create history for other session
other_history = SQLChatMessageHistory(
session_id="456", connection_string=con_str, table_name="test_table"
)
yield (message_history, other_history)
message_history.clear()
other_history.clear()
def test_add_messages(
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory]
) -> None:
sql_history, other_history = sql_histories
sql_history.add_user_message("Hello!")
sql_history.add_ai_message("Hi there!")
messages = sql_history.messages
assert len(messages) == 2
assert isinstance(messages[0], HumanMessage)
assert isinstance(messages[1], AIMessage)
assert messages[0].content == "Hello!"
assert messages[1].content == "Hi there!"
def test_multiple_sessions(
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory]
) -> None:
sql_history, other_history = sql_histories
sql_history.add_user_message("Hello!")
sql_history.add_ai_message("Hi there!")
sql_history.add_user_message("Whats cracking?")
# Ensure the messages are added correctly in the first session
assert len(sql_history.messages) == 3, "waat"
assert sql_history.messages[0].content == "Hello!"
assert sql_history.messages[1].content == "Hi there!"
assert sql_history.messages[2].content == "Whats cracking?"
# second session
other_history.add_user_message("Hellox")
assert len(other_history.messages) == 1
assert len(sql_history.messages) == 3
assert other_history.messages[0].content == "Hellox"
assert sql_history.messages[0].content == "Hello!"
assert sql_history.messages[1].content == "Hi there!"
assert sql_history.messages[2].content == "Whats cracking?"
def test_clear_messages(
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory]
) -> None:
sql_history, other_history = sql_histories
sql_history.add_user_message("Hello!")
sql_history.add_ai_message("Hi there!")
assert len(sql_history.messages) == 2
# Now create another history with different session id
other_history.add_user_message("Hellox")
assert len(other_history.messages) == 1
assert len(sql_history.messages) == 2
# Now clear the first history
sql_history.clear()
assert len(sql_history.messages) == 0
assert len(other_history.messages) == 1