mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
3a2eb6e12b
Added noqa for existing prints. Can slowly remove / will prevent more being intro'd
269 lines
9.3 KiB
Python
269 lines
9.3 KiB
Python
from datetime import datetime
|
|
from time import sleep
|
|
from typing import Any, Callable, List, Union
|
|
from uuid import uuid4
|
|
|
|
from langchain_core.chat_history import BaseChatMessageHistory
|
|
from langchain_core.messages import (
|
|
BaseMessage,
|
|
message_to_dict,
|
|
messages_from_dict,
|
|
)
|
|
|
|
|
|
class RocksetChatMessageHistory(BaseChatMessageHistory):
|
|
"""Uses Rockset to store chat messages.
|
|
|
|
To use, ensure that the `rockset` python package installed.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.chat_message_histories import (
|
|
RocksetChatMessageHistory
|
|
)
|
|
from rockset import RocksetClient
|
|
|
|
history = RocksetChatMessageHistory(
|
|
session_id="MySession",
|
|
client=RocksetClient(),
|
|
collection="langchain_demo",
|
|
sync=True
|
|
)
|
|
|
|
history.add_user_message("hi!")
|
|
history.add_ai_message("whats up?")
|
|
|
|
print(history.messages) # noqa: T201
|
|
"""
|
|
|
|
# You should set these values based on your VI.
|
|
# These values are configured for the typical
|
|
# free VI. Read more about VIs here:
|
|
# https://rockset.com/docs/instances
|
|
SLEEP_INTERVAL_MS: int = 5
|
|
ADD_TIMEOUT_MS: int = 5000
|
|
CREATE_TIMEOUT_MS: int = 20000
|
|
|
|
def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None:
|
|
"""Sleeps until meth() evaluates to true. Passes kwargs into
|
|
meth.
|
|
"""
|
|
start = datetime.now()
|
|
while not method(**method_params):
|
|
curr = datetime.now()
|
|
if (curr - start).total_seconds() * 1000 > timeout:
|
|
raise TimeoutError(f"{method} timed out at {timeout} ms")
|
|
sleep(RocksetChatMessageHistory.SLEEP_INTERVAL_MS / 1000)
|
|
|
|
def _query(self, query: str, **query_params: Any) -> List[Any]:
|
|
"""Executes an SQL statement and returns the result
|
|
Args:
|
|
- query: The SQL string
|
|
- **query_params: Parameters to pass into the query
|
|
"""
|
|
return self.client.sql(query, params=query_params).results
|
|
|
|
def _create_collection(self) -> None:
|
|
"""Creates a collection for this message history"""
|
|
self.client.Collections.create_s3_collection(
|
|
name=self.collection, workspace=self.workspace
|
|
)
|
|
|
|
def _collection_exists(self) -> bool:
|
|
"""Checks whether a collection exists for this message history"""
|
|
try:
|
|
self.client.Collections.get(collection=self.collection)
|
|
except self.rockset.exceptions.NotFoundException:
|
|
return False
|
|
return True
|
|
|
|
def _collection_is_ready(self) -> bool:
|
|
"""Checks whether the collection for this message history is ready
|
|
to be queried
|
|
"""
|
|
return (
|
|
self.client.Collections.get(collection=self.collection).data.status
|
|
== "READY"
|
|
)
|
|
|
|
def _document_exists(self) -> bool:
|
|
return (
|
|
len(
|
|
self._query(
|
|
f"""
|
|
SELECT 1
|
|
FROM {self.location}
|
|
WHERE _id=:session_id
|
|
LIMIT 1
|
|
""",
|
|
session_id=self.session_id,
|
|
)
|
|
)
|
|
!= 0
|
|
)
|
|
|
|
def _wait_until_collection_created(self) -> None:
|
|
"""Sleeps until the collection for this message history is ready
|
|
to be queried
|
|
"""
|
|
self._wait_until(
|
|
lambda: self._collection_is_ready(),
|
|
RocksetChatMessageHistory.CREATE_TIMEOUT_MS,
|
|
)
|
|
|
|
def _wait_until_message_added(self, message_id: str) -> None:
|
|
"""Sleeps until a message is added to the messages list"""
|
|
self._wait_until(
|
|
lambda message_id: len(
|
|
self._query(
|
|
f"""
|
|
SELECT *
|
|
FROM UNNEST((
|
|
SELECT {self.messages_key}
|
|
FROM {self.location}
|
|
WHERE _id = :session_id
|
|
)) AS message
|
|
WHERE message.data.additional_kwargs.id = :message_id
|
|
LIMIT 1
|
|
""",
|
|
session_id=self.session_id,
|
|
message_id=message_id,
|
|
),
|
|
)
|
|
!= 0,
|
|
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
|
|
message_id=message_id,
|
|
)
|
|
|
|
def _create_empty_doc(self) -> None:
|
|
"""Creates or replaces a document for this message history with no
|
|
messages"""
|
|
self.client.Documents.add_documents(
|
|
collection=self.collection,
|
|
workspace=self.workspace,
|
|
data=[{"_id": self.session_id, self.messages_key: []}],
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
session_id: str,
|
|
client: Any,
|
|
collection: str,
|
|
workspace: str = "commons",
|
|
messages_key: str = "messages",
|
|
sync: bool = False,
|
|
message_uuid_method: Callable[[], Union[str, int]] = lambda: str(uuid4()),
|
|
) -> None:
|
|
"""Constructs a new RocksetChatMessageHistory.
|
|
|
|
Args:
|
|
- session_id: The ID of the chat session
|
|
- client: The RocksetClient object to use to query
|
|
- collection: The name of the collection to use to store chat
|
|
messages. If a collection with the given name
|
|
does not exist in the workspace, it is created.
|
|
- workspace: The workspace containing `collection`. Defaults
|
|
to `"commons"`
|
|
- messages_key: The DB column containing message history.
|
|
Defaults to `"messages"`
|
|
- sync: Whether to wait for messages to be added. Defaults
|
|
to `False`. NOTE: setting this to `True` will slow
|
|
down performance.
|
|
- message_uuid_method: The method that generates message IDs.
|
|
If set, all messages will have an `id` field within the
|
|
`additional_kwargs` property. If this param is not set
|
|
and `sync` is `False`, message IDs will not be created.
|
|
If this param is not set and `sync` is `True`, the
|
|
`uuid.uuid4` method will be used to create message IDs.
|
|
"""
|
|
try:
|
|
import rockset
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import rockset client python package. "
|
|
"Please install it with `pip install rockset`."
|
|
)
|
|
|
|
if not isinstance(client, rockset.RocksetClient):
|
|
raise ValueError(
|
|
f"client should be an instance of rockset.RocksetClient, "
|
|
f"got {type(client)}"
|
|
)
|
|
|
|
self.session_id = session_id
|
|
self.client = client
|
|
self.collection = collection
|
|
self.workspace = workspace
|
|
self.location = f'"{self.workspace}"."{self.collection}"'
|
|
self.rockset = rockset
|
|
self.messages_key = messages_key
|
|
self.message_uuid_method = message_uuid_method
|
|
self.sync = sync
|
|
|
|
try:
|
|
self.client.set_application("langchain")
|
|
except AttributeError:
|
|
# ignore
|
|
pass
|
|
|
|
if not self._collection_exists():
|
|
self._create_collection()
|
|
self._wait_until_collection_created()
|
|
self._create_empty_doc()
|
|
elif not self._document_exists():
|
|
self._create_empty_doc()
|
|
|
|
@property
|
|
def messages(self) -> List[BaseMessage]: # type: ignore
|
|
"""Messages in this chat history."""
|
|
return messages_from_dict(
|
|
self._query(
|
|
f"""
|
|
SELECT *
|
|
FROM UNNEST ((
|
|
SELECT "{self.messages_key}"
|
|
FROM {self.location}
|
|
WHERE _id = :session_id
|
|
))
|
|
""",
|
|
session_id=self.session_id,
|
|
)
|
|
)
|
|
|
|
def add_message(self, message: BaseMessage) -> None:
|
|
"""Add a Message object to the history.
|
|
|
|
Args:
|
|
message: A BaseMessage object to store.
|
|
"""
|
|
if self.sync and "id" not in message.additional_kwargs:
|
|
message.additional_kwargs["id"] = self.message_uuid_method()
|
|
self.client.Documents.patch_documents(
|
|
collection=self.collection,
|
|
workspace=self.workspace,
|
|
data=[
|
|
self.rockset.model.patch_document.PatchDocument(
|
|
id=self.session_id,
|
|
patch=[
|
|
self.rockset.model.patch_operation.PatchOperation(
|
|
op="ADD",
|
|
path=f"/{self.messages_key}/-",
|
|
value=message_to_dict(message),
|
|
)
|
|
],
|
|
)
|
|
],
|
|
)
|
|
if self.sync:
|
|
self._wait_until_message_added(message.additional_kwargs["id"])
|
|
|
|
def clear(self) -> None:
|
|
"""Removes all messages from the chat history"""
|
|
self._create_empty_doc()
|
|
if self.sync:
|
|
self._wait_until(
|
|
lambda: not self.messages,
|
|
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
|
|
)
|