From 6705928b9d71a67551d463697c06d136f134f7d6 Mon Sep 17 00:00:00 2001 From: Joshua Carroll Date: Tue, 1 Aug 2023 14:28:15 -0700 Subject: [PATCH] Add StreamlitChatMessageHistory (#8497) Add a StreamlitChatMessageHistory class that stores chat messages in [Streamlit's Session State](https://docs.streamlit.io/library/api-reference/session-state). Note: The integration test uses a currently-experimental Streamlit testing framework to simulate the execution of a Streamlit app. Marking this PR as draft until I confirm with the Streamlit team that we're comfortable supporting it. --------- Co-authored-by: Bagatur --- .../streamlit_chat_message_history.ipynb | 61 ++++++++++++++++++ libs/langchain/langchain/memory/__init__.py | 2 + .../memory/chat_message_histories/__init__.py | 4 ++ .../chat_message_histories/streamlit.py | 40 ++++++++++++ .../chat_message_histories/test_streamlit.py | 64 +++++++++++++++++++ 5 files changed, 171 insertions(+) create mode 100644 docs/extras/integrations/memory/streamlit_chat_message_history.ipynb create mode 100644 libs/langchain/langchain/memory/chat_message_histories/streamlit.py create mode 100644 libs/langchain/tests/unit_tests/memory/chat_message_histories/test_streamlit.py diff --git a/docs/extras/integrations/memory/streamlit_chat_message_history.ipynb b/docs/extras/integrations/memory/streamlit_chat_message_history.ipynb new file mode 100644 index 0000000000..3f7e0ebecd --- /dev/null +++ b/docs/extras/integrations/memory/streamlit_chat_message_history.ipynb @@ -0,0 +1,61 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "91c6a7ef", + "metadata": {}, + "source": [ + "# Streamlit Chat Message History\n", + "\n", + "This notebook goes over how to use Streamlit to store chat message history. Note, StreamlitChatMessageHistory only works when run in a Streamlit app. For more on Streamlit check out their\n", + "[getting started documentation](https://docs.streamlit.io/library/get-started)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d15e3302", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.memory import StreamlitChatMessageHistory\n", + "\n", + "history = StreamlitChatMessageHistory(\"foo\")\n", + "\n", + "history.add_user_message(\"hi!\")\n", + "history.add_ai_message(\"whats up?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64fc465e", + "metadata": {}, + "outputs": [], + "source": [ + "history.messages" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "poetry-venv", + "language": "python", + "name": "poetry-venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/memory/__init__.py b/libs/langchain/langchain/memory/__init__.py index 76fc7c92ed..e0ac6371fe 100644 --- a/libs/langchain/langchain/memory/__init__.py +++ b/libs/langchain/langchain/memory/__init__.py @@ -42,6 +42,7 @@ from langchain.memory.chat_message_histories import ( PostgresChatMessageHistory, RedisChatMessageHistory, SQLChatMessageHistory, + StreamlitChatMessageHistory, ZepChatMessageHistory, ) from langchain.memory.combined import CombinedMemory @@ -87,6 +88,7 @@ __all__ = [ "SQLChatMessageHistory", "SQLiteEntityStore", "SimpleMemory", + "StreamlitChatMessageHistory", "VectorStoreRetrieverMemory", "ZepChatMessageHistory", "ZepMemory", diff --git a/libs/langchain/langchain/memory/chat_message_histories/__init__.py b/libs/langchain/langchain/memory/chat_message_histories/__init__.py index 80aa0fcf24..b118eb5ae5 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/__init__.py +++ b/libs/langchain/langchain/memory/chat_message_histories/__init__.py @@ -13,6 +13,9 @@ from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHi from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory +from langchain.memory.chat_message_histories.streamlit import ( + StreamlitChatMessageHistory, +) from langchain.memory.chat_message_histories.zep import ZepChatMessageHistory __all__ = [ @@ -27,5 +30,6 @@ __all__ = [ "PostgresChatMessageHistory", "RedisChatMessageHistory", "SQLChatMessageHistory", + "StreamlitChatMessageHistory", "ZepChatMessageHistory", ] diff --git a/libs/langchain/langchain/memory/chat_message_histories/streamlit.py b/libs/langchain/langchain/memory/chat_message_histories/streamlit.py new file mode 100644 index 0000000000..34280356f7 --- /dev/null +++ b/libs/langchain/langchain/memory/chat_message_histories/streamlit.py @@ -0,0 +1,40 @@ +from typing import List + +from langchain.schema import ( + BaseChatMessageHistory, +) +from langchain.schema.messages import BaseMessage + + +class StreamlitChatMessageHistory(BaseChatMessageHistory): + """ + Chat message history that stores messages in Streamlit session state. + + Args: + key: The key to use in Streamlit session state for storing messages. + """ + + def __init__(self, key: str = "langchain_messages"): + try: + import streamlit as st + except ImportError as e: + raise ImportError( + "Unable to import streamlit, please run `pip install streamlit`." + ) from e + + if key not in st.session_state: + st.session_state[key] = [] + self._messages = st.session_state[key] + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + """Retrieve the current list of messages""" + return self._messages + + def add_message(self, message: BaseMessage) -> None: + """Add a message to the session memory""" + self._messages.append(message) + + def clear(self) -> None: + """Clear session memory""" + self._messages.clear() diff --git a/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_streamlit.py b/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_streamlit.py new file mode 100644 index 0000000000..5ed50191c1 --- /dev/null +++ b/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_streamlit.py @@ -0,0 +1,64 @@ +"""Unit tests for StreamlitChatMessageHistory functionality.""" +import pytest + +test_script = """ + import json + import streamlit as st + from langchain.memory import ConversationBufferMemory + from langchain.memory.chat_message_histories import StreamlitChatMessageHistory + from langchain.schema.messages import _message_to_dict + + message_history = StreamlitChatMessageHistory() + memory = ConversationBufferMemory(chat_memory=message_history, return_messages=True) + + # Add some messages + if st.checkbox("add initial messages", value=True): + memory.chat_memory.add_ai_message("This is me, the AI") + memory.chat_memory.add_user_message("This is me, the human") + else: + st.markdown("Skipped add") + + # Clear messages if checked + if st.checkbox("clear messages"): + st.markdown("Cleared!") + memory.chat_memory.clear() + + # Write the output to st.code as a json blob for inspection + messages = memory.chat_memory.messages + messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) + st.text(messages_json) +""" + + +@pytest.mark.requires("streamlit") +def test_memory_with_message_store() -> None: + try: + from streamlit.testing.script_interactions import InteractiveScriptTests + except ModuleNotFoundError: + pytest.skip("Incorrect version of Streamlit installed") + + test_handler = InteractiveScriptTests() + test_handler.setUp() + try: + sr = test_handler.script_from_string(test_script).run() + except TypeError: + # Earlier version expected 2 arguments + sr = test_handler.script_from_string("memory_test.py", test_script).run() + + # Initial run should write two messages + messages_json = sr.get("text")[-1].value + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + # Uncheck the initial write, they should persist in session_state + sr = sr.get("checkbox")[0].uncheck().run() + assert sr.get("markdown")[0].value == "Skipped add" + messages_json = sr.get("text")[-1].value + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + # Clear the message history + sr = sr.get("checkbox")[1].check().run() + assert sr.get("markdown")[1].value == "Cleared!" + messages_json = sr.get("text")[-1].value + assert messages_json == "[]"