forked from Archives/langchain
Add SQLiteChatMessageHistory (#3534)
It's based on already existing `PostgresChatMessageHistory` Use case somewhere in between multiple files and Postgres storage.fix_agent_callbacks
parent
921894960b
commit
647bbf61c1
@ -0,0 +1,83 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Integer, Text, create_engine
|
||||||
|
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||||
|
|
||||||
|
from langchain.schema import (
|
||||||
|
AIMessage,
|
||||||
|
BaseChatMessageHistory,
|
||||||
|
BaseMessage,
|
||||||
|
HumanMessage,
|
||||||
|
_message_to_dict,
|
||||||
|
messages_from_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_message_model(table_name, DynamicBase): # type: ignore
|
||||||
|
# Model decleared inside a function to have a dynamic table name
|
||||||
|
class Message(DynamicBase):
|
||||||
|
__tablename__ = table_name
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
session_id = Column(Text)
|
||||||
|
message = Column(Text)
|
||||||
|
|
||||||
|
return Message
|
||||||
|
|
||||||
|
|
||||||
|
class SQLChatMessageHistory(BaseChatMessageHistory):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
connection_string: str,
|
||||||
|
table_name: str = "message_store",
|
||||||
|
):
|
||||||
|
self.table_name = table_name
|
||||||
|
self.connection_string = connection_string
|
||||||
|
self.engine = create_engine(connection_string, echo=False)
|
||||||
|
self._create_table_if_not_exists()
|
||||||
|
|
||||||
|
self.session_id = session_id
|
||||||
|
self.Session = sessionmaker(self.engine)
|
||||||
|
|
||||||
|
def _create_table_if_not_exists(self) -> None:
|
||||||
|
DynamicBase = declarative_base()
|
||||||
|
self.Message = create_message_model(self.table_name, DynamicBase)
|
||||||
|
# Create all does the check for us in case the table exists.
|
||||||
|
DynamicBase.metadata.create_all(self.engine)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||||
|
"""Retrieve all messages from db"""
|
||||||
|
with self.Session() as session:
|
||||||
|
result = session.query(self.Message).where(
|
||||||
|
self.Message.session_id == self.session_id
|
||||||
|
)
|
||||||
|
items = [json.loads(record.message) for record in result]
|
||||||
|
messages = messages_from_dict(items)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def add_user_message(self, message: str) -> None:
|
||||||
|
self.append(HumanMessage(content=message))
|
||||||
|
|
||||||
|
def add_ai_message(self, message: str) -> None:
|
||||||
|
self.append(AIMessage(content=message))
|
||||||
|
|
||||||
|
def append(self, message: BaseMessage) -> None:
|
||||||
|
"""Append the message to the record in db"""
|
||||||
|
with self.Session() as session:
|
||||||
|
jsonstr = json.dumps(_message_to_dict(message))
|
||||||
|
session.add(self.Message(session_id=self.session_id, message=jsonstr))
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear session memory from db"""
|
||||||
|
|
||||||
|
with self.Session() as session:
|
||||||
|
session.query(self.Message).filter(
|
||||||
|
self.Message.session_id == self.session_id
|
||||||
|
).delete()
|
||||||
|
session.commit()
|
@ -0,0 +1,85 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.memory.chat_message_histories import SQLChatMessageHistory
|
||||||
|
from langchain.schema import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.fixture(params=[("SQLite"), ("postgresql")])
|
||||||
|
@pytest.fixture(params=[("SQLite")])
|
||||||
|
def sql_histories(request, tmp_path: Path): # type: ignore
|
||||||
|
if request.param == "SQLite":
|
||||||
|
file_path = tmp_path / "db.sqlite3"
|
||||||
|
con_str = f"sqlite:///{file_path}"
|
||||||
|
elif request.param == "postgresql":
|
||||||
|
con_str = "postgresql://postgres:postgres@localhost/postgres"
|
||||||
|
|
||||||
|
message_history = SQLChatMessageHistory(
|
||||||
|
session_id="123", connection_string=con_str, table_name="test_table"
|
||||||
|
)
|
||||||
|
# Create history for other session
|
||||||
|
other_history = SQLChatMessageHistory(
|
||||||
|
session_id="456", connection_string=con_str, table_name="test_table"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield (message_history, other_history)
|
||||||
|
message_history.clear()
|
||||||
|
other_history.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_messages(
|
||||||
|
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory]
|
||||||
|
) -> None:
|
||||||
|
sql_history, other_history = sql_histories
|
||||||
|
sql_history.add_user_message("Hello!")
|
||||||
|
sql_history.add_ai_message("Hi there!")
|
||||||
|
|
||||||
|
messages = sql_history.messages
|
||||||
|
assert len(messages) == 2
|
||||||
|
assert isinstance(messages[0], HumanMessage)
|
||||||
|
assert isinstance(messages[1], AIMessage)
|
||||||
|
assert messages[0].content == "Hello!"
|
||||||
|
assert messages[1].content == "Hi there!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_sessions(
|
||||||
|
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory]
|
||||||
|
) -> None:
|
||||||
|
sql_history, other_history = sql_histories
|
||||||
|
sql_history.add_user_message("Hello!")
|
||||||
|
sql_history.add_ai_message("Hi there!")
|
||||||
|
sql_history.add_user_message("Whats cracking?")
|
||||||
|
|
||||||
|
# Ensure the messages are added correctly in the first session
|
||||||
|
assert len(sql_history.messages) == 3, "waat"
|
||||||
|
assert sql_history.messages[0].content == "Hello!"
|
||||||
|
assert sql_history.messages[1].content == "Hi there!"
|
||||||
|
assert sql_history.messages[2].content == "Whats cracking?"
|
||||||
|
|
||||||
|
# second session
|
||||||
|
other_history.add_user_message("Hellox")
|
||||||
|
assert len(other_history.messages) == 1
|
||||||
|
assert len(sql_history.messages) == 3
|
||||||
|
assert other_history.messages[0].content == "Hellox"
|
||||||
|
assert sql_history.messages[0].content == "Hello!"
|
||||||
|
assert sql_history.messages[1].content == "Hi there!"
|
||||||
|
assert sql_history.messages[2].content == "Whats cracking?"
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_messages(
|
||||||
|
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory]
|
||||||
|
) -> None:
|
||||||
|
sql_history, other_history = sql_histories
|
||||||
|
sql_history.add_user_message("Hello!")
|
||||||
|
sql_history.add_ai_message("Hi there!")
|
||||||
|
assert len(sql_history.messages) == 2
|
||||||
|
# Now create another history with different session id
|
||||||
|
other_history.add_user_message("Hellox")
|
||||||
|
assert len(other_history.messages) == 1
|
||||||
|
assert len(sql_history.messages) == 2
|
||||||
|
# Now clear the first history
|
||||||
|
sql_history.clear()
|
||||||
|
assert len(sql_history.messages) == 0
|
||||||
|
assert len(other_history.messages) == 1
|
Loading…
Reference in New Issue