You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/tests/unit_tests/memory/chat_message_histories/test_sql.py

86 lines
3.0 KiB
Python

from pathlib import Path
from typing import Tuple
import pytest
from langchain.memory.chat_message_histories import SQLChatMessageHistory
from langchain.schema import AIMessage, HumanMessage
# @pytest.fixture(params=[("SQLite"), ("postgresql")])
@pytest.fixture(params=[("SQLite")])
def sql_histories(request, tmp_path: Path): # type: ignore
if request.param == "SQLite":
file_path = tmp_path / "db.sqlite3"
con_str = f"sqlite:///{file_path}"
elif request.param == "postgresql":
con_str = "postgresql://postgres:postgres@localhost/postgres"
message_history = SQLChatMessageHistory(
session_id="123", connection_string=con_str, table_name="test_table"
)
# Create history for other session
other_history = SQLChatMessageHistory(
session_id="456", connection_string=con_str, table_name="test_table"
)
yield (message_history, other_history)
message_history.clear()
other_history.clear()
def test_add_messages(
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory]
) -> None:
sql_history, other_history = sql_histories
sql_history.add_user_message("Hello!")
sql_history.add_ai_message("Hi there!")
messages = sql_history.messages
assert len(messages) == 2
assert isinstance(messages[0], HumanMessage)
assert isinstance(messages[1], AIMessage)
assert messages[0].content == "Hello!"
assert messages[1].content == "Hi there!"
def test_multiple_sessions(
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory]
) -> None:
sql_history, other_history = sql_histories
sql_history.add_user_message("Hello!")
sql_history.add_ai_message("Hi there!")
sql_history.add_user_message("Whats cracking?")
# Ensure the messages are added correctly in the first session
assert len(sql_history.messages) == 3, "waat"
assert sql_history.messages[0].content == "Hello!"
assert sql_history.messages[1].content == "Hi there!"
assert sql_history.messages[2].content == "Whats cracking?"
# second session
other_history.add_user_message("Hellox")
assert len(other_history.messages) == 1
assert len(sql_history.messages) == 3
assert other_history.messages[0].content == "Hellox"
assert sql_history.messages[0].content == "Hello!"
assert sql_history.messages[1].content == "Hi there!"
assert sql_history.messages[2].content == "Whats cracking?"
def test_clear_messages(
sql_histories: Tuple[SQLChatMessageHistory, SQLChatMessageHistory]
) -> None:
sql_history, other_history = sql_histories
sql_history.add_user_message("Hello!")
sql_history.add_ai_message("Hi there!")
assert len(sql_history.messages) == 2
# Now create another history with different session id
other_history.add_user_message("Hellox")
assert len(other_history.messages) == 1
assert len(sql_history.messages) == 2
# Now clear the first history
sql_history.clear()
assert len(sql_history.messages) == 0
assert len(other_history.messages) == 1