mirror of https://github.com/hwchase17/langchain
Integrate Rockset as a chat history store (#8940)
Description: Adds Rockset as a chat history store Dependencies: no changes Tag maintainer: @hwchase17 This PR passes linting and testing. I added a test for the integration and an example notebook showing its use.pull/8828/head
parent
0a1be1d501
commit
3f64b8a761
@ -0,0 +1,67 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Rockset Chat Message History\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook goes over how to use [Rockset](https://rockset.com/docs) to store chat message history. \n",
|
||||||
|
"\n",
|
||||||
|
"To begin, with get your API key from the [Rockset console](https://console.rockset.com/apikeys). Find your API region for the Rockset [API reference](https://rockset.com/docs/rest-api#introduction)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"vscode": {
|
||||||
|
"languageId": "plaintext"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.memory.chat_message_histories import RocksetChatMessageHistory\n",
|
||||||
|
"from rockset import RocksetClient, Regions\n",
|
||||||
|
"\n",
|
||||||
|
"history = RocksetChatMessageHistory(\n",
|
||||||
|
" session_id=\"MySession\",\n",
|
||||||
|
" client=RocksetClient(\n",
|
||||||
|
" api_key=\"YOUR API KEY\", \n",
|
||||||
|
" host=Regions.usw2a1 # us-west-2 Oregon\n",
|
||||||
|
" ),\n",
|
||||||
|
" collection=\"langchain_demo\",\n",
|
||||||
|
" sync=True\n",
|
||||||
|
")\n",
|
||||||
|
"history.add_user_message(\"hi!\")\n",
|
||||||
|
"history.add_ai_message(\"whats up?\")\n",
|
||||||
|
"print(history.messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"The output should be something like:\n",
|
||||||
|
"\n",
|
||||||
|
"```python\n",
|
||||||
|
"[\n",
|
||||||
|
" HumanMessage(content='hi!', additional_kwargs={'id': '2e62f1c2-e9f7-465e-b551-49bae07fe9f0'}, example=False), \n",
|
||||||
|
" AIMessage(content='whats up?', additional_kwargs={'id': 'b9be8eda-4c18-4cf8-81c3-e91e876927d0'}, example=False)\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"```"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -0,0 +1,258 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from time import sleep
|
||||||
|
from typing import Any, Callable, List, Union
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from langchain.schema import BaseChatMessageHistory
|
||||||
|
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
|
||||||
|
|
||||||
|
|
||||||
|
class RocksetChatMessageHistory(BaseChatMessageHistory):
|
||||||
|
"""Uses Rockset to store chat messages.
|
||||||
|
|
||||||
|
To use, ensure that the `rockset` python package installed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.memory.chat_message_histories import (
|
||||||
|
RocksetChatMessageHistory
|
||||||
|
)
|
||||||
|
from rockset import RocksetClient
|
||||||
|
|
||||||
|
history = RocksetChatMessageHistory(
|
||||||
|
session_id="MySession",
|
||||||
|
client=RocksetClient(),
|
||||||
|
collection="langchain_demo",
|
||||||
|
sync=True
|
||||||
|
)
|
||||||
|
|
||||||
|
history.add_user_message("hi!")
|
||||||
|
history.add_ai_message("whats up?")
|
||||||
|
|
||||||
|
print(history.messages)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# You should set these values based on your VI.
|
||||||
|
# These values are configured for the typical
|
||||||
|
# free VI. Read more about VIs here:
|
||||||
|
# https://rockset.com/docs/instances
|
||||||
|
SLEEP_INTERVAL_MS = 5
|
||||||
|
ADD_TIMEOUT_MS = 5000
|
||||||
|
CREATE_TIMEOUT_MS = 20000
|
||||||
|
|
||||||
|
def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None:
|
||||||
|
"""Sleeps until meth() evaluates to true. Passes kwargs into
|
||||||
|
meth.
|
||||||
|
"""
|
||||||
|
start = datetime.now()
|
||||||
|
while not method(**method_params):
|
||||||
|
curr = datetime.now()
|
||||||
|
if (curr - start).total_seconds() * 1000 > timeout:
|
||||||
|
raise TimeoutError(f"{method} timed out at {timeout} ms")
|
||||||
|
sleep(RocksetChatMessageHistory.SLEEP_INTERVAL_MS / 1000)
|
||||||
|
|
||||||
|
def _query(self, query: str, **query_params: Any) -> List[Any]:
|
||||||
|
"""Executes an SQL statement and returns the result
|
||||||
|
Args:
|
||||||
|
- query: The SQL string
|
||||||
|
- **query_params: Parameters to pass into the query
|
||||||
|
"""
|
||||||
|
return self.client.sql(query, params=query_params).results
|
||||||
|
|
||||||
|
def _create_collection(self) -> None:
|
||||||
|
"""Creates a collection for this message history"""
|
||||||
|
self.client.Collections.create_s3_collection(
|
||||||
|
name=self.collection, workspace=self.workspace
|
||||||
|
)
|
||||||
|
|
||||||
|
def _collection_exists(self) -> bool:
|
||||||
|
"""Checks whether a collection exists for this message history"""
|
||||||
|
try:
|
||||||
|
self.client.Collections.get(collection=self.collection)
|
||||||
|
except self.rockset.exceptions.NotFoundException:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _collection_is_ready(self) -> bool:
|
||||||
|
"""Checks whether the collection for this message history is ready
|
||||||
|
to be queried
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
self.client.Collections.get(collection=self.collection).data.status
|
||||||
|
== "READY"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _document_exists(self) -> bool:
|
||||||
|
return (
|
||||||
|
len(
|
||||||
|
self._query(
|
||||||
|
f"""
|
||||||
|
SELECT 1
|
||||||
|
FROM {self.location}
|
||||||
|
WHERE _id=:session_id
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
session_id=self.session_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
!= 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _wait_until_collection_created(self) -> None:
|
||||||
|
"""Sleeps until the collection for this message history is ready
|
||||||
|
to be queried
|
||||||
|
"""
|
||||||
|
self._wait_until(
|
||||||
|
lambda: self._collection_is_ready(),
|
||||||
|
RocksetChatMessageHistory.CREATE_TIMEOUT_MS,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _wait_until_message_added(self, message_id: str) -> None:
|
||||||
|
"""Sleeps until a message is added to the messages list"""
|
||||||
|
self._wait_until(
|
||||||
|
lambda message_id: len(
|
||||||
|
self._query(
|
||||||
|
f"""
|
||||||
|
SELECT *
|
||||||
|
FROM UNNEST((
|
||||||
|
SELECT {self.messages_key}
|
||||||
|
FROM {self.location}
|
||||||
|
WHERE _id = :session_id
|
||||||
|
)) AS message
|
||||||
|
WHERE message.data.additional_kwargs.id = :message_id
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
session_id=self.session_id,
|
||||||
|
message_id=message_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
!= 0,
|
||||||
|
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_empty_doc(self) -> None:
|
||||||
|
"""Creates or replaces a document for this message history with no
|
||||||
|
messages"""
|
||||||
|
self.client.Documents.add_documents(
|
||||||
|
collection=self.collection,
|
||||||
|
workspace=self.workspace,
|
||||||
|
data=[{"_id": self.session_id, self.messages_key: []}],
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
client: Any,
|
||||||
|
collection: str,
|
||||||
|
workspace: str = "commons",
|
||||||
|
messages_key: str = "messages",
|
||||||
|
sync: bool = False,
|
||||||
|
message_uuid_method: Callable[[], Union[str, int]] = lambda: str(uuid4()),
|
||||||
|
) -> None:
|
||||||
|
"""Constructs a new RocksetChatMessageHistory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- session_id: The ID of the chat session
|
||||||
|
- client: The RocksetClient object to use to query
|
||||||
|
- collection: The name of the collection to use to store chat
|
||||||
|
messages. If a collection with the given name
|
||||||
|
does not exist in the workspace, it is created.
|
||||||
|
- workspace: The workspace containing `collection`. Defaults
|
||||||
|
to `"commons"`
|
||||||
|
- messages_key: The DB column containing message history.
|
||||||
|
Defaults to `"messages"`
|
||||||
|
- sync: Whether to wait for messages to be added. Defaults
|
||||||
|
to `False`. NOTE: setting this to `True` will slow
|
||||||
|
down performance.
|
||||||
|
- message_uuid_method: The method that generates message IDs.
|
||||||
|
If set, all messages will have an `id` field within the
|
||||||
|
`additional_kwargs` property. If this param is not set
|
||||||
|
and `sync` is `False`, message IDs will not be created.
|
||||||
|
If this param is not set and `sync` is `True`, the
|
||||||
|
`uuid.uuid4` method will be used to create message IDs.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import rockset
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import rockset client python package. "
|
||||||
|
"Please install it with `pip install rockset`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(client, rockset.RocksetClient):
|
||||||
|
raise ValueError(
|
||||||
|
f"client should be an instance of rockset.RocksetClient, "
|
||||||
|
f"got {type(client)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.session_id = session_id
|
||||||
|
self.client = client
|
||||||
|
self.collection = collection
|
||||||
|
self.workspace = workspace
|
||||||
|
self.location = f'"{self.workspace}"."{self.collection}"'
|
||||||
|
self.rockset = rockset
|
||||||
|
self.messages_key = messages_key
|
||||||
|
self.message_uuid_method = message_uuid_method
|
||||||
|
self.sync = sync
|
||||||
|
|
||||||
|
if not self._collection_exists():
|
||||||
|
self._create_collection()
|
||||||
|
self._wait_until_collection_created()
|
||||||
|
self._create_empty_doc()
|
||||||
|
elif not self._document_exists():
|
||||||
|
self._create_empty_doc()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||||
|
"""Messages in this chat history."""
|
||||||
|
return messages_from_dict(
|
||||||
|
self._query(
|
||||||
|
f"""
|
||||||
|
SELECT *
|
||||||
|
FROM UNNEST ((
|
||||||
|
SELECT "{self.messages_key}"
|
||||||
|
FROM {self.location}
|
||||||
|
WHERE _id = :session_id
|
||||||
|
))
|
||||||
|
""",
|
||||||
|
session_id=self.session_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
|
"""Add a Message object to the history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: A BaseMessage object to store.
|
||||||
|
"""
|
||||||
|
if self.sync and "id" not in message.additional_kwargs:
|
||||||
|
message.additional_kwargs["id"] = self.message_uuid_method()
|
||||||
|
self.client.Documents.patch_documents(
|
||||||
|
collection=self.collection,
|
||||||
|
workspace=self.workspace,
|
||||||
|
data=[
|
||||||
|
self.rockset.model.patch_document.PatchDocument(
|
||||||
|
id=self.session_id,
|
||||||
|
patch=[
|
||||||
|
self.rockset.model.patch_operation.PatchOperation(
|
||||||
|
op="ADD",
|
||||||
|
path=f"/{self.messages_key}/-",
|
||||||
|
value=_message_to_dict(message),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if self.sync:
|
||||||
|
self._wait_until_message_added(message.additional_kwargs["id"])
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Removes all messages from the chat history"""
|
||||||
|
self._create_empty_doc()
|
||||||
|
if self.sync:
|
||||||
|
self._wait_until(
|
||||||
|
lambda: not self.messages,
|
||||||
|
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
|
||||||
|
)
|
@ -0,0 +1,63 @@
|
|||||||
|
"""Tests RocksetChatMessageHistory by creating a collection
|
||||||
|
for message history, adding to it, and clearing it.
|
||||||
|
|
||||||
|
To run these tests, make sure you have the ROCKSET_API_KEY
|
||||||
|
and ROCKSET_REGION environment variables set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from rockset import DevRegions, Regions, RocksetClient
|
||||||
|
|
||||||
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
from langchain.memory.chat_message_histories import RocksetChatMessageHistory
|
||||||
|
from langchain.schema.messages import _message_to_dict
|
||||||
|
|
||||||
|
collection_name = "langchain_demo"
|
||||||
|
session_id = "MySession"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRockset:
|
||||||
|
memory: RocksetChatMessageHistory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls) -> None:
|
||||||
|
assert os.environ.get("ROCKSET_API_KEY") is not None
|
||||||
|
assert os.environ.get("ROCKSET_REGION") is not None
|
||||||
|
|
||||||
|
api_key = os.environ.get("ROCKSET_API_KEY")
|
||||||
|
region = os.environ.get("ROCKSET_REGION")
|
||||||
|
if region == "use1a1":
|
||||||
|
host = Regions.use1a1
|
||||||
|
elif region == "usw2a1" or not region:
|
||||||
|
host = Regions.usw2a1
|
||||||
|
elif region == "euc1a1":
|
||||||
|
host = Regions.euc1a1
|
||||||
|
elif region == "dev":
|
||||||
|
host = DevRegions.usw2a1
|
||||||
|
else:
|
||||||
|
host = region
|
||||||
|
|
||||||
|
client = RocksetClient(host, api_key)
|
||||||
|
cls.memory = RocksetChatMessageHistory(
|
||||||
|
session_id, client, collection_name, sync=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_memory_with_message_store(self) -> None:
|
||||||
|
memory = ConversationBufferMemory(
|
||||||
|
memory_key="messages", chat_memory=self.memory, return_messages=True
|
||||||
|
)
|
||||||
|
|
||||||
|
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||||
|
memory.chat_memory.add_user_message("This is me, the human")
|
||||||
|
|
||||||
|
messages = memory.chat_memory.messages
|
||||||
|
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||||
|
|
||||||
|
assert "This is me, the AI" in messages_json
|
||||||
|
assert "This is me, the human" in messages_json
|
||||||
|
|
||||||
|
memory.chat_memory.clear()
|
||||||
|
|
||||||
|
assert memory.chat_memory.messages == []
|
Loading…
Reference in New Issue