mirror of https://github.com/hwchase17/langchain
Integrate Rockset as a chat history store (#8940)
Description: Adds Rockset as a chat history store Dependencies: no changes Tag maintainer: @hwchase17 This PR passes linting and testing. I added a test for the integration and an example notebook showing its use.pull/8828/head
parent
0a1be1d501
commit
3f64b8a761
@ -0,0 +1,67 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Rockset Chat Message History\n",
|
||||
"\n",
|
||||
"This notebook goes over how to use [Rockset](https://rockset.com/docs) to store chat message history. \n",
|
||||
"\n",
|
||||
"To begin, with get your API key from the [Rockset console](https://console.rockset.com/apikeys). Find your API region for the Rockset [API reference](https://rockset.com/docs/rest-api#introduction)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "plaintext"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.memory.chat_message_histories import RocksetChatMessageHistory\n",
|
||||
"from rockset import RocksetClient, Regions\n",
|
||||
"\n",
|
||||
"history = RocksetChatMessageHistory(\n",
|
||||
" session_id=\"MySession\",\n",
|
||||
" client=RocksetClient(\n",
|
||||
" api_key=\"YOUR API KEY\", \n",
|
||||
" host=Regions.usw2a1 # us-west-2 Oregon\n",
|
||||
" ),\n",
|
||||
" collection=\"langchain_demo\",\n",
|
||||
" sync=True\n",
|
||||
")\n",
|
||||
"history.add_user_message(\"hi!\")\n",
|
||||
"history.add_ai_message(\"whats up?\")\n",
|
||||
"print(history.messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The output should be something like:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"[\n",
|
||||
" HumanMessage(content='hi!', additional_kwargs={'id': '2e62f1c2-e9f7-465e-b551-49bae07fe9f0'}, example=False), \n",
|
||||
" AIMessage(content='whats up?', additional_kwargs={'id': 'b9be8eda-4c18-4cf8-81c3-e91e876927d0'}, example=False)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"```"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,258 @@
|
||||
from datetime import datetime
|
||||
from time import sleep
|
||||
from typing import Any, Callable, List, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain.schema import BaseChatMessageHistory
|
||||
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
|
||||
|
||||
|
||||
class RocksetChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Uses Rockset to store chat messages.
|
||||
|
||||
To use, ensure that the `rockset` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.memory.chat_message_histories import (
|
||||
RocksetChatMessageHistory
|
||||
)
|
||||
from rockset import RocksetClient
|
||||
|
||||
history = RocksetChatMessageHistory(
|
||||
session_id="MySession",
|
||||
client=RocksetClient(),
|
||||
collection="langchain_demo",
|
||||
sync=True
|
||||
)
|
||||
|
||||
history.add_user_message("hi!")
|
||||
history.add_ai_message("whats up?")
|
||||
|
||||
print(history.messages)
|
||||
"""
|
||||
|
||||
# You should set these values based on your VI.
|
||||
# These values are configured for the typical
|
||||
# free VI. Read more about VIs here:
|
||||
# https://rockset.com/docs/instances
|
||||
SLEEP_INTERVAL_MS = 5
|
||||
ADD_TIMEOUT_MS = 5000
|
||||
CREATE_TIMEOUT_MS = 20000
|
||||
|
||||
def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None:
|
||||
"""Sleeps until meth() evaluates to true. Passes kwargs into
|
||||
meth.
|
||||
"""
|
||||
start = datetime.now()
|
||||
while not method(**method_params):
|
||||
curr = datetime.now()
|
||||
if (curr - start).total_seconds() * 1000 > timeout:
|
||||
raise TimeoutError(f"{method} timed out at {timeout} ms")
|
||||
sleep(RocksetChatMessageHistory.SLEEP_INTERVAL_MS / 1000)
|
||||
|
||||
def _query(self, query: str, **query_params: Any) -> List[Any]:
|
||||
"""Executes an SQL statement and returns the result
|
||||
Args:
|
||||
- query: The SQL string
|
||||
- **query_params: Parameters to pass into the query
|
||||
"""
|
||||
return self.client.sql(query, params=query_params).results
|
||||
|
||||
def _create_collection(self) -> None:
|
||||
"""Creates a collection for this message history"""
|
||||
self.client.Collections.create_s3_collection(
|
||||
name=self.collection, workspace=self.workspace
|
||||
)
|
||||
|
||||
def _collection_exists(self) -> bool:
|
||||
"""Checks whether a collection exists for this message history"""
|
||||
try:
|
||||
self.client.Collections.get(collection=self.collection)
|
||||
except self.rockset.exceptions.NotFoundException:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _collection_is_ready(self) -> bool:
|
||||
"""Checks whether the collection for this message history is ready
|
||||
to be queried
|
||||
"""
|
||||
return (
|
||||
self.client.Collections.get(collection=self.collection).data.status
|
||||
== "READY"
|
||||
)
|
||||
|
||||
def _document_exists(self) -> bool:
|
||||
return (
|
||||
len(
|
||||
self._query(
|
||||
f"""
|
||||
SELECT 1
|
||||
FROM {self.location}
|
||||
WHERE _id=:session_id
|
||||
LIMIT 1
|
||||
""",
|
||||
session_id=self.session_id,
|
||||
)
|
||||
)
|
||||
!= 0
|
||||
)
|
||||
|
||||
def _wait_until_collection_created(self) -> None:
|
||||
"""Sleeps until the collection for this message history is ready
|
||||
to be queried
|
||||
"""
|
||||
self._wait_until(
|
||||
lambda: self._collection_is_ready(),
|
||||
RocksetChatMessageHistory.CREATE_TIMEOUT_MS,
|
||||
)
|
||||
|
||||
def _wait_until_message_added(self, message_id: str) -> None:
|
||||
"""Sleeps until a message is added to the messages list"""
|
||||
self._wait_until(
|
||||
lambda message_id: len(
|
||||
self._query(
|
||||
f"""
|
||||
SELECT *
|
||||
FROM UNNEST((
|
||||
SELECT {self.messages_key}
|
||||
FROM {self.location}
|
||||
WHERE _id = :session_id
|
||||
)) AS message
|
||||
WHERE message.data.additional_kwargs.id = :message_id
|
||||
LIMIT 1
|
||||
""",
|
||||
session_id=self.session_id,
|
||||
message_id=message_id,
|
||||
),
|
||||
)
|
||||
!= 0,
|
||||
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def _create_empty_doc(self) -> None:
|
||||
"""Creates or replaces a document for this message history with no
|
||||
messages"""
|
||||
self.client.Documents.add_documents(
|
||||
collection=self.collection,
|
||||
workspace=self.workspace,
|
||||
data=[{"_id": self.session_id, self.messages_key: []}],
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
client: Any,
|
||||
collection: str,
|
||||
workspace: str = "commons",
|
||||
messages_key: str = "messages",
|
||||
sync: bool = False,
|
||||
message_uuid_method: Callable[[], Union[str, int]] = lambda: str(uuid4()),
|
||||
) -> None:
|
||||
"""Constructs a new RocksetChatMessageHistory.
|
||||
|
||||
Args:
|
||||
- session_id: The ID of the chat session
|
||||
- client: The RocksetClient object to use to query
|
||||
- collection: The name of the collection to use to store chat
|
||||
messages. If a collection with the given name
|
||||
does not exist in the workspace, it is created.
|
||||
- workspace: The workspace containing `collection`. Defaults
|
||||
to `"commons"`
|
||||
- messages_key: The DB column containing message history.
|
||||
Defaults to `"messages"`
|
||||
- sync: Whether to wait for messages to be added. Defaults
|
||||
to `False`. NOTE: setting this to `True` will slow
|
||||
down performance.
|
||||
- message_uuid_method: The method that generates message IDs.
|
||||
If set, all messages will have an `id` field within the
|
||||
`additional_kwargs` property. If this param is not set
|
||||
and `sync` is `False`, message IDs will not be created.
|
||||
If this param is not set and `sync` is `True`, the
|
||||
`uuid.uuid4` method will be used to create message IDs.
|
||||
"""
|
||||
try:
|
||||
import rockset
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import rockset client python package. "
|
||||
"Please install it with `pip install rockset`."
|
||||
)
|
||||
|
||||
if not isinstance(client, rockset.RocksetClient):
|
||||
raise ValueError(
|
||||
f"client should be an instance of rockset.RocksetClient, "
|
||||
f"got {type(client)}"
|
||||
)
|
||||
|
||||
self.session_id = session_id
|
||||
self.client = client
|
||||
self.collection = collection
|
||||
self.workspace = workspace
|
||||
self.location = f'"{self.workspace}"."{self.collection}"'
|
||||
self.rockset = rockset
|
||||
self.messages_key = messages_key
|
||||
self.message_uuid_method = message_uuid_method
|
||||
self.sync = sync
|
||||
|
||||
if not self._collection_exists():
|
||||
self._create_collection()
|
||||
self._wait_until_collection_created()
|
||||
self._create_empty_doc()
|
||||
elif not self._document_exists():
|
||||
self._create_empty_doc()
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Messages in this chat history."""
|
||||
return messages_from_dict(
|
||||
self._query(
|
||||
f"""
|
||||
SELECT *
|
||||
FROM UNNEST ((
|
||||
SELECT "{self.messages_key}"
|
||||
FROM {self.location}
|
||||
WHERE _id = :session_id
|
||||
))
|
||||
""",
|
||||
session_id=self.session_id,
|
||||
)
|
||||
)
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a Message object to the history.
|
||||
|
||||
Args:
|
||||
message: A BaseMessage object to store.
|
||||
"""
|
||||
if self.sync and "id" not in message.additional_kwargs:
|
||||
message.additional_kwargs["id"] = self.message_uuid_method()
|
||||
self.client.Documents.patch_documents(
|
||||
collection=self.collection,
|
||||
workspace=self.workspace,
|
||||
data=[
|
||||
self.rockset.model.patch_document.PatchDocument(
|
||||
id=self.session_id,
|
||||
patch=[
|
||||
self.rockset.model.patch_operation.PatchOperation(
|
||||
op="ADD",
|
||||
path=f"/{self.messages_key}/-",
|
||||
value=_message_to_dict(message),
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
if self.sync:
|
||||
self._wait_until_message_added(message.additional_kwargs["id"])
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Removes all messages from the chat history"""
|
||||
self._create_empty_doc()
|
||||
if self.sync:
|
||||
self._wait_until(
|
||||
lambda: not self.messages,
|
||||
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
|
||||
)
|
@ -0,0 +1,63 @@
|
||||
"""Tests RocksetChatMessageHistory by creating a collection
|
||||
for message history, adding to it, and clearing it.
|
||||
|
||||
To run these tests, make sure you have the ROCKSET_API_KEY
|
||||
and ROCKSET_REGION environment variables set.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from rockset import DevRegions, Regions, RocksetClient
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_message_histories import RocksetChatMessageHistory
|
||||
from langchain.schema.messages import _message_to_dict
|
||||
|
||||
collection_name = "langchain_demo"
|
||||
session_id = "MySession"
|
||||
|
||||
|
||||
class TestRockset:
|
||||
memory: RocksetChatMessageHistory
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls) -> None:
|
||||
assert os.environ.get("ROCKSET_API_KEY") is not None
|
||||
assert os.environ.get("ROCKSET_REGION") is not None
|
||||
|
||||
api_key = os.environ.get("ROCKSET_API_KEY")
|
||||
region = os.environ.get("ROCKSET_REGION")
|
||||
if region == "use1a1":
|
||||
host = Regions.use1a1
|
||||
elif region == "usw2a1" or not region:
|
||||
host = Regions.usw2a1
|
||||
elif region == "euc1a1":
|
||||
host = Regions.euc1a1
|
||||
elif region == "dev":
|
||||
host = DevRegions.usw2a1
|
||||
else:
|
||||
host = region
|
||||
|
||||
client = RocksetClient(host, api_key)
|
||||
cls.memory = RocksetChatMessageHistory(
|
||||
session_id, client, collection_name, sync=True
|
||||
)
|
||||
|
||||
def test_memory_with_message_store(self) -> None:
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="messages", chat_memory=self.memory, return_messages=True
|
||||
)
|
||||
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
Loading…
Reference in New Issue