Second Attempt - Add concurrent insertion of vector rows in the Cassandra Vector Store (#7017)

Retrying with the same improvements as in #6772, this time trying not to
mess up with branches.

@rlancemartin doing a fresh new PR from a branch with a new name. This
should do. Thank you for your help!

---------

Co-authored-by: Jonathan Ellis <jbellis@datastax.com>
Co-authored-by: rlm <pexpresss31@gmail.com>
pull/7028/head
Stefano Lottini 1 year ago committed by GitHub
parent 3bfe7cf467
commit 8d2281a8ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,10 +1,10 @@
# Cassandra # Cassandra
>[Apache Cassandra®](https://cassandra.apache.org/) is a free and open-source, distributed, wide-column >[Apache Cassandra®](https://cassandra.apache.org/) is a free and open-source, distributed, wide-column
> store, NoSQL database management system designed to handle large amounts of data across many commodity servers, > store, NoSQL database management system designed to handle large amounts of data across many commodity servers,
> providing high availability with no single point of failure. Cassandra offers support for clusters spanning > providing high availability with no single point of failure. Cassandra offers support for clusters spanning
> multiple datacenters, with asynchronous masterless replication allowing low latency operations for all clients. > multiple datacenters, with asynchronous masterless replication allowing low latency operations for all clients.
> Cassandra was designed to implement a combination of _Amazon's Dynamo_ distributed storage and replication > Cassandra was designed to implement a combination of _Amazon's Dynamo_ distributed storage and replication
> techniques combined with _Google's Bigtable_ data and storage engine model. > techniques combined with _Google's Bigtable_ data and storage engine model.
## Installation and Setup ## Installation and Setup
@ -16,6 +16,16 @@ pip install cassio
## Vector Store
See a [usage example](/docs/modules/data_connection/vectorstores/integrations/cassandra.html).
```python
from langchain.memory import CassandraChatMessageHistory
```
## Memory ## Memory
See a [usage example](/docs/modules/memory/integrations/cassandra_chat_message_history.html). See a [usage example](/docs/modules/memory/integrations/cassandra_chat_message_history.html).

@ -23,7 +23,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install \"cassio>=0.0.5\"" "!pip install \"cassio>=0.0.7\""
] ]
}, },
{ {
@ -44,14 +44,16 @@
"import os\n", "import os\n",
"import getpass\n", "import getpass\n",
"\n", "\n",
"database_mode = (input('\\n(L)ocal Cassandra or (A)stra DB? ')).upper()\n", "database_mode = (input('\\n(C)assandra or (A)stra DB? ')).upper()\n",
"\n", "\n",
"keyspace_name = input('\\nKeyspace name? ')\n", "keyspace_name = input('\\nKeyspace name? ')\n",
"\n", "\n",
"if database_mode == 'A':\n", "if database_mode == 'A':\n",
" ASTRA_DB_APPLICATION_TOKEN = getpass.getpass('\\nAstra DB Token (\"AstraCS:...\") ')\n", " ASTRA_DB_APPLICATION_TOKEN = getpass.getpass('\\nAstra DB Token (\"AstraCS:...\") ')\n",
" #\n", " #\n",
" ASTRA_DB_SECURE_BUNDLE_PATH = input('Full path to your Secure Connect Bundle? ')" " ASTRA_DB_SECURE_BUNDLE_PATH = input('Full path to your Secure Connect Bundle? ')\n",
"elif database_mode == 'C':\n",
" CASSANDRA_CONTACT_POINTS = input('Contact points? (comma-separated, empty for localhost) ').strip()"
] ]
}, },
{ {
@ -72,8 +74,15 @@
"from cassandra.cluster import Cluster\n", "from cassandra.cluster import Cluster\n",
"from cassandra.auth import PlainTextAuthProvider\n", "from cassandra.auth import PlainTextAuthProvider\n",
"\n", "\n",
"if database_mode == 'L':\n", "if database_mode == 'C':\n",
" cluster = Cluster()\n", " if CASSANDRA_CONTACT_POINTS:\n",
" cluster = Cluster([\n",
" cp.strip()\n",
" for cp in CASSANDRA_CONTACT_POINTS.split(',')\n",
" if cp.strip()\n",
" ])\n",
" else:\n",
" cluster = Cluster()\n",
" session = cluster.connect()\n", " session = cluster.connect()\n",
"elif database_mode == 'A':\n", "elif database_mode == 'A':\n",
" ASTRA_DB_CLIENT_ID = \"token\"\n", " ASTRA_DB_CLIENT_ID = \"token\"\n",
@ -261,7 +270,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.10" "version": "3.10.6"
} }
}, },
"nbformat": 4, "nbformat": 4,

