Add StreamlitChatMessageHistory (#8497)

Add a StreamlitChatMessageHistory class that stores chat messages in
[Streamlit's Session
State](https://docs.streamlit.io/library/api-reference/session-state).

Note: The integration test uses a currently-experimental Streamlit
testing framework to simulate the execution of a Streamlit app. Marking
this PR as draft until I confirm with the Streamlit team that we're
comfortable supporting it.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Joshua Carroll 2023-08-01 14:28:15 -07:00 committed by GitHub
parent 8961c720b8
commit 6705928b9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 171 additions and 0 deletions

View File

@ -0,0 +1,61 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "91c6a7ef",
"metadata": {},
"source": [
"# Streamlit Chat Message History\n",
"\n",
"This notebook goes over how to use Streamlit to store chat message history. Note, StreamlitChatMessageHistory only works when run in a Streamlit app. For more on Streamlit check out their\n",
"[getting started documentation](https://docs.streamlit.io/library/get-started)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d15e3302",
"metadata": {},
"outputs": [],
"source": [
"from langchain.memory import StreamlitChatMessageHistory\n",
"\n",
"history = StreamlitChatMessageHistory(\"foo\")\n",
"\n",
"history.add_user_message(\"hi!\")\n",
"history.add_ai_message(\"whats up?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "64fc465e",
"metadata": {},
"outputs": [],
"source": [
"history.messages"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv",
"language": "python",
"name": "poetry-venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -42,6 +42,7 @@ from langchain.memory.chat_message_histories import (
PostgresChatMessageHistory,
RedisChatMessageHistory,
SQLChatMessageHistory,
StreamlitChatMessageHistory,
ZepChatMessageHistory,
)
from langchain.memory.combined import CombinedMemory
@ -87,6 +88,7 @@ __all__ = [
"SQLChatMessageHistory",
"SQLiteEntityStore",
"SimpleMemory",
"StreamlitChatMessageHistory",
"VectorStoreRetrieverMemory",
"ZepChatMessageHistory",
"ZepMemory",

View File

@ -13,6 +13,9 @@ from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHi
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory
from langchain.memory.chat_message_histories.streamlit import (
StreamlitChatMessageHistory,
)
from langchain.memory.chat_message_histories.zep import ZepChatMessageHistory
__all__ = [
@ -27,5 +30,6 @@ __all__ = [
"PostgresChatMessageHistory",
"RedisChatMessageHistory",
"SQLChatMessageHistory",
"StreamlitChatMessageHistory",
"ZepChatMessageHistory",
]

View File

@ -0,0 +1,40 @@
from typing import List
from langchain.schema import (
BaseChatMessageHistory,
)
from langchain.schema.messages import BaseMessage
class StreamlitChatMessageHistory(BaseChatMessageHistory):
"""
Chat message history that stores messages in Streamlit session state.
Args:
key: The key to use in Streamlit session state for storing messages.
"""
def __init__(self, key: str = "langchain_messages"):
try:
import streamlit as st
except ImportError as e:
raise ImportError(
"Unable to import streamlit, please run `pip install streamlit`."
) from e
if key not in st.session_state:
st.session_state[key] = []
self._messages = st.session_state[key]
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the current list of messages"""
return self._messages
def add_message(self, message: BaseMessage) -> None:
"""Add a message to the session memory"""
self._messages.append(message)
def clear(self) -> None:
"""Clear session memory"""
self._messages.clear()

View File

@ -0,0 +1,64 @@
"""Unit tests for StreamlitChatMessageHistory functionality."""
import pytest
test_script = """
import json
import streamlit as st
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from langchain.schema.messages import _message_to_dict
message_history = StreamlitChatMessageHistory()
memory = ConversationBufferMemory(chat_memory=message_history, return_messages=True)
# Add some messages
if st.checkbox("add initial messages", value=True):
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
else:
st.markdown("Skipped add")
# Clear messages if checked
if st.checkbox("clear messages"):
st.markdown("Cleared!")
memory.chat_memory.clear()
# Write the output to st.code as a json blob for inspection
messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
st.text(messages_json)
"""
@pytest.mark.requires("streamlit")
def test_memory_with_message_store() -> None:
try:
from streamlit.testing.script_interactions import InteractiveScriptTests
except ModuleNotFoundError:
pytest.skip("Incorrect version of Streamlit installed")
test_handler = InteractiveScriptTests()
test_handler.setUp()
try:
sr = test_handler.script_from_string(test_script).run()
except TypeError:
# Earlier version expected 2 arguments
sr = test_handler.script_from_string("memory_test.py", test_script).run()
# Initial run should write two messages
messages_json = sr.get("text")[-1].value
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# Uncheck the initial write, they should persist in session_state
sr = sr.get("checkbox")[0].uncheck().run()
assert sr.get("markdown")[0].value == "Skipped add"
messages_json = sr.get("text")[-1].value
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# Clear the message history
sr = sr.get("checkbox")[1].check().run()
assert sr.get("markdown")[1].value == "Cleared!"
messages_json = sr.get("text")[-1].value
assert messages_json == "[]"