Second Attempt - Add concurrent insertion of vector rows in the Cassandra Vector Store (#7017)

Retrying with the same improvements as in #6772, this time trying not to
mess up with branches.

@rlancemartin doing a fresh new PR from a branch with a new name. This
should do. Thank you for your help!

---------

Co-authored-by: Jonathan Ellis <jbellis@datastax.com>
Co-authored-by: rlm <pexpresss31@gmail.com>
pull/7028/head
Stefano Lottini 1 year ago committed by GitHub
parent 3bfe7cf467
commit 8d2281a8ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,10 +1,10 @@
# Cassandra
>[Apache Cassandra®](https://cassandra.apache.org/) is a free and open-source, distributed, wide-column
>[Apache Cassandra®](https://cassandra.apache.org/) is a free and open-source, distributed, wide-column
> store, NoSQL database management system designed to handle large amounts of data across many commodity servers,
> providing high availability with no single point of failure. Cassandra offers support for clusters spanning
> providing high availability with no single point of failure. Cassandra offers support for clusters spanning
> multiple datacenters, with asynchronous masterless replication allowing low latency operations for all clients.
> Cassandra was designed to implement a combination of _Amazon's Dynamo_ distributed storage and replication
> Cassandra was designed to implement a combination of _Amazon's Dynamo_ distributed storage and replication
> techniques combined with _Google's Bigtable_ data and storage engine model.
## Installation and Setup
@ -16,6 +16,16 @@ pip install cassio
## Vector Store
See a [usage example](/docs/modules/data_connection/vectorstores/integrations/cassandra.html).
```python
from langchain.memory import CassandraChatMessageHistory
```
## Memory
See a [usage example](/docs/modules/memory/integrations/cassandra_chat_message_history.html).

@ -23,7 +23,7 @@
},
"outputs": [],
"source": [
"!pip install \"cassio>=0.0.5\""
"!pip install \"cassio>=0.0.7\""
]
},
{
@ -44,14 +44,16 @@
"import os\n",
"import getpass\n",
"\n",
"database_mode = (input('\\n(L)ocal Cassandra or (A)stra DB? ')).upper()\n",
"database_mode = (input('\\n(C)assandra or (A)stra DB? ')).upper()\n",
"\n",
"keyspace_name = input('\\nKeyspace name? ')\n",
"\n",
"if database_mode == 'A':\n",
" ASTRA_DB_APPLICATION_TOKEN = getpass.getpass('\\nAstra DB Token (\"AstraCS:...\") ')\n",
" #\n",
" ASTRA_DB_SECURE_BUNDLE_PATH = input('Full path to your Secure Connect Bundle? ')"
" ASTRA_DB_SECURE_BUNDLE_PATH = input('Full path to your Secure Connect Bundle? ')\n",
"elif database_mode == 'C':\n",
" CASSANDRA_CONTACT_POINTS = input('Contact points? (comma-separated, empty for localhost) ').strip()"
]
},
{
@ -72,8 +74,15 @@
"from cassandra.cluster import Cluster\n",
"from cassandra.auth import PlainTextAuthProvider\n",
"\n",
"if database_mode == 'L':\n",
" cluster = Cluster()\n",
"if database_mode == 'C':\n",
" if CASSANDRA_CONTACT_POINTS:\n",
" cluster = Cluster([\n",
" cp.strip()\n",
" for cp in CASSANDRA_CONTACT_POINTS.split(',')\n",
" if cp.strip()\n",
" ])\n",
" else:\n",
" cluster = Cluster()\n",
" session = cluster.connect()\n",
"elif database_mode == 'A':\n",
" ASTRA_DB_CLIENT_ID = \"token\"\n",
@ -261,7 +270,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.10.6"
}
},
"nbformat": 4,

