langchain/libs/community/langchain_community/chat_message_histories/rocksetdb.py
Erick Friis 3a2eb6e12b
infra: add print rule to ruff (#16221)
Added noqa for existing prints. Can slowly remove / will prevent more
being intro'd
2024-02-09 16:13:30 -08:00

269 lines
9.3 KiB
Python

from datetime import datetime
from time import sleep
from typing import Any, Callable, List, Union
from uuid import uuid4
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
class RocksetChatMessageHistory(BaseChatMessageHistory):
"""Uses Rockset to store chat messages.
To use, ensure that the `rockset` python package installed.
Example:
.. code-block:: python
from langchain_community.chat_message_histories import (
RocksetChatMessageHistory
)
from rockset import RocksetClient
history = RocksetChatMessageHistory(
session_id="MySession",
client=RocksetClient(),
collection="langchain_demo",
sync=True
)
history.add_user_message("hi!")
history.add_ai_message("whats up?")
print(history.messages) # noqa: T201
"""
# You should set these values based on your VI.
# These values are configured for the typical
# free VI. Read more about VIs here:
# https://rockset.com/docs/instances
SLEEP_INTERVAL_MS: int = 5
ADD_TIMEOUT_MS: int = 5000
CREATE_TIMEOUT_MS: int = 20000
def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None:
"""Sleeps until meth() evaluates to true. Passes kwargs into
meth.
"""
start = datetime.now()
while not method(**method_params):
curr = datetime.now()
if (curr - start).total_seconds() * 1000 > timeout:
raise TimeoutError(f"{method} timed out at {timeout} ms")
sleep(RocksetChatMessageHistory.SLEEP_INTERVAL_MS / 1000)
def _query(self, query: str, **query_params: Any) -> List[Any]:
"""Executes an SQL statement and returns the result
Args:
- query: The SQL string
- **query_params: Parameters to pass into the query
"""
return self.client.sql(query, params=query_params).results
def _create_collection(self) -> None:
"""Creates a collection for this message history"""
self.client.Collections.create_s3_collection(
name=self.collection, workspace=self.workspace
)
def _collection_exists(self) -> bool:
"""Checks whether a collection exists for this message history"""
try:
self.client.Collections.get(collection=self.collection)
except self.rockset.exceptions.NotFoundException:
return False
return True
def _collection_is_ready(self) -> bool:
"""Checks whether the collection for this message history is ready
to be queried
"""
return (
self.client.Collections.get(collection=self.collection).data.status
== "READY"
)
def _document_exists(self) -> bool:
return (
len(
self._query(
f"""
SELECT 1
FROM {self.location}
WHERE _id=:session_id
LIMIT 1
""",
session_id=self.session_id,
)
)
!= 0
)
def _wait_until_collection_created(self) -> None:
"""Sleeps until the collection for this message history is ready
to be queried
"""
self._wait_until(
lambda: self._collection_is_ready(),
RocksetChatMessageHistory.CREATE_TIMEOUT_MS,
)
def _wait_until_message_added(self, message_id: str) -> None:
"""Sleeps until a message is added to the messages list"""
self._wait_until(
lambda message_id: len(
self._query(
f"""
SELECT *
FROM UNNEST((
SELECT {self.messages_key}
FROM {self.location}
WHERE _id = :session_id
)) AS message
WHERE message.data.additional_kwargs.id = :message_id
LIMIT 1
""",
session_id=self.session_id,
message_id=message_id,
),
)
!= 0,
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
message_id=message_id,
)
def _create_empty_doc(self) -> None:
"""Creates or replaces a document for this message history with no
messages"""
self.client.Documents.add_documents(
collection=self.collection,
workspace=self.workspace,
data=[{"_id": self.session_id, self.messages_key: []}],
)
def __init__(
self,
session_id: str,
client: Any,
collection: str,
workspace: str = "commons",
messages_key: str = "messages",
sync: bool = False,
message_uuid_method: Callable[[], Union[str, int]] = lambda: str(uuid4()),
) -> None:
"""Constructs a new RocksetChatMessageHistory.
Args:
- session_id: The ID of the chat session
- client: The RocksetClient object to use to query
- collection: The name of the collection to use to store chat
messages. If a collection with the given name
does not exist in the workspace, it is created.
- workspace: The workspace containing `collection`. Defaults
to `"commons"`
- messages_key: The DB column containing message history.
Defaults to `"messages"`
- sync: Whether to wait for messages to be added. Defaults
to `False`. NOTE: setting this to `True` will slow
down performance.
- message_uuid_method: The method that generates message IDs.
If set, all messages will have an `id` field within the
`additional_kwargs` property. If this param is not set
and `sync` is `False`, message IDs will not be created.
If this param is not set and `sync` is `True`, the
`uuid.uuid4` method will be used to create message IDs.
"""
try:
import rockset
except ImportError:
raise ImportError(
"Could not import rockset client python package. "
"Please install it with `pip install rockset`."
)
if not isinstance(client, rockset.RocksetClient):
raise ValueError(
f"client should be an instance of rockset.RocksetClient, "
f"got {type(client)}"
)
self.session_id = session_id
self.client = client
self.collection = collection
self.workspace = workspace
self.location = f'"{self.workspace}"."{self.collection}"'
self.rockset = rockset
self.messages_key = messages_key
self.message_uuid_method = message_uuid_method
self.sync = sync
try:
self.client.set_application("langchain")
except AttributeError:
# ignore
pass
if not self._collection_exists():
self._create_collection()
self._wait_until_collection_created()
self._create_empty_doc()
elif not self._document_exists():
self._create_empty_doc()
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Messages in this chat history."""
return messages_from_dict(
self._query(
f"""
SELECT *
FROM UNNEST ((
SELECT "{self.messages_key}"
FROM {self.location}
WHERE _id = :session_id
))
""",
session_id=self.session_id,
)
)
def add_message(self, message: BaseMessage) -> None:
"""Add a Message object to the history.
Args:
message: A BaseMessage object to store.
"""
if self.sync and "id" not in message.additional_kwargs:
message.additional_kwargs["id"] = self.message_uuid_method()
self.client.Documents.patch_documents(
collection=self.collection,
workspace=self.workspace,
data=[
self.rockset.model.patch_document.PatchDocument(
id=self.session_id,
patch=[
self.rockset.model.patch_operation.PatchOperation(
op="ADD",
path=f"/{self.messages_key}/-",
value=message_to_dict(message),
)
],
)
],
)
if self.sync:
self._wait_until_message_added(message.additional_kwargs["id"])
def clear(self) -> None:
"""Removes all messages from the chat history"""
self._create_empty_doc()
if self.sync:
self._wait_until(
lambda: not self.messages,
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
)