Integrate Rockset as a chat history store (#8940)

Description: Adds Rockset as a chat history store
Dependencies: no changes
Tag maintainer: @hwchase17

This PR passes linting and testing. 

I added a test for the integration and an example notebook showing its
use.
pull/8828/head
Aarav Borthakur 11 months ago committed by GitHub
parent 0a1be1d501
commit 3f64b8a761
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,67 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Rockset Chat Message History\n",
"\n",
"This notebook goes over how to use [Rockset](https://rockset.com/docs) to store chat message history. \n",
"\n",
"To begin, with get your API key from the [Rockset console](https://console.rockset.com/apikeys). Find your API region for the Rockset [API reference](https://rockset.com/docs/rest-api#introduction)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"from langchain.memory.chat_message_histories import RocksetChatMessageHistory\n",
"from rockset import RocksetClient, Regions\n",
"\n",
"history = RocksetChatMessageHistory(\n",
" session_id=\"MySession\",\n",
" client=RocksetClient(\n",
" api_key=\"YOUR API KEY\", \n",
" host=Regions.usw2a1 # us-west-2 Oregon\n",
" ),\n",
" collection=\"langchain_demo\",\n",
" sync=True\n",
")\n",
"history.add_user_message(\"hi!\")\n",
"history.add_ai_message(\"whats up?\")\n",
"print(history.messages)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"The output should be something like:\n",
"\n",
"```python\n",
"[\n",
" HumanMessage(content='hi!', additional_kwargs={'id': '2e62f1c2-e9f7-465e-b551-49bae07fe9f0'}, example=False), \n",
" AIMessage(content='whats up?', additional_kwargs={'id': 'b9be8eda-4c18-4cf8-81c3-e91e876927d0'}, example=False)\n",
"]\n",
"\n",
"```"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -23,4 +23,11 @@ from langchain.vectorstores import Rockset
See a [usage example](/docs/integrations/document_loaders/rockset).
```python
from langchain.document_loaders import RocksetLoader
```
## Chat Message History
See a [usage example](/docs/integrations/memory/rockset_chat_message_history).
```python
from langchain.memory.chat_message_histories import RocksetChatMessageHistory
```

@ -12,6 +12,7 @@ from langchain.memory.chat_message_histories.momento import MomentoChatMessageHi
from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHistory
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
from langchain.memory.chat_message_histories.rocksetdb import RocksetChatMessageHistory
from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory
from langchain.memory.chat_message_histories.streamlit import (
StreamlitChatMessageHistory,
@ -29,6 +30,7 @@ __all__ = [
"MongoDBChatMessageHistory",
"PostgresChatMessageHistory",
"RedisChatMessageHistory",
"RocksetChatMessageHistory",
"SQLChatMessageHistory",
"StreamlitChatMessageHistory",
"ZepChatMessageHistory",

@ -0,0 +1,258 @@
from datetime import datetime
from time import sleep
from typing import Any, Callable, List, Union
from uuid import uuid4
from langchain.schema import BaseChatMessageHistory
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
class RocksetChatMessageHistory(BaseChatMessageHistory):
"""Uses Rockset to store chat messages.
To use, ensure that the `rockset` python package installed.
Example:
.. code-block:: python
from langchain.memory.chat_message_histories import (
RocksetChatMessageHistory
)
from rockset import RocksetClient
history = RocksetChatMessageHistory(
session_id="MySession",
client=RocksetClient(),
collection="langchain_demo",
sync=True
)
history.add_user_message("hi!")
history.add_ai_message("whats up?")
print(history.messages)
"""
# You should set these values based on your VI.
# These values are configured for the typical
# free VI. Read more about VIs here:
# https://rockset.com/docs/instances
SLEEP_INTERVAL_MS = 5
ADD_TIMEOUT_MS = 5000
CREATE_TIMEOUT_MS = 20000
def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None:
"""Sleeps until meth() evaluates to true. Passes kwargs into
meth.
"""
start = datetime.now()
while not method(**method_params):
curr = datetime.now()
if (curr - start).total_seconds() * 1000 > timeout:
raise TimeoutError(f"{method} timed out at {timeout} ms")
sleep(RocksetChatMessageHistory.SLEEP_INTERVAL_MS / 1000)
def _query(self, query: str, **query_params: Any) -> List[Any]:
"""Executes an SQL statement and returns the result
Args:
- query: The SQL string
- **query_params: Parameters to pass into the query
"""
return self.client.sql(query, params=query_params).results
def _create_collection(self) -> None:
"""Creates a collection for this message history"""
self.client.Collections.create_s3_collection(
name=self.collection, workspace=self.workspace
)
def _collection_exists(self) -> bool:
"""Checks whether a collection exists for this message history"""
try:
self.client.Collections.get(collection=self.collection)
except self.rockset.exceptions.NotFoundException:
return False
return True
def _collection_is_ready(self) -> bool:
"""Checks whether the collection for this message history is ready
to be queried
"""
return (
self.client.Collections.get(collection=self.collection).data.status
== "READY"
)
def _document_exists(self) -> bool:
return (
len(
self._query(
f"""
SELECT 1
FROM {self.location}
WHERE _id=:session_id
LIMIT 1
""",
session_id=self.session_id,
)
)
!= 0
)
def _wait_until_collection_created(self) -> None:
"""Sleeps until the collection for this message history is ready
to be queried
"""
self._wait_until(
lambda: self._collection_is_ready(),
RocksetChatMessageHistory.CREATE_TIMEOUT_MS,
)
def _wait_until_message_added(self, message_id: str) -> None:
"""Sleeps until a message is added to the messages list"""
self._wait_until(
lambda message_id: len(
self._query(
f"""
SELECT *
FROM UNNEST((
SELECT {self.messages_key}
FROM {self.location}
WHERE _id = :session_id
)) AS message
WHERE message.data.additional_kwargs.id = :message_id
LIMIT 1
""",
session_id=self.session_id,
message_id=message_id,
),
)
!= 0,
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
message_id=message_id,
)
def _create_empty_doc(self) -> None:
"""Creates or replaces a document for this message history with no
messages"""
self.client.Documents.add_documents(
collection=self.collection,
workspace=self.workspace,
data=[{"_id": self.session_id, self.messages_key: []}],
)
def __init__(
self,
session_id: str,
client: Any,
collection: str,
workspace: str = "commons",
messages_key: str = "messages",
sync: bool = False,
message_uuid_method: Callable[[], Union[str, int]] = lambda: str(uuid4()),
) -> None:
"""Constructs a new RocksetChatMessageHistory.
Args:
- session_id: The ID of the chat session
- client: The RocksetClient object to use to query
- collection: The name of the collection to use to store chat
messages. If a collection with the given name
does not exist in the workspace, it is created.
- workspace: The workspace containing `collection`. Defaults
to `"commons"`
- messages_key: The DB column containing message history.
Defaults to `"messages"`
- sync: Whether to wait for messages to be added. Defaults
to `False`. NOTE: setting this to `True` will slow
down performance.
- message_uuid_method: The method that generates message IDs.
If set, all messages will have an `id` field within the
`additional_kwargs` property. If this param is not set
and `sync` is `False`, message IDs will not be created.
If this param is not set and `sync` is `True`, the
`uuid.uuid4` method will be used to create message IDs.
"""
try:
import rockset
except ImportError:
raise ImportError(
"Could not import rockset client python package. "
"Please install it with `pip install rockset`."
)
if not isinstance(client, rockset.RocksetClient):
raise ValueError(
f"client should be an instance of rockset.RocksetClient, "
f"got {type(client)}"
)
self.session_id = session_id
self.client = client
self.collection = collection
self.workspace = workspace
self.location = f'"{self.workspace}"."{self.collection}"'
self.rockset = rockset
self.messages_key = messages_key
self.message_uuid_method = message_uuid_method
self.sync = sync
if not self._collection_exists():
self._create_collection()
self._wait_until_collection_created()
self._create_empty_doc()
elif not self._document_exists():
self._create_empty_doc()
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Messages in this chat history."""
return messages_from_dict(
self._query(
f"""
SELECT *
FROM UNNEST ((
SELECT "{self.messages_key}"
FROM {self.location}
WHERE _id = :session_id
))
""",
session_id=self.session_id,
)
)
def add_message(self, message: BaseMessage) -> None:
"""Add a Message object to the history.
Args:
message: A BaseMessage object to store.
"""
if self.sync and "id" not in message.additional_kwargs:
message.additional_kwargs["id"] = self.message_uuid_method()
self.client.Documents.patch_documents(
collection=self.collection,
workspace=self.workspace,
data=[
self.rockset.model.patch_document.PatchDocument(
id=self.session_id,
patch=[
self.rockset.model.patch_operation.PatchOperation(
op="ADD",
path=f"/{self.messages_key}/-",
value=_message_to_dict(message),
)
],
)
],
)
if self.sync:
self._wait_until_message_added(message.additional_kwargs["id"])
def clear(self) -> None:
"""Removes all messages from the chat history"""
self._create_empty_doc()
if self.sync:
self._wait_until(
lambda: not self.messages,
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
)

@ -0,0 +1,63 @@
"""Tests RocksetChatMessageHistory by creating a collection
for message history, adding to it, and clearing it.
To run these tests, make sure you have the ROCKSET_API_KEY
and ROCKSET_REGION environment variables set.
"""
import json
import os
from rockset import DevRegions, Regions, RocksetClient
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import RocksetChatMessageHistory
from langchain.schema.messages import _message_to_dict
collection_name = "langchain_demo"
session_id = "MySession"
class TestRockset:
memory: RocksetChatMessageHistory
@classmethod
def setup_class(cls) -> None:
assert os.environ.get("ROCKSET_API_KEY") is not None
assert os.environ.get("ROCKSET_REGION") is not None
api_key = os.environ.get("ROCKSET_API_KEY")
region = os.environ.get("ROCKSET_REGION")
if region == "use1a1":
host = Regions.use1a1
elif region == "usw2a1" or not region:
host = Regions.usw2a1
elif region == "euc1a1":
host = Regions.euc1a1
elif region == "dev":
host = DevRegions.usw2a1
else:
host = region
client = RocksetClient(host, api_key)
cls.memory = RocksetChatMessageHistory(
session_id, client, collection_name, sync=True
)
def test_memory_with_message_store(self) -> None:
memory = ConversationBufferMemory(
memory_key="messages", chat_memory=self.memory, return_messages=True
)
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
memory.chat_memory.clear()
assert memory.chat_memory.messages == []
Loading…
Cancel
Save