@ -1,8 +1,8 @@
"""Wrapper around Cassandra vector-store capabilities, based on cassIO.""" """Wrapper around Cassandra vector-store capabilities, based on cassIO."""
from __future__ import annotations from __future__ import annotations
import hashlib
import typing import typing
import uuid
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar
import numpy as np import numpy as np
@ -17,14 +17,6 @@ from langchain.vectorstores.utils import maximal_marginal_relevance
CVST = TypeVar("CVST", bound="Cassandra") CVST = TypeVar("CVST", bound="Cassandra")
# a positive number of seconds to expire entries, or None for no expiration.
CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS = None
def _hash(_input: str) -> str:
"""Use a deterministic hashing approach."""
return hashlib.md5(_input.encode()).hexdigest()
class Cassandra(VectorStore): class Cassandra(VectorStore):
"""Wrapper around Cassandra embeddings platform. """Wrapper around Cassandra embeddings platform.
@ -46,7 +38,7 @@ class Cassandra(VectorStore):
_embedding_dimension: int | None _embedding_dimension: int | None
def _getEmbeddingDimension(self) -> int: def _get_embedding_dimension(self) -> int:
if self._embedding_dimension is None: if self._embedding_dimension is None:
self._embedding_dimension = len( self._embedding_dimension = len(
self.embedding.embed_query("This is a sample sentence.") self.embedding.embed_query("This is a sample sentence.")
@ -59,7 +51,7 @@ class Cassandra(VectorStore):
session: Session, session: Session,
keyspace: str, keyspace: str,
table_name: str, table_name: str,
ttl_seconds: int | None = CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
try: try:
from cassio.vector import VectorTable from cassio.vector import VectorTable
@ -81,8 +73,8 @@ class Cassandra(VectorStore):
session=session, session=session,
keyspace=keyspace, keyspace=keyspace,
table=table_name, table=table_name,
embedding_dimension=self._getEmbeddingDimension(), embedding_dimension=self._get_embedding_dimension(),
auto_id=False, # the `add_texts` contract admits user-provided ids primary_key_type="TEXT",
) )
def delete_collection(self) -> None: def delete_collection(self) -> None:
@ -99,11 +91,27 @@ class Cassandra(VectorStore):
def delete_by_document_id(self, document_id: str) -> None: def delete_by_document_id(self, document_id: str) -> None:
return self.table.delete(document_id) return self.table.delete(document_id)
def delete(self, ids: List[str]) -> Optional[bool]:
"""Delete by vector ID.
Args:
ids: List of ids to delete.
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
for document_id in ids:
self.delete_by_document_id(document_id)
return True
def add_texts( def add_texts(
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
batch_size: int = 16,
ttl_seconds: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore. """Run more texts through the embeddings and add to the vectorstore.
@ -112,33 +120,39 @@ class Cassandra(VectorStore):
texts (Iterable[str]): Texts to add to the vectorstore. texts (Iterable[str]): Texts to add to the vectorstore.
metadatas (Optional[List[dict]], optional): Optional list of metadatas. metadatas (Optional[List[dict]], optional): Optional list of metadatas.
ids (Optional[List[str]], optional): Optional list of IDs. ids (Optional[List[str]], optional): Optional list of IDs.
batch_size (int): Number of concurrent requests to send to the server.
ttl_seconds (Optional[int], optional): Optional time-to-live
for the added texts.
Returns: Returns:
List[str]: List of IDs of the added texts. List[str]: List of IDs of the added texts.
""" """
_texts = list(texts) # lest it be a generator or something _texts = list(texts) # lest it be a generator or something
if ids is None: if ids is None:
# unless otherwise specified, we have deterministic IDs: ids = [uuid.uuid4().hex for _ in _texts]
# re-inserting an existing document will not create a duplicate.
# (and effectively update the metadata)
ids = [_hash(text) for text in _texts]
if metadatas is None: if metadatas is None:
metadatas = [{} for _ in _texts] metadatas = [{} for _ in _texts]
# #
ttl_seconds = kwargs.get("ttl_seconds", self.ttl_seconds) ttl_seconds = ttl_seconds or self.ttl_seconds
# #
embedding_vectors = self.embedding.embed_documents(_texts) embedding_vectors = self.embedding.embed_documents(_texts)
for text, embedding_vector, text_id, metadata in zip(
_texts, embedding_vectors, ids, metadatas
):
self.table.put(
document=text,
embedding_vector=embedding_vector,
document_id=text_id,
metadata=metadata,
ttl_seconds=ttl_seconds,
)
# #
for i in range(0, len(_texts), batch_size):
batch_texts = _texts[i : i + batch_size]
batch_embedding_vectors = embedding_vectors[i : i + batch_size]
batch_ids = ids[i : i + batch_size]
batch_metadatas = metadatas[i : i + batch_size]
futures = [
self.table.put_async(
text, embedding_vector, text_id, metadata, ttl_seconds
)
for text, embedding_vector, text_id, metadata in zip(
batch_texts, batch_embedding_vectors, batch_ids, batch_metadatas
)
]
for future in futures:
future.result()
return ids return ids
# id-returning search facilities # id-returning search facilities
@ -181,7 +195,6 @@ class Cassandra(VectorStore):
self, self,
query: str, query: str,
k: int = 4, k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float, str]]: ) -> List[Tuple[Document, float, str]]:
embedding_vector = self.embedding.embed_query(query) embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_id_by_vector( return self.similarity_search_with_score_id_by_vector(
@ -219,12 +232,10 @@ class Cassandra(VectorStore):
k: int = 4, k: int = 4,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
#
embedding_vector = self.embedding.embed_query(query) embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_by_vector( return self.similarity_search_by_vector(
embedding_vector, embedding_vector,
k, k,
**kwargs,
) )
def similarity_search_by_vector( def similarity_search_by_vector(
@ -245,7 +256,6 @@ class Cassandra(VectorStore):
self, self,
query: str, query: str,
k: int = 4, k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
embedding_vector = self.embedding.embed_query(query) embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_by_vector( return self.similarity_search_with_score_by_vector(
@ -266,7 +276,6 @@ class Cassandra(VectorStore):
return self.similarity_search_with_score( return self.similarity_search_with_score(
query, query,
k, k,
**kwargs,
) )
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector(
@ -352,6 +361,7 @@ class Cassandra(VectorStore):
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
batch_size: int = 16,
**kwargs: Any, **kwargs: Any,
) -> CVST: ) -> CVST:
"""Create a Cassandra vectorstore from raw texts. """Create a Cassandra vectorstore from raw texts.
@ -378,6 +388,7 @@ class Cassandra(VectorStore):
cls: Type[CVST], cls: Type[CVST],
documents: List[Document], documents: List[Document],
embedding: Embeddings, embedding: Embeddings,
batch_size: int = 16,
**kwargs: Any, **kwargs: Any,
) -> CVST: ) -> CVST:
"""Create a Cassandra vectorstore from a document list. """Create a Cassandra vectorstore from a document list.

552
poetry.lock generated

File diff suppressed because it is too large Load Diff

@ -113,7 +113,7 @@ esprima = {version = "^4.0.1", optional = true}
openllm = {version = ">=0.1.19", optional = true} openllm = {version = ">=0.1.19", optional = true}
streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"} streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"}
psychicapi = {version = "^0.8.0", optional = true} psychicapi = {version = "^0.8.0", optional = true}
cassio = {version = "^0.0.6", optional = true} cassio = {version = "^0.0.7", optional = true}
[tool.poetry.group.docs.dependencies] [tool.poetry.group.docs.dependencies]
autodoc_pydantic = "^1.8.0" autodoc_pydantic = "^1.8.0"
@ -188,7 +188,7 @@ gptcache = "^0.1.9"
promptlayer = "^0.1.80" promptlayer = "^0.1.80"
tair = "^1.3.3" tair = "^1.3.3"
wikipedia = "^1" wikipedia = "^1"
cassio = "^0.0.6" cassio = "^0.0.7"
arxiv = "^1.4" arxiv = "^1.4"
mastodon-py = "^1.8.1" mastodon-py = "^1.8.1"
momento = "^1.5.0" momento = "^1.5.0"

@ -84,7 +84,7 @@ def test_cassandra_max_marginal_relevance_search() -> None:
With fetch_k==3 and k==2, when query is at (1, ), With fetch_k==3 and k==2, when query is at (1, ),
one expects that v2 and v0 are returned (in some order). one expects that v2 and v0 are returned (in some order).
""" """
texts = ["-0.125", "+0.125", "+0.25", "+1.0"] texts = ["-0.124", "+0.127", "+0.25", "+1.0"]
metadatas = [{"page": i} for i in range(len(texts))] metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _vectorstore_from_texts( docsearch = _vectorstore_from_texts(
texts, metadatas=metadatas, embedding_class=AngularTwoDimensionalEmbeddings texts, metadatas=metadatas, embedding_class=AngularTwoDimensionalEmbeddings
@ -95,7 +95,7 @@ def test_cassandra_max_marginal_relevance_search() -> None:
} }
assert output_set == { assert output_set == {
("+0.25", 2), ("+0.25", 2),
("-0.125", 0), ("-0.124", 0),
} }
@ -105,9 +105,9 @@ def test_cassandra_add_extra() -> None:
metadatas = [{"page": i} for i in range(len(texts))] metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _vectorstore_from_texts(texts, metadatas=metadatas) docsearch = _vectorstore_from_texts(texts, metadatas=metadatas)
docsearch.add_texts(texts, metadatas)
texts2 = ["foo2", "bar2", "baz2"] texts2 = ["foo2", "bar2", "baz2"]
docsearch.add_texts(texts2, metadatas) metadatas2 = [{"page": i + 3} for i in range(len(texts))]
docsearch.add_texts(texts2, metadatas2)
output = docsearch.similarity_search("foo", k=10) output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6 assert len(output) == 6
@ -127,9 +127,37 @@ def test_cassandra_no_drop() -> None:
assert len(output) == 6 assert len(output) == 6
def test_cassandra_delete() -> None:
"""Test delete methods from vector store."""
texts = ["foo", "bar", "baz", "gni"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _vectorstore_from_texts([], metadatas=metadatas)
ids = docsearch.add_texts(texts, metadatas)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 4
docsearch.delete_by_document_id(ids[0])
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 3
docsearch.delete(ids[1:3])
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 1
docsearch.delete(["not-existing"])
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 1
docsearch.clear()
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 0
# if __name__ == "__main__": # if __name__ == "__main__":
# test_cassandra() # test_cassandra()
# test_cassandra_with_score() # test_cassandra_with_score()
# test_cassandra_max_marginal_relevance_search() # test_cassandra_max_marginal_relevance_search()
# test_cassandra_add_extra() # test_cassandra_add_extra()
# test_cassandra_no_drop() # test_cassandra_no_drop()
# test_cassandra_delete()

Loading…
Cancel
Save