From 3f64b8a76184490edfeba8fa8a329b9409bf974d Mon Sep 17 00:00:00 2001 From: Aarav Borthakur <69025547+gadhagod@users.noreply.github.com> Date: Tue, 8 Aug 2023 18:54:07 -0700 Subject: [PATCH] Integrate Rockset as a chat history store (#8940) Description: Adds Rockset as a chat history store Dependencies: no changes Tag maintainer: @hwchase17 This PR passes linting and testing. I added a test for the integration and an example notebook showing its use. --- .../memory/rockset_chat_message_history.ipynb | 67 +++++ .../extras/integrations/providers/rockset.mdx | 7 + .../memory/chat_message_histories/__init__.py | 2 + .../chat_message_histories/rocksetdb.py | 258 ++++++++++++++++++ .../integration_tests/memory/test_rockset.py | 63 +++++ 5 files changed, 397 insertions(+) create mode 100644 docs/extras/integrations/memory/rockset_chat_message_history.ipynb create mode 100644 libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py create mode 100644 libs/langchain/tests/integration_tests/memory/test_rockset.py diff --git a/docs/extras/integrations/memory/rockset_chat_message_history.ipynb b/docs/extras/integrations/memory/rockset_chat_message_history.ipynb new file mode 100644 index 0000000000..1bf7c5e3ff --- /dev/null +++ b/docs/extras/integrations/memory/rockset_chat_message_history.ipynb @@ -0,0 +1,67 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Rockset Chat Message History\n", + "\n", + "This notebook goes over how to use [Rockset](https://rockset.com/docs) to store chat message history. \n", + "\n", + "To begin, with get your API key from the [Rockset console](https://console.rockset.com/apikeys). Find your API region for the Rockset [API reference](https://rockset.com/docs/rest-api#introduction)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "from langchain.memory.chat_message_histories import RocksetChatMessageHistory\n", + "from rockset import RocksetClient, Regions\n", + "\n", + "history = RocksetChatMessageHistory(\n", + " session_id=\"MySession\",\n", + " client=RocksetClient(\n", + " api_key=\"YOUR API KEY\", \n", + " host=Regions.usw2a1 # us-west-2 Oregon\n", + " ),\n", + " collection=\"langchain_demo\",\n", + " sync=True\n", + ")\n", + "history.add_user_message(\"hi!\")\n", + "history.add_ai_message(\"whats up?\")\n", + "print(history.messages)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The output should be something like:\n", + "\n", + "```python\n", + "[\n", + " HumanMessage(content='hi!', additional_kwargs={'id': '2e62f1c2-e9f7-465e-b551-49bae07fe9f0'}, example=False), \n", + " AIMessage(content='whats up?', additional_kwargs={'id': 'b9be8eda-4c18-4cf8-81c3-e91e876927d0'}, example=False)\n", + "]\n", + "\n", + "```" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/extras/integrations/providers/rockset.mdx b/docs/extras/integrations/providers/rockset.mdx index 1b9ba56590..b13b4fb944 100644 --- a/docs/extras/integrations/providers/rockset.mdx +++ b/docs/extras/integrations/providers/rockset.mdx @@ -23,4 +23,11 @@ from langchain.vectorstores import Rockset See a [usage example](/docs/integrations/document_loaders/rockset). ```python from langchain.document_loaders import RocksetLoader +``` + +## Chat Message History + +See a [usage example](/docs/integrations/memory/rockset_chat_message_history). +```python +from langchain.memory.chat_message_histories import RocksetChatMessageHistory ``` \ No newline at end of file diff --git a/libs/langchain/langchain/memory/chat_message_histories/__init__.py b/libs/langchain/langchain/memory/chat_message_histories/__init__.py index b118eb5ae5..02241675b1 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/__init__.py +++ b/libs/langchain/langchain/memory/chat_message_histories/__init__.py @@ -12,6 +12,7 @@ from langchain.memory.chat_message_histories.momento import MomentoChatMessageHi from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHistory from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory +from langchain.memory.chat_message_histories.rocksetdb import RocksetChatMessageHistory from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory from langchain.memory.chat_message_histories.streamlit import ( StreamlitChatMessageHistory, @@ -29,6 +30,7 @@ __all__ = [ "MongoDBChatMessageHistory", "PostgresChatMessageHistory", "RedisChatMessageHistory", + "RocksetChatMessageHistory", "SQLChatMessageHistory", "StreamlitChatMessageHistory", "ZepChatMessageHistory", diff --git a/libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py b/libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py new file mode 100644 index 0000000000..ce19d55693 --- /dev/null +++ b/libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py @@ -0,0 +1,258 @@ +from datetime import datetime +from time import sleep +from typing import Any, Callable, List, Union +from uuid import uuid4 + +from langchain.schema import BaseChatMessageHistory +from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict + + +class RocksetChatMessageHistory(BaseChatMessageHistory): + """Uses Rockset to store chat messages. + + To use, ensure that the `rockset` python package installed. + + Example: + .. code-block:: python + + from langchain.memory.chat_message_histories import ( + RocksetChatMessageHistory + ) + from rockset import RocksetClient + + history = RocksetChatMessageHistory( + session_id="MySession", + client=RocksetClient(), + collection="langchain_demo", + sync=True + ) + + history.add_user_message("hi!") + history.add_ai_message("whats up?") + + print(history.messages) + """ + + # You should set these values based on your VI. + # These values are configured for the typical + # free VI. Read more about VIs here: + # https://rockset.com/docs/instances + SLEEP_INTERVAL_MS = 5 + ADD_TIMEOUT_MS = 5000 + CREATE_TIMEOUT_MS = 20000 + + def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None: + """Sleeps until meth() evaluates to true. Passes kwargs into + meth. + """ + start = datetime.now() + while not method(**method_params): + curr = datetime.now() + if (curr - start).total_seconds() * 1000 > timeout: + raise TimeoutError(f"{method} timed out at {timeout} ms") + sleep(RocksetChatMessageHistory.SLEEP_INTERVAL_MS / 1000) + + def _query(self, query: str, **query_params: Any) -> List[Any]: + """Executes an SQL statement and returns the result + Args: + - query: The SQL string + - **query_params: Parameters to pass into the query + """ + return self.client.sql(query, params=query_params).results + + def _create_collection(self) -> None: + """Creates a collection for this message history""" + self.client.Collections.create_s3_collection( + name=self.collection, workspace=self.workspace + ) + + def _collection_exists(self) -> bool: + """Checks whether a collection exists for this message history""" + try: + self.client.Collections.get(collection=self.collection) + except self.rockset.exceptions.NotFoundException: + return False + return True + + def _collection_is_ready(self) -> bool: + """Checks whether the collection for this message history is ready + to be queried + """ + return ( + self.client.Collections.get(collection=self.collection).data.status + == "READY" + ) + + def _document_exists(self) -> bool: + return ( + len( + self._query( + f""" + SELECT 1 + FROM {self.location} + WHERE _id=:session_id + LIMIT 1 + """, + session_id=self.session_id, + ) + ) + != 0 + ) + + def _wait_until_collection_created(self) -> None: + """Sleeps until the collection for this message history is ready + to be queried + """ + self._wait_until( + lambda: self._collection_is_ready(), + RocksetChatMessageHistory.CREATE_TIMEOUT_MS, + ) + + def _wait_until_message_added(self, message_id: str) -> None: + """Sleeps until a message is added to the messages list""" + self._wait_until( + lambda message_id: len( + self._query( + f""" + SELECT * + FROM UNNEST(( + SELECT {self.messages_key} + FROM {self.location} + WHERE _id = :session_id + )) AS message + WHERE message.data.additional_kwargs.id = :message_id + LIMIT 1 + """, + session_id=self.session_id, + message_id=message_id, + ), + ) + != 0, + RocksetChatMessageHistory.ADD_TIMEOUT_MS, + message_id=message_id, + ) + + def _create_empty_doc(self) -> None: + """Creates or replaces a document for this message history with no + messages""" + self.client.Documents.add_documents( + collection=self.collection, + workspace=self.workspace, + data=[{"_id": self.session_id, self.messages_key: []}], + ) + + def __init__( + self, + session_id: str, + client: Any, + collection: str, + workspace: str = "commons", + messages_key: str = "messages", + sync: bool = False, + message_uuid_method: Callable[[], Union[str, int]] = lambda: str(uuid4()), + ) -> None: + """Constructs a new RocksetChatMessageHistory. + + Args: + - session_id: The ID of the chat session + - client: The RocksetClient object to use to query + - collection: The name of the collection to use to store chat + messages. If a collection with the given name + does not exist in the workspace, it is created. + - workspace: The workspace containing `collection`. Defaults + to `"commons"` + - messages_key: The DB column containing message history. + Defaults to `"messages"` + - sync: Whether to wait for messages to be added. Defaults + to `False`. NOTE: setting this to `True` will slow + down performance. + - message_uuid_method: The method that generates message IDs. + If set, all messages will have an `id` field within the + `additional_kwargs` property. If this param is not set + and `sync` is `False`, message IDs will not be created. + If this param is not set and `sync` is `True`, the + `uuid.uuid4` method will be used to create message IDs. + """ + try: + import rockset + except ImportError: + raise ImportError( + "Could not import rockset client python package. " + "Please install it with `pip install rockset`." + ) + + if not isinstance(client, rockset.RocksetClient): + raise ValueError( + f"client should be an instance of rockset.RocksetClient, " + f"got {type(client)}" + ) + + self.session_id = session_id + self.client = client + self.collection = collection + self.workspace = workspace + self.location = f'"{self.workspace}"."{self.collection}"' + self.rockset = rockset + self.messages_key = messages_key + self.message_uuid_method = message_uuid_method + self.sync = sync + + if not self._collection_exists(): + self._create_collection() + self._wait_until_collection_created() + self._create_empty_doc() + elif not self._document_exists(): + self._create_empty_doc() + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + """Messages in this chat history.""" + return messages_from_dict( + self._query( + f""" + SELECT * + FROM UNNEST (( + SELECT "{self.messages_key}" + FROM {self.location} + WHERE _id = :session_id + )) + """, + session_id=self.session_id, + ) + ) + + def add_message(self, message: BaseMessage) -> None: + """Add a Message object to the history. + + Args: + message: A BaseMessage object to store. + """ + if self.sync and "id" not in message.additional_kwargs: + message.additional_kwargs["id"] = self.message_uuid_method() + self.client.Documents.patch_documents( + collection=self.collection, + workspace=self.workspace, + data=[ + self.rockset.model.patch_document.PatchDocument( + id=self.session_id, + patch=[ + self.rockset.model.patch_operation.PatchOperation( + op="ADD", + path=f"/{self.messages_key}/-", + value=_message_to_dict(message), + ) + ], + ) + ], + ) + if self.sync: + self._wait_until_message_added(message.additional_kwargs["id"]) + + def clear(self) -> None: + """Removes all messages from the chat history""" + self._create_empty_doc() + if self.sync: + self._wait_until( + lambda: not self.messages, + RocksetChatMessageHistory.ADD_TIMEOUT_MS, + ) diff --git a/libs/langchain/tests/integration_tests/memory/test_rockset.py b/libs/langchain/tests/integration_tests/memory/test_rockset.py new file mode 100644 index 0000000000..e1cb50d05f --- /dev/null +++ b/libs/langchain/tests/integration_tests/memory/test_rockset.py @@ -0,0 +1,63 @@ +"""Tests RocksetChatMessageHistory by creating a collection +for message history, adding to it, and clearing it. + +To run these tests, make sure you have the ROCKSET_API_KEY +and ROCKSET_REGION environment variables set. +""" + +import json +import os + +from rockset import DevRegions, Regions, RocksetClient + +from langchain.memory import ConversationBufferMemory +from langchain.memory.chat_message_histories import RocksetChatMessageHistory +from langchain.schema.messages import _message_to_dict + +collection_name = "langchain_demo" +session_id = "MySession" + + +class TestRockset: + memory: RocksetChatMessageHistory + + @classmethod + def setup_class(cls) -> None: + assert os.environ.get("ROCKSET_API_KEY") is not None + assert os.environ.get("ROCKSET_REGION") is not None + + api_key = os.environ.get("ROCKSET_API_KEY") + region = os.environ.get("ROCKSET_REGION") + if region == "use1a1": + host = Regions.use1a1 + elif region == "usw2a1" or not region: + host = Regions.usw2a1 + elif region == "euc1a1": + host = Regions.euc1a1 + elif region == "dev": + host = DevRegions.usw2a1 + else: + host = region + + client = RocksetClient(host, api_key) + cls.memory = RocksetChatMessageHistory( + session_id, client, collection_name, sync=True + ) + + def test_memory_with_message_store(self) -> None: + memory = ConversationBufferMemory( + memory_key="messages", chat_memory=self.memory, return_messages=True + ) + + memory.chat_memory.add_ai_message("This is me, the AI") + memory.chat_memory.add_user_message("This is me, the human") + + messages = memory.chat_memory.messages + messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) + + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + memory.chat_memory.clear() + + assert memory.chat_memory.messages == []