diff --git a/libs/community/langchain_community/chat_message_histories/cassandra.py b/libs/community/langchain_community/chat_message_histories/cassandra.py index bc3e079465..3eb3d673ac 100644 --- a/libs/community/langchain_community/chat_message_histories/cassandra.py +++ b/libs/community/langchain_community/chat_message_histories/cassandra.py @@ -3,6 +3,7 @@ from __future__ import annotations import json import typing +import uuid from typing import List if typing.TYPE_CHECKING: @@ -41,7 +42,7 @@ class CassandraChatMessageHistory(BaseChatMessageHistory): ttl_seconds: typing.Optional[int] = DEFAULT_TTL_SECONDS, ) -> None: try: - from cassio.history import StoredBlobHistory + from cassio.table import ClusteredCassandraTable except (ImportError, ModuleNotFoundError): raise ImportError( "Could not import cassio python package. " @@ -49,24 +50,39 @@ class CassandraChatMessageHistory(BaseChatMessageHistory): ) self.session_id = session_id self.ttl_seconds = ttl_seconds - self.blob_history = StoredBlobHistory(session, keyspace, table_name) + self.table = ClusteredCassandraTable( + session=session, + keyspace=keyspace, + table=table_name, + ttl_seconds=ttl_seconds, + primary_key_type=["TEXT", "TIMEUUID"], + ordering_in_partition="DESC", + ) @property def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve all session messages from DB""" - message_blobs = self.blob_history.retrieve( - self.session_id, - ) + # The latest are returned, in chronological order + message_blobs = [ + row["body_blob"] + for row in self.table.get_partition( + partition_id=self.session_id, + ) + ][::-1] items = [json.loads(message_blob) for message_blob in message_blobs] messages = messages_from_dict(items) return messages def add_message(self, message: BaseMessage) -> None: """Write a message to the table""" - self.blob_history.store( - self.session_id, json.dumps(message_to_dict(message)), self.ttl_seconds + this_row_id = uuid.uuid1() + self.table.put( + partition_id=self.session_id, + row_id=this_row_id, + body_blob=json.dumps(message_to_dict(message)), + ttl_seconds=self.ttl_seconds, ) def clear(self) -> None: """Clear session memory from DB""" - self.blob_history.clear_session_id(self.session_id) + self.table.delete_partition(self.session_id)