forked from Archives/langchain
Add SQLiteChatMessageHistory (#3534)
It's based on already existing `PostgresChatMessageHistory` Use case somewhere in between multiple files and Postgres storage.
This commit is contained in:
parent
921894960b
commit
647bbf61c1
@ -3,11 +3,13 @@ from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessage
|
||||
from langchain.memory.chat_message_histories.file import FileChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory
|
||||
|
||||
__all__ = [
|
||||
"DynamoDBChatMessageHistory",
|
||||
"RedisChatMessageHistory",
|
||||
"PostgresChatMessageHistory",
|
||||
"SQLChatMessageHistory",
|
||||
"FileChatMessageHistory",
|
||||
"CosmosDBChatMessageHistory",
|
||||
]
|
||||
|
83
langchain/memory/chat_message_histories/sql.py
Normal file
83
langchain/memory/chat_message_histories/sql.py
Normal file
@ -0,0 +1,83 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, create_engine
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
_message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_message_model(table_name, DynamicBase): # type: ignore
|
||||
# Model decleared inside a function to have a dynamic table name
|
||||
class Message(DynamicBase):
|
||||
__tablename__ = table_name
|
||||
id = Column(Integer, primary_key=True)
|
||||
session_id = Column(Text)
|
||||
message = Column(Text)
|
||||
|
||||
return Message
|
||||
|
||||
|
||||
class SQLChatMessageHistory(BaseChatMessageHistory):
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
connection_string: str,
|
||||
table_name: str = "message_store",
|
||||
):
|
||||
self.table_name = table_name
|
||||
self.connection_string = connection_string
|
||||
self.engine = create_engine(connection_string, echo=False)
|
||||
self._create_table_if_not_exists()
|
||||
|
||||
self.session_id = session_id
|
||||
self.Session = sessionmaker(self.engine)
|
||||
|
||||
def _create_table_if_not_exists(self) -> None:
|
||||
DynamicBase = declarative_base()
|
||||
self.Message = create_message_model(self.table_name, DynamicBase)
|
||||
# Create all does the check for us in case the table exists.
|
||||
DynamicBase.metadata.create_all(self.engine)
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve all messages from db"""
|
||||
with self.Session() as session:
|
||||
result = session.query(self.Message).where(
|
||||
self.Message.session_id == self.session_id
|
||||
)
|
||||
items = [json.loads(record.message) for record in result]
|
||||
messages = messages_from_dict(items)
|
||||
return messages
|
||||
|
||||
def add_user_message(self, message: str) -> None:
|
||||
self.append(HumanMessage(content=message))
|
||||
|
||||
def add_ai_message(self, message: str) -> None:
|
||||
self.append(AIMessage(content=message))
|
||||
|
||||
def append(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in db"""
|
||||
with self.Session() as session:
|
||||
jsonstr = json.dumps(_message_to_dict(message))
|
||||
session.add(self.Message(session_id=self.session_id, message=jsonstr))
|
||||
session.commit()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from db"""
|
||||
|
||||
with self.Session() as session:
|
||||
session.query(self.Message).filter(
|
||||
self.Message.session_id == self.session_id
|
||||
).delete()
|
||||
session.commit()
|
85
tests/unit_tests/memory/chat_message_histories/test_sql.py
Normal file
85
tests/unit_tests/memory/chat_message_histories/test_sql.py
Normal file
@ -0,0 +1,85 @@
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.memory.chat_message_histories import SQLChatMessageHistory
|
||||
from langchain.schema import AIMessage, HumanMessage
|
||||
|
||||
|
||||
# @pytest.fixture(params=[("SQLite"), ("postgresql")])
|
||||
@pytest.fixture(params=[("SQLite")])
|
||||
def sql_histories(request, tmp_path: Path): # type: ignore
|
||||
if request.param == "SQLite":
|
||||
file_path = tmp_path / "db.sqlite3"
|
||||
con_str = f"sqlite:///{file_path}"
|
||||
elif request.param == "postgresql":
|
||||
con_str = "postgresql://postgres:postgres@localhost/postgres"
|
||||
|
||||
message_history = SQLChatMessageHistory(
|
||||
session_id="123", connection_string=con_str, table_name="test_table"
|
||||
)
|
||||
# Create history for other session
|
||||
other_history = SQLChatMessageHistory(
|
||||
session_id="456", connection_string=con_str, table_name="test_table"
|
||||
)
|
||||
|
||||
yield (message_history, other_history)
|
||||
message_history.clear()
|
||||
other_history.clear()
|
||||
|
||||
|
||||
def test_add_messages(
|
||||
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory]
|
||||
) -> None:
|
||||
sql_history, other_history = sql_histories
|
||||
sql_history.add_user_message("Hello!")
|
||||
sql_history.add_ai_message("Hi there!")
|
||||
|
||||
messages = sql_history.messages
|
||||
assert len(messages) == 2
|
||||
assert isinstance(messages[0], HumanMessage)
|
||||
assert isinstance(messages[1], AIMessage)
|
||||
assert messages[0].content == "Hello!"
|
||||
assert messages[1].content == "Hi there!"
|
||||
|
||||
|
||||
def test_multiple_sessions(
|
||||
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory]
|
||||
) -> None:
|
||||
sql_history, other_history = sql_histories
|
||||
sql_history.add_user_message("Hello!")
|
||||
sql_history.add_ai_message("Hi there!")
|
||||
sql_history.add_user_message("Whats cracking?")
|
||||
|
||||
# Ensure the messages are added correctly in the first session
|
||||
assert len(sql_history.messages) == 3, "waat"
|
||||
assert sql_history.messages[0].content == "Hello!"
|
||||
assert sql_history.messages[1].content == "Hi there!"
|
||||
assert sql_history.messages[2].content == "Whats cracking?"
|
||||
|
||||
# second session
|
||||
other_history.add_user_message("Hellox")
|
||||
assert len(other_history.messages) == 1
|
||||
assert len(sql_history.messages) == 3
|
||||
assert other_history.messages[0].content == "Hellox"
|
||||
assert sql_history.messages[0].content == "Hello!"
|
||||
assert sql_history.messages[1].content == "Hi there!"
|
||||
assert sql_history.messages[2].content == "Whats cracking?"
|
||||
|
||||
|
||||
def test_clear_messages(
|
||||
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory]
|
||||
) -> None:
|
||||
sql_history, other_history = sql_histories
|
||||
sql_history.add_user_message("Hello!")
|
||||
sql_history.add_ai_message("Hi there!")
|
||||
assert len(sql_history.messages) == 2
|
||||
# Now create another history with different session id
|
||||
other_history.add_user_message("Hellox")
|
||||
assert len(other_history.messages) == 1
|
||||
assert len(sql_history.messages) == 2
|
||||
# Now clear the first history
|
||||
sql_history.clear()
|
||||
assert len(sql_history.messages) == 0
|
||||
assert len(other_history.messages) == 1
|
Loading…
Reference in New Issue
Block a user