Cassandra support for chat history using CassIO library (#6771)

### Overview

This PR aims at building on #4378, expanding the capabilities and
building on top of the `cassIO` library to interface with the database
(as opposed to using the core drivers directly).

Usage of `cassIO` (a library abstracting Cassandra access for
ML/GenAI-specific purposes) is already established since #6426 was
merged, so no new dependencies are introduced.

In the same spirit, we try to uniform the interface for using Cassandra
instances throughout LangChain: all our appreciation of the work by
@jj701 notwithstanding, who paved the way for this incremental work
(thank you!), we identified a few reasons for changing the way a
`CassandraChatMessageHistory` is instantiated. Advocating a syntax
change is something we don't take lighthearted way, so we add some
explanations about this below.

Additionally, this PR expands on integration testing, enables use of
Cassandra's native Time-to-Live (TTL) features and improves the phrasing
around the notebook example and the short "integrations" documentation
paragraph.

We would kindly request @hwchase to review (since this is an elaboration
and proposed improvement of #4378 who had the same reviewer).

### About the __init__ breaking changes

There are
[many](https://docs.datastax.com/en/developer/python-driver/3.28/api/cassandra/cluster/)
options when creating the `Cluster` object, and new ones might be added
at any time. Choosing some of them and exposing them as `__init__`
parameters `CassandraChatMessageHistory` will prove to be insufficient
for at least some users.

On the other hand, working through `kwargs` or adding a long, long list
of arguments to `__init__` is not a desirable option either. For this
reason, (as done in #6426), we propose that whoever instantiates the
Chat Message History class provide a Cassandra `Session` object, ready
to use. This also enables easier injection of mocks and usage of
Cassandra-compatible connections (such as those to the cloud database
DataStax Astra DB, obtained with a different set of init parameters than
`contact_points` and `port`).

We feel that a breaking change might still be acceptable since LangChain
is at `0.*`. However, while maintaining that the approach we propose
will be more flexible in the future, room could be made for a
"compatibility layer" that respects the current init method. Honestly,
we would to that only if there are strong reasons for it, as that would
entail an additional maintenance burden.

### Other changes

We propose to remove the keyspace creation from the class code for two
reasons: first, production Cassandra instances often employ RBAC so that
the database user reading/writing from tables does not necessarily (and
generally shouldn't) have permission to create keyspaces, and second
that programmatic keyspace creation is not a best practice (it should be
done more or less manually, with extra care about schema mismatched
among nodes, etc). Removing this (usually unnecessary) operation from
the `__init__` path would also improve initialization performance
(shorter time).

We suggest, likewise, to remove the `__del__` method (which would close
the database connection), for the following reason: it is the
recommended best practice to create a single Cassandra `Session` object
throughout an application (it is a resource-heavy object capable to
handle concurrency internally), so in case Cassandra is used in other
ways by the app there is the risk of truncating the connection for all
usages when the history instance is destroyed. Moreover, the `Session`
object, in typical applications, is best left to garbage-collect itself
automatically.

As mentioned above, we defer the actual database I/O to the `cassIO`
library, which is designed to encode practices optimized for LLM
applications (among other) without the need to expose LangChain
developers to the internals of CQL (Cassandra Query Language). CassIO is
already employed by the LangChain's Vector Store support for Cassandra.

We added a few more connection options in the companion notebook example
(most notably, Astra DB) to encourage usage by anyone who cannot run
their own Cassandra cluster.

We surface the `ttl_seconds` option for automatic handling of an
expiration time to chat history messages, a likely useful feature given
that very old messages generally may lose their importance.

We elaborated a bit more on the integration testing (Time-to-live,
separation of "session ids", ...).

### Remarks from linter & co.

We reinstated `cassio` as a dependency both in the "optional" group and
in the "integration testing" group of `pyproject.toml`. This might not
be the right thing do to, in which case the author of this PR offer his
apologies (lack of confidence with Poetry - happy to be pointed in the
right direction, though!).

During linter tests, we were hit by some errors which appear unrelated
to the code in the PR. We left them here and report on them here for
awareness:

```
langchain/vectorstores/mongodb_atlas.py:137: error: Argument 1 to "insert_many" of "Collection" has incompatible type "List[Dict[str, Sequence[object]]]"; expected "Iterable[Union[MongoDBDocumentType, RawBSONDocument]]"  [arg-type]
langchain/vectorstores/mongodb_atlas.py:186: error: Argument 1 to "aggregate" of "Collection" has incompatible type "List[object]"; expected "Sequence[Mapping[str, Any]]"  [arg-type]

langchain/vectorstores/qdrant.py:16: error: Name "grpc" is not defined  [name-defined]
langchain/vectorstores/qdrant.py:19: error: Name "grpc" is not defined  [name-defined]
langchain/vectorstores/qdrant.py:20: error: Name "grpc" is not defined  [name-defined]
langchain/vectorstores/qdrant.py:22: error: Name "grpc" is not defined  [name-defined]
langchain/vectorstores/qdrant.py:23: error: Name "grpc" is not defined  [name-defined]
```

In the same spirit, we observe that to even get `import langchain` run,
it seems that a `pip install bs4` is missing from the minimal package
installation path.

Thank you!
pull/6114/head
Stefano Lottini 1 year ago committed by GitHub
parent f5663603cf
commit 75fb9d2fdc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,19 +1,21 @@
# Cassandra
>[Cassandra](https://en.wikipedia.org/wiki/Apache_Cassandra) is a free and open-source, distributed, wide-column
>[Apache Cassandra®](https://cassandra.apache.org/) is a free and open-source, distributed, wide-column
> store, NoSQL database management system designed to handle large amounts of data across many commodity servers,
> providing high availability with no single point of failure. `Cassandra` offers support for clusters spanning
> providing high availability with no single point of failure. Cassandra offers support for clusters spanning
> multiple datacenters, with asynchronous masterless replication allowing low latency operations for all clients.
> `Cassandra` was designed to implement a combination of `Amazon's Dynamo` distributed storage and replication
> techniques combined with `Google's Bigtable` data and storage engine model.
> Cassandra was designed to implement a combination of _Amazon's Dynamo_ distributed storage and replication
> techniques combined with _Google's Bigtable_ data and storage engine model.
## Installation and Setup
```bash
pip install cassandra-drive
pip install cassandra-driver
pip install cassio
```
## Memory
See a [usage example](/docs/modules/memory/integrations/cassandra_chat_message_history.html).

@ -1,34 +1,116 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "91c6a7ef",
"id": "90cd3ded",
"metadata": {},
"source": [
"# Cassandra Chat Message History\n",
"\n",
">[Apache Cassandra®](https://cassandra.apache.org) is a NoSQL, row-oriented, highly scalable and highly available database, well suited for storing large amounts of data.\n",
"\n",
"Cassandra is a good choice for storing chat message history because it is easy to scale and can handle a large number of writes.\n",
"\n",
"This notebook goes over how to use Cassandra to store chat message history.\n",
"\n",
"Cassandra is a distributed database that is well suited for storing large amounts of data. \n",
"To run this notebook you need either a running Cassandra cluster or a DataStax Astra DB instance running in the cloud (you can get one for free at [datastax.com](https://astra.datastax.com)). Check [cassio.org](https://cassio.org/start_here/) for more information."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d7092199",
"metadata": {},
"outputs": [],
"source": [
"!pip install \"cassio>=0.0.6\""
]
},
{
"cell_type": "markdown",
"id": "e3d97b65",
"metadata": {},
"source": [
"### Please provide database connection parameters and secrets:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "163d97f0",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
"\n",
"database_mode = (input('\\n(C)assandra or (A)stra DB? ')).upper()\n",
"\n",
"keyspace_name = input('\\nKeyspace name? ')\n",
"\n",
"It is a good choice for storing chat message history because it is easy to scale and can handle a large number of writes.\n"
"if database_mode == 'A':\n",
" ASTRA_DB_APPLICATION_TOKEN = getpass.getpass('\\nAstra DB Token (\"AstraCS:...\") ')\n",
" #\n",
" ASTRA_DB_SECURE_BUNDLE_PATH = input('Full path to your Secure Connect Bundle? ')\n",
"elif database_mode == 'C':\n",
" CASSANDRA_CONTACT_POINTS = input('Contact points? (comma-separated, empty for localhost) ').strip()"
]
},
{
"cell_type": "markdown",
"id": "55860b2d",
"metadata": {},
"source": [
"#### depending on whether local or cloud-based Astra DB, create the corresponding database connection \"Session\" object"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "47a601d2",
"execution_count": null,
"id": "8dff2798",
"metadata": {},
"outputs": [],
"source": [
"# List of contact points to try connecting to Cassandra cluster.\n",
"contact_points = [\"cassandra\"]"
"from cassandra.cluster import Cluster\n",
"from cassandra.auth import PlainTextAuthProvider\n",
"\n",
"if database_mode == 'C':\n",
" if CASSANDRA_CONTACT_POINTS:\n",
" cluster = Cluster([\n",
" cp.strip()\n",
" for cp in CASSANDRA_CONTACT_POINTS.split(',')\n",
" if cp.strip()\n",
" ])\n",
" else:\n",
" cluster = Cluster()\n",
" session = cluster.connect()\n",
"elif database_mode == 'A':\n",
" ASTRA_DB_CLIENT_ID = \"token\"\n",
" cluster = Cluster(\n",
" cloud={\n",
" \"secure_connect_bundle\": ASTRA_DB_SECURE_BUNDLE_PATH,\n",
" },\n",
" auth_provider=PlainTextAuthProvider(\n",
" ASTRA_DB_CLIENT_ID,\n",
" ASTRA_DB_APPLICATION_TOKEN,\n",
" ),\n",
" )\n",
" session = cluster.connect()\n",
"else:\n",
" raise NotImplementedError"
]
},
{
"cell_type": "markdown",
"id": "36c163e8",
"metadata": {},
"source": [
"### Creation and usage of the Chat Message History"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "d15e3302",
"metadata": {},
"outputs": [],
@ -36,7 +118,9 @@
"from langchain.memory import CassandraChatMessageHistory\n",
"\n",
"message_history = CassandraChatMessageHistory(\n",
" contact_points=contact_points, session_id=\"test-session\"\n",
" session_id=\"test-session\",\n",
" session=session,\n",
" keyspace=keyspace_name,\n",
")\n",
"\n",
"message_history.add_user_message(\"hi!\")\n",
@ -46,22 +130,10 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "64fc465e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n",
" AIMessage(content='whats up?', additional_kwargs={}, example=False)]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"message_history.messages"
]
@ -83,7 +155,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.10.6"
}
},
"nbformat": 4,

@ -1,7 +1,13 @@
"""Cassandra-based chat message history, based on cassIO."""
from __future__ import annotations
import json
import logging
import typing
from typing import List
if typing.TYPE_CHECKING:
from cassandra.cluster import Session
from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
@ -9,171 +15,58 @@ from langchain.schema import (
messages_from_dict,
)
logger = logging.getLogger(__name__)
DEFAULT_KEYSPACE_NAME = "chat_history"
DEFAULT_TABLE_NAME = "message_store"
DEFAULT_USERNAME = "cassandra"
DEFAULT_PASSWORD = "cassandra"
DEFAULT_PORT = 9042
DEFAULT_TTL_SECONDS = None
class CassandraChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in Cassandra.
Args:
contact_points: list of ips to connect to Cassandra cluster
session_id: arbitrary key that is used to store the messages
of a single chat session.
port: port to connect to Cassandra cluster
username: username to connect to Cassandra cluster
password: password to connect to Cassandra cluster
keyspace_name: name of the keyspace to use
table_name: name of the table to use
session: a Cassandra `Session` object (an open DB connection)
keyspace: name of the keyspace to use.
table_name: name of the table to use.
ttl_seconds: time-to-live (seconds) for automatic expiration
of stored entries. None (default) for no expiration.
"""
def __init__(
self,
contact_points: List[str],
session_id: str,
port: int = DEFAULT_PORT,
username: str = DEFAULT_USERNAME,
password: str = DEFAULT_PASSWORD,
keyspace_name: str = DEFAULT_KEYSPACE_NAME,
session: Session,
keyspace: str,
table_name: str = DEFAULT_TABLE_NAME,
):
self.contact_points = contact_points
self.session_id = session_id
self.port = port
self.username = username
self.password = password
self.keyspace_name = keyspace_name
self.table_name = table_name
ttl_seconds: int | None = DEFAULT_TTL_SECONDS,
) -> None:
try:
from cassandra import (
AuthenticationFailed,
OperationTimedOut,
UnresolvableContactPoints,
)
from cassandra.cluster import Cluster, PlainTextAuthProvider
except ImportError:
from cassio.history import StoredBlobHistory
except (ImportError, ModuleNotFoundError):
raise ValueError(
"Could not import cassandra-driver python package. "
"Please install it with `pip install cassandra-driver`."
)
self.cluster: Cluster = Cluster(
contact_points,
port=port,
auth_provider=PlainTextAuthProvider(
username=self.username, password=self.password
),
)
try:
self.session = self.cluster.connect()
except (
AuthenticationFailed,
UnresolvableContactPoints,
OperationTimedOut,
) as error:
logger.error(
"Unable to establish connection with \
cassandra chat message history database"
)
raise error
self._prepare_cassandra()
def _prepare_cassandra(self) -> None:
"""Create the keyspace and table if they don't exist yet"""
from cassandra import OperationTimedOut, Unavailable
try:
self.session.execute(
f"""CREATE KEYSPACE IF NOT EXISTS
{self.keyspace_name} WITH REPLICATION =
{{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 }};"""
)
except (OperationTimedOut, Unavailable) as error:
logger.error(
f"Unable to create cassandra \
chat message history keyspace: {self.keyspace_name}."
)
raise error
self.session.set_keyspace(self.keyspace_name)
try:
self.session.execute(
f"""CREATE TABLE IF NOT EXISTS
{self.table_name} (id UUID, session_id varchar,
history text, PRIMARY KEY ((session_id), id) );"""
)
except (OperationTimedOut, Unavailable) as error:
logger.error(
f"Unable to create cassandra \
chat message history table: {self.table_name}"
"Could not import cassio python package. "
"Please install it with `pip install cassio`."
)
raise error
self.session_id = session_id
self.ttl_seconds = ttl_seconds
self.blob_history = StoredBlobHistory(session, keyspace, table_name)
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from Cassandra"""
from cassandra import ReadFailure, ReadTimeout, Unavailable
try:
rows = self.session.execute(
f"""SELECT * FROM {self.table_name}
WHERE session_id = '{self.session_id}' ;"""
)
except (Unavailable, ReadTimeout, ReadFailure) as error:
logger.error("Unable to Retreive chat history messages from cassadra")
raise error
if rows:
items = [json.loads(row.history) for row in rows]
else:
items = []
"""Retrieve all session messages from DB"""
message_blobs = self.blob_history.retrieve(
self.session_id,
)
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in Cassandra"""
import uuid
from cassandra import Unavailable, WriteFailure, WriteTimeout
try:
self.session.execute(
"""INSERT INTO message_store
(id, session_id, history) VALUES (%s, %s, %s);""",
(uuid.uuid4(), self.session_id, json.dumps(_message_to_dict(message))),
)
except (Unavailable, WriteTimeout, WriteFailure) as error:
logger.error("Unable to write chat history messages to cassandra")
raise error
"""Write a message to the table"""
self.blob_history.store(
self.session_id, json.dumps(_message_to_dict(message)), self.ttl_seconds
)
def clear(self) -> None:
"""Clear session memory from Cassandra"""
from cassandra import OperationTimedOut, Unavailable
try:
self.session.execute(
f"DELETE FROM {self.table_name} WHERE session_id = '{self.session_id}';"
)
except (Unavailable, OperationTimedOut) as error:
logger.error("Unable to clear chat history messages from cassandra")
raise error
def __del__(self) -> None:
if self.session:
self.session.shutdown()
if self.cluster:
self.cluster.shutdown()
"""Clear session memory from DB"""
self.blob_history.clear_session_id(self.session_id)

575
poetry.lock generated

File diff suppressed because it is too large Load Diff

@ -113,6 +113,7 @@ esprima = {version = "^4.0.1", optional = true}
openllm = {version = ">=0.1.6", optional = true}
streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"}
psychicapi = {version = "^0.8.0", optional = true}
cassio = {version = "^0.0.6", optional = true}
[tool.poetry.group.docs.dependencies]
autodoc_pydantic = "^1.8.0"
@ -187,7 +188,7 @@ gptcache = "^0.1.9"
promptlayer = "^0.1.80"
tair = "^1.3.3"
wikipedia = "^1"
cassandra-driver = "^3.27.0"
cassio = "^0.0.6"
arxiv = "^1.4"
mastodon-py = "^1.8.1"
momento = "^1.5.0"
@ -316,6 +317,7 @@ all = [
extended_testing = [
"beautifulsoup4",
"bibtexparser",
"cassio",
"chardet",
"esprima",
"jq",

@ -1,42 +1,118 @@
import json
import os
import time
from typing import Optional
from cassandra.cluster import Cluster
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories.cassandra import (
CassandraChatMessageHistory,
)
from langchain.schema import _message_to_dict
# Replace these with your cassandra contact points
contact_points = (
os.environ["CONTACT_POINTS"].split(",")
if "CONTACT_POINTS" in os.environ
else ["cassandra"]
from langchain.schema import (
AIMessage,
HumanMessage,
)
def _chat_message_history(
session_id: str = "test-session",
drop: bool = True,
ttl_seconds: Optional[int] = None,
) -> CassandraChatMessageHistory:
keyspace = "cmh_test_keyspace"
table_name = "cmh_test_table"
# get db connection
if "CASSANDRA_CONTACT_POINTS" in os.environ:
contact_points = os.environ["CONTACT_POINTS"].split(",")
cluster = Cluster(contact_points)
else:
cluster = Cluster()
#
session = cluster.connect()
# ensure keyspace exists
session.execute(
(
f"CREATE KEYSPACE IF NOT EXISTS {keyspace} "
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
)
)
# drop table if required
if drop:
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{table_name}")
#
return CassandraChatMessageHistory(
session_id=session_id,
session=session,
keyspace=keyspace,
table_name=table_name,
**({} if ttl_seconds is None else {"ttl_seconds": ttl_seconds}),
)
def test_memory_with_message_store() -> None:
"""Test the memory with a message store."""
# setup cassandra as a message store
message_history = CassandraChatMessageHistory(
contact_points=contact_points, session_id="test-session"
)
message_history = _chat_message_history()
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
memory_key="baz",
chat_memory=message_history,
return_messages=True,
)
assert memory.chat_memory.messages == []
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
expected = [
AIMessage(content="This is me, the AI"),
HumanMessage(content="This is me, the human"),
]
assert messages == expected
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# remove the record from Cassandra, so the next test run won't pick it up
# clear the store
memory.chat_memory.clear()
assert memory.chat_memory.messages == []
def test_memory_separate_session_ids() -> None:
"""Test that separate session IDs do not share entries."""
message_history1 = _chat_message_history(session_id="test-session1")
memory1 = ConversationBufferMemory(
memory_key="mk1",
chat_memory=message_history1,
return_messages=True,
)
message_history2 = _chat_message_history(session_id="test-session2")
memory2 = ConversationBufferMemory(
memory_key="mk2",
chat_memory=message_history2,
return_messages=True,
)
memory1.chat_memory.add_ai_message("Just saying.")
assert memory2.chat_memory.messages == []
memory1.chat_memory.clear()
memory2.chat_memory.clear()
def test_memory_ttl() -> None:
"""Test time-to-live feature of the memory."""
message_history = _chat_message_history(ttl_seconds=5)
memory = ConversationBufferMemory(
memory_key="baz",
chat_memory=message_history,
return_messages=True,
)
#
assert memory.chat_memory.messages == []
memory.chat_memory.add_ai_message("Nothing special here.")
time.sleep(2)
assert memory.chat_memory.messages != []
time.sleep(5)
assert memory.chat_memory.messages == []

Loading…
Cancel
Save