@ -1,8 +1,8 @@
"""Wrapper around Cassandra vector-store capabilities, based on cassIO."""
from __future__ import annotations
import hashlib
import typing
import uuid
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar
import numpy as np
@ -17,14 +17,6 @@ from langchain.vectorstores.utils import maximal_marginal_relevance
CVST = TypeVar("CVST", bound="Cassandra")
# a positive number of seconds to expire entries, or None for no expiration.
CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS = None
def _hash(_input: str) -> str:
"""Use a deterministic hashing approach."""
return hashlib.md5(_input.encode()).hexdigest()
class Cassandra(VectorStore):
"""Wrapper around Cassandra embeddings platform.
@ -46,7 +38,7 @@ class Cassandra(VectorStore):
_embedding_dimension: int | None
def _getEmbeddingDimension(self) -> int:
def _get_embedding_dimension(self) -> int:
if self._embedding_dimension is None:
self._embedding_dimension = len(
self.embedding.embed_query("This is a sample sentence.")
@ -59,7 +51,7 @@ class Cassandra(VectorStore):
session: Session,
keyspace: str,
table_name: str,
ttl_seconds: int | None = CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS,
ttl_seconds: Optional[int] = None,
) -> None:
try:
from cassio.vector import VectorTable
@ -81,8 +73,8 @@ class Cassandra(VectorStore):
session=session,
keyspace=keyspace,
table=table_name,
embedding_dimension=self._getEmbeddingDimension(),
auto_id=False, # the `add_texts` contract admits user-provided ids
embedding_dimension=self._get_embedding_dimension(),
primary_key_type="TEXT",
)
def delete_collection(self) -> None:
@ -99,11 +91,27 @@ class Cassandra(VectorStore):
def delete_by_document_id(self, document_id: str) -> None:
return self.table.delete(document_id)
def delete(self, ids: List[str]) -> Optional[bool]:
"""Delete by vector ID.
Args:
ids: List of ids to delete.
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
for document_id in ids:
self.delete_by_document_id(document_id)
return True
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
batch_size: int = 16,
ttl_seconds: Optional[int] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
@ -112,33 +120,39 @@ class Cassandra(VectorStore):
texts (Iterable[str]): Texts to add to the vectorstore.
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
ids (Optional[List[str]], optional): Optional list of IDs.
batch_size (int): Number of concurrent requests to send to the server.
ttl_seconds (Optional[int], optional): Optional time-to-live
for the added texts.
Returns:
List[str]: List of IDs of the added texts.
"""
_texts = list(texts) # lest it be a generator or something
if ids is None:
# unless otherwise specified, we have deterministic IDs:
# re-inserting an existing document will not create a duplicate.
# (and effectively update the metadata)
ids = [_hash(text) for text in _texts]
ids = [uuid.uuid4().hex for _ in _texts]
if metadatas is None:
metadatas = [{} for _ in _texts]
#
ttl_seconds = kwargs.get("ttl_seconds", self.ttl_seconds)
ttl_seconds = ttl_seconds or self.ttl_seconds
#
embedding_vectors = self.embedding.embed_documents(_texts)
for text, embedding_vector, text_id, metadata in zip(
_texts, embedding_vectors, ids, metadatas
):
self.table.put(
document=text,
embedding_vector=embedding_vector,
document_id=text_id,
metadata=metadata,
ttl_seconds=ttl_seconds,
)
#
for i in range(0, len(_texts), batch_size):
batch_texts = _texts[i : i + batch_size]
batch_embedding_vectors = embedding_vectors[i : i + batch_size]
batch_ids = ids[i : i + batch_size]
batch_metadatas = metadatas[i : i + batch_size]
futures = [
self.table.put_async(
text, embedding_vector, text_id, metadata, ttl_seconds
)
for text, embedding_vector, text_id, metadata in zip(
batch_texts, batch_embedding_vectors, batch_ids, batch_metadatas
)
]
for future in futures:
future.result()
return ids
# id-returning search facilities
@ -181,7 +195,6 @@ class Cassandra(VectorStore):
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float, str]]:
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_id_by_vector(
@ -219,12 +232,10 @@ class Cassandra(VectorStore):
k: int = 4,
**kwargs: Any,
) -> List[Document]:
#
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_by_vector(
embedding_vector,
k,
**kwargs,
)
def similarity_search_by_vector(
@ -245,7 +256,6 @@ class Cassandra(VectorStore):
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_by_vector(
@ -266,7 +276,6 @@ class Cassandra(VectorStore):
return self.similarity_search_with_score(
query,
k,
**kwargs,
)
def max_marginal_relevance_search_by_vector(
@ -352,6 +361,7 @@ class Cassandra(VectorStore):
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
batch_size: int = 16,
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from raw texts.
@ -378,6 +388,7 @@ class Cassandra(VectorStore):
cls: Type[CVST],
documents: List[Document],
embedding: Embeddings,
batch_size: int = 16,
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from a document list.

552
poetry.lock generated

File diff suppressed because it is too large Load Diff

@ -113,7 +113,7 @@ esprima = {version = "^4.0.1", optional = true}
openllm = {version = ">=0.1.19", optional = true}
streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"}
psychicapi = {version = "^0.8.0", optional = true}
cassio = {version = "^0.0.6", optional = true}
cassio = {version = "^0.0.7", optional = true}
[tool.poetry.group.docs.dependencies]
autodoc_pydantic = "^1.8.0"
@ -188,7 +188,7 @@ gptcache = "^0.1.9"
promptlayer = "^0.1.80"
tair = "^1.3.3"
wikipedia = "^1"
cassio = "^0.0.6"
cassio = "^0.0.7"
arxiv = "^1.4"
mastodon-py = "^1.8.1"
momento = "^1.5.0"

@ -84,7 +84,7 @@ def test_cassandra_max_marginal_relevance_search() -> None:
With fetch_k==3 and k==2, when query is at (1, ),
one expects that v2 and v0 are returned (in some order).
"""
texts = ["-0.125", "+0.125", "+0.25", "+1.0"]
texts = ["-0.124", "+0.127", "+0.25", "+1.0"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _vectorstore_from_texts(
texts, metadatas=metadatas, embedding_class=AngularTwoDimensionalEmbeddings
@ -95,7 +95,7 @@ def test_cassandra_max_marginal_relevance_search() -> None:
}
assert output_set == {
("+0.25", 2),
("-0.125", 0),
("-0.124", 0),
}
@ -105,9 +105,9 @@ def test_cassandra_add_extra() -> None:
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _vectorstore_from_texts(texts, metadatas=metadatas)
docsearch.add_texts(texts, metadatas)
texts2 = ["foo2", "bar2", "baz2"]
docsearch.add_texts(texts2, metadatas)
metadatas2 = [{"page": i + 3} for i in range(len(texts))]
docsearch.add_texts(texts2, metadatas2)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6
@ -127,9 +127,37 @@ def test_cassandra_no_drop() -> None:
assert len(output) == 6
def test_cassandra_delete() -> None:
"""Test delete methods from vector store."""
texts = ["foo", "bar", "baz", "gni"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _vectorstore_from_texts([], metadatas=metadatas)
ids = docsearch.add_texts(texts, metadatas)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 4
docsearch.delete_by_document_id(ids[0])
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 3
docsearch.delete(ids[1:3])
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 1
docsearch.delete(["not-existing"])
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 1
docsearch.clear()
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 0
# if __name__ == "__main__":
# test_cassandra()
# test_cassandra_with_score()
# test_cassandra_max_marginal_relevance_search()
# test_cassandra_add_extra()
# test_cassandra_no_drop()
# test_cassandra_delete()

Loading…
Cancel
Save