You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

65 lines
2.3 KiB

"""Unit tests for StreamlitChatMessageHistory functionality."""
import pytest
test_script = """
import json
import streamlit as st
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from langchain_core.messages import message_to_dict
message_history = StreamlitChatMessageHistory()
memory = ConversationBufferMemory(chat_memory=message_history, return_messages=True)
# Add some messages
if st.checkbox("add initial messages", value=True):
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
st.markdown("Skipped add")
# Clear messages if checked
if st.checkbox("clear messages"):
# Write the output to st.code as a json blob for inspection
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
def test_memory_with_message_store() -> None:
from streamlit.testing.script_interactions import InteractiveScriptTests
except ModuleNotFoundError:
pytest.skip("Incorrect version of Streamlit installed")
test_handler = InteractiveScriptTests()
sr = test_handler.script_from_string(test_script).run()
except TypeError:
# Earlier version expected 2 arguments
sr = test_handler.script_from_string("", test_script).run()
# Initial run should write two messages
messages_json = sr.get("text")[-1].value
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# Uncheck the initial write, they should persist in session_state
sr = sr.get("checkbox")[0].uncheck().run()
assert sr.get("markdown")[0].value == "Skipped add"
messages_json = sr.get("text")[-1].value
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# Clear the message history
sr = sr.get("checkbox")[1].check().run()
assert sr.get("markdown")[1].value == "Cleared!"
messages_json = sr.get("text")[-1].value
assert messages_json == "[]"