community: Cassandra Vector Store: extend metadata-related methods (#27078)

**Description:** this PR adds a set of methods to deal with metadata
associated to the vector store entries. These, while essential to the
Graph-related extension of the `Cassandra` vector store, are also useful
in themselves. These are (all come in their sync+async versions):

- `[a]delete_by_metadata_filter`
- `[a]replace_metadata`
- `[a]get_by_document_id`
- `[a]metadata_search`

Additionally, a `[a]similarity_search_with_embedding_id_by_vector`
method is introduced to better serve the store's internal working (esp.
related to reranking logic).

**Issue:** no issue number, but now all Document's returned bear their
`.id` consistently (as a consequence of a slight refactoring in how the
raw entries read from DB are made back into `Document` instances).

**Dependencies:** (no new deps: packaging comes through langchain-core
already; `cassio` is now required to be version 0.1.10+)


**Add tests and docs**
Added integration tests for the relevant newly-introduced methods.
(Docs will be updated in a separate PR).

**Lint and test** Lint and (updated) test all pass.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Stefano Lottini 2024-10-09 08:41:34 +02:00 committed by GitHub
parent 84c05b031d
commit d05fdd97dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 396 additions and 18 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import importlib.metadata
import typing
import uuid
from typing import (
@ -18,6 +19,7 @@ from typing import (
)
import numpy as np
from packaging.version import Version # this is a lancghain-core dependency
if typing.TYPE_CHECKING:
from cassandra.cluster import Session
@ -30,6 +32,7 @@ from langchain_community.utilities.cassandra import SetupMode
from langchain_community.vectorstores.utils import maximal_marginal_relevance
CVST = TypeVar("CVST", bound="Cassandra")
MIN_CASSIO_VERSION = Version("0.1.10")
class Cassandra(VectorStore):
@ -110,6 +113,15 @@ class Cassandra(VectorStore):
"Could not import cassio python package. "
"Please install it with `pip install cassio`."
)
cassio_version = Version(importlib.metadata.version("cassio"))
if cassio_version is not None and cassio_version < MIN_CASSIO_VERSION:
msg = (
"Cassio version not supported. Please upgrade cassio "
f"to version {MIN_CASSIO_VERSION} or higher."
)
raise ImportError(msg)
if not table_name:
raise ValueError("Missing required parameter 'table_name'.")
self.embedding = embedding
@ -143,6 +155,9 @@ class Cassandra(VectorStore):
**kwargs,
)
if self.session is None:
self.session = self.table.session
@property
def embeddings(self) -> Embeddings:
return self.embedding
@ -231,6 +246,70 @@ class Cassandra(VectorStore):
await self.adelete_by_document_id(document_id)
return True
def delete_by_metadata_filter(
self,
filter: dict[str, Any],
*,
batch_size: int = 50,
) -> int:
"""Delete all documents matching a certain metadata filtering condition.
This operation does not use the vector embeddings in any way, it simply
removes all documents whose metadata match the provided condition.
Args:
filter: Filter on the metadata to apply. The filter cannot be empty.
batch_size: amount of deletions per each batch (until exhaustion of
the matching documents).
Returns:
A number expressing the amount of deleted documents.
"""
if not filter:
msg = (
"Method `delete_by_metadata_filter` does not accept an empty "
"filter. Use the `clear()` method if you really want to empty "
"the vector store."
)
raise ValueError(msg)
return self.table.find_and_delete_entries(
metadata=filter,
batch_size=batch_size,
)
async def adelete_by_metadata_filter(
self,
filter: dict[str, Any],
*,
batch_size: int = 50,
) -> int:
"""Delete all documents matching a certain metadata filtering condition.
This operation does not use the vector embeddings in any way, it simply
removes all documents whose metadata match the provided condition.
Args:
filter: Filter on the metadata to apply. The filter cannot be empty.
batch_size: amount of deletions per each batch (until exhaustion of
the matching documents).
Returns:
A number expressing the amount of deleted documents.
"""
if not filter:
msg = (
"Method `delete_by_metadata_filter` does not accept an empty "
"filter. Use the `clear()` method if you really want to empty "
"the vector store."
)
raise ValueError(msg)
return await self.table.afind_and_delete_entries(
metadata=filter,
batch_size=batch_size,
)
def add_texts(
self,
texts: Iterable[str],
@ -333,6 +412,180 @@ class Cassandra(VectorStore):
await asyncio.gather(*tasks)
return ids
def replace_metadata(
self,
id_to_metadata: dict[str, dict],
*,
batch_size: int = 50,
) -> None:
"""Replace the metadata of documents.
For each document to update, identified by its ID, the new metadata
dictionary completely replaces what is on the store. This includes
passing empty metadata `{}` to erase the currently-stored information.
Args:
id_to_metadata: map from the Document IDs to modify to the
new metadata for updating.
Keys in this dictionary that do not correspond to an existing
document will not cause an error, rather will result in new
rows being written into the Cassandra table but without an
associated vector: hence unreachable through vector search.
batch_size: Number of concurrent requests to send to the server.
Returns:
None if the writes succeed (otherwise an error is raised).
"""
ids_and_metadatas = list(id_to_metadata.items())
for i in range(0, len(ids_and_metadatas), batch_size):
batch_i_m = ids_and_metadatas[i : i + batch_size]
futures = [
self.table.put_async(
row_id=doc_id,
metadata=doc_md,
)
for doc_id, doc_md in batch_i_m
]
for future in futures:
future.result()
return
async def areplace_metadata(
self,
id_to_metadata: dict[str, dict],
*,
concurrency: int = 50,
) -> None:
"""Replace the metadata of documents.
For each document to update, identified by its ID, the new metadata
dictionary completely replaces what is on the store. This includes
passing empty metadata `{}` to erase the currently-stored information.
Args:
id_to_metadata: map from the Document IDs to modify to the
new metadata for updating.
Keys in this dictionary that do not correspond to an existing
document will not cause an error, rather will result in new
rows being written into the Cassandra table but without an
associated vector: hence unreachable through vector search.
concurrency: Number of concurrent queries to the database.
Defaults to 50.
Returns:
None if the writes succeed (otherwise an error is raised).
"""
ids_and_metadatas = list(id_to_metadata.items())
sem = asyncio.Semaphore(concurrency)
async def send_concurrently(doc_id: str, doc_md: dict) -> None:
async with sem:
await self.table.aput(
row_id=doc_id,
metadata=doc_md,
)
for doc_id, doc_md in ids_and_metadatas:
tasks = [asyncio.create_task(send_concurrently(doc_id, doc_md))]
await asyncio.gather(*tasks)
return
@staticmethod
def _row_to_document(row: Dict[str, Any]) -> Document:
return Document(
id=row["row_id"],
page_content=row["body_blob"],
metadata=row["metadata"],
)
def get_by_document_id(self, document_id: str) -> Document | None:
"""Get by document ID.
Args:
document_id: the document ID to get.
"""
row = self.table.get(row_id=document_id)
if row is None:
return None
return self._row_to_document(row=row)
async def aget_by_document_id(self, document_id: str) -> Document | None:
"""Get by document ID.
Args:
document_id: the document ID to get.
"""
row = await self.table.aget(row_id=document_id)
if row is None:
return None
return self._row_to_document(row=row)
def metadata_search(
self,
metadata: dict[str, Any] = {}, # noqa: B006
n: int = 5,
) -> Iterable[Document]:
"""Get documents via a metadata search.
Args:
metadata: the metadata to query for.
"""
rows = self.table.find_entries(metadata=metadata, n=n)
return [self._row_to_document(row=row) for row in rows if row]
async def ametadata_search(
self,
metadata: dict[str, Any] = {}, # noqa: B006
n: int = 5,
) -> Iterable[Document]:
"""Get documents via a metadata search.
Args:
metadata: the metadata to query for.
"""
rows = await self.table.afind_entries(metadata=metadata, n=n)
return [self._row_to_document(row=row) for row in rows]
async def asimilarity_search_with_embedding_id_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
body_search: Optional[Union[str, List[str]]] = None,
) -> List[Tuple[Document, List[float], str]]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
body_search: Document textual search terms to apply.
Only supported by Astra DB at the moment.
Returns:
List of (Document, embedding, id), the most similar to the query vector.
"""
kwargs: Dict[str, Any] = {}
if filter is not None:
kwargs["metadata"] = filter
if body_search is not None:
kwargs["body_search"] = body_search
hits = await self.table.aann_search(
vector=embedding,
n=k,
**kwargs,
)
return [
(
self._row_to_document(row=hit),
hit["vector"],
hit["row_id"],
)
for hit in hits
]
@staticmethod
def _search_to_documents(
hits: Iterable[Dict[str, Any]],
@ -341,10 +594,7 @@ class Cassandra(VectorStore):
# (1=most relevant), as required by this class' contract.
return [
(
Document(
page_content=hit["body_blob"],
metadata=hit["metadata"],
),
Cassandra._row_to_document(row=hit),
0.5 + 0.5 * hit["distance"],
hit["row_id"],
)
@ -375,7 +625,6 @@ class Cassandra(VectorStore):
kwargs["metadata"] = filter
if body_search is not None:
kwargs["body_search"] = body_search
hits = self.table.metric_ann_search(
vector=embedding,
n=k,
@ -712,13 +961,7 @@ class Cassandra(VectorStore):
for pf_index, pf_hit in enumerate(prefetch_hits)
if pf_index in mmr_chosen_indices
]
return [
Document(
page_content=hit["body_blob"],
metadata=hit["metadata"],
)
for hit in mmr_hits
]
return [Cassandra._row_to_document(row=hit) for hit in mmr_hits]
def max_marginal_relevance_search_by_vector(
self,

View File

@ -17,6 +17,17 @@ from tests.integration_tests.vectorstores.fake_embeddings import (
)
def _strip_docs(documents: List[Document]) -> List[Document]:
return [_strip_doc(doc) for doc in documents]
def _strip_doc(document: Document) -> Document:
return Document(
page_content=document.page_content,
metadata=document.metadata,
)
def _vectorstore_from_texts(
texts: List[str],
metadatas: Optional[List[dict]] = None,
@ -110,9 +121,9 @@ async def test_cassandra() -> None:
texts = ["foo", "bar", "baz"]
docsearch = _vectorstore_from_texts(texts)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
assert _strip_docs(output) == _strip_docs([Document(page_content="foo")])
output = await docsearch.asimilarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
assert _strip_docs(output) == _strip_docs([Document(page_content="foo")])
async def test_cassandra_with_score() -> None:
@ -130,13 +141,13 @@ async def test_cassandra_with_score() -> None:
output = docsearch.similarity_search_with_score("foo", k=3)
docs = [o[0] for o in output]
scores = [o[1] for o in output]
assert docs == expected_docs
assert _strip_docs(docs) == _strip_docs(expected_docs)
assert scores[0] > scores[1] > scores[2]
output = await docsearch.asimilarity_search_with_score("foo", k=3)
docs = [o[0] for o in output]
scores = [o[1] for o in output]
assert docs == expected_docs
assert _strip_docs(docs) == _strip_docs(expected_docs)
assert scores[0] > scores[1] > scores[2]
@ -239,7 +250,7 @@ async def test_cassandra_no_drop_async() -> None:
def test_cassandra_delete() -> None:
"""Test delete methods from vector store."""
texts = ["foo", "bar", "baz", "gni"]
metadatas = [{"page": i} for i in range(len(texts))]
metadatas = [{"page": i, "mod2": i % 2} for i in range(len(texts))]
docsearch = _vectorstore_from_texts([], metadatas=metadatas)
ids = docsearch.add_texts(texts, metadatas)
@ -263,11 +274,21 @@ def test_cassandra_delete() -> None:
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 0
docsearch.add_texts(texts, metadatas)
num_deleted = docsearch.delete_by_metadata_filter({"mod2": 0}, batch_size=1)
assert num_deleted == 2
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 2
docsearch.clear()
with pytest.raises(ValueError):
docsearch.delete_by_metadata_filter({})
async def test_cassandra_adelete() -> None:
"""Test delete methods from vector store."""
texts = ["foo", "bar", "baz", "gni"]
metadatas = [{"page": i} for i in range(len(texts))]
metadatas = [{"page": i, "mod2": i % 2} for i in range(len(texts))]
docsearch = await _vectorstore_from_texts_async([], metadatas=metadatas)
ids = await docsearch.aadd_texts(texts, metadatas)
@ -291,6 +312,16 @@ async def test_cassandra_adelete() -> None:
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 0
await docsearch.aadd_texts(texts, metadatas)
num_deleted = await docsearch.adelete_by_metadata_filter({"mod2": 0}, batch_size=1)
assert num_deleted == 2
output = await docsearch.asimilarity_search("foo", k=10)
assert len(output) == 2
await docsearch.aclear()
with pytest.raises(ValueError):
await docsearch.adelete_by_metadata_filter({})
def test_cassandra_metadata_indexing() -> None:
"""Test comparing metadata indexing policies."""
@ -316,3 +347,107 @@ def test_cassandra_metadata_indexing() -> None:
with pytest.raises(ValueError):
# "Non-indexed metadata fields cannot be used in queries."
vstore_f1.similarity_search("bar", filter={"field2": "b"}, k=2)
def test_cassandra_replace_metadata() -> None:
"""Test of replacing metadata."""
N_DOCS = 100
REPLACE_RATIO = 2 # one in ... will have replaced metadata
BATCH_SIZE = 3
vstore_f1 = _vectorstore_from_texts(
texts=[],
metadata_indexing=("allowlist", ["field1", "field2"]),
table_name="vector_test_table_indexing",
)
orig_documents = [
Document(
page_content=f"doc_{doc_i}",
id=f"doc_id_{doc_i}",
metadata={"field1": f"f1_{doc_i}", "otherf": "pre"},
)
for doc_i in range(N_DOCS)
]
vstore_f1.add_documents(orig_documents)
ids_to_replace = [
f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0
]
# various kinds of replacement at play here:
def _make_new_md(mode: int, doc_id: str) -> dict[str, str]:
if mode == 0:
return {}
elif mode == 1:
return {"field2": f"NEW_{doc_id}"}
elif mode == 2:
return {"field2": f"NEW_{doc_id}", "ofherf2": "post"}
else:
return {"ofherf2": "post"}
ids_to_new_md = {
doc_id: _make_new_md(rep_i % 4, doc_id)
for rep_i, doc_id in enumerate(ids_to_replace)
}
vstore_f1.replace_metadata(ids_to_new_md, batch_size=BATCH_SIZE)
# thorough check
expected_id_to_metadata: dict[str, dict] = {
**{(document.id or ""): document.metadata for document in orig_documents},
**ids_to_new_md,
}
for hit in vstore_f1.similarity_search("doc", k=N_DOCS + 1):
assert hit.id is not None
assert hit.metadata == expected_id_to_metadata[hit.id]
async def test_cassandra_areplace_metadata() -> None:
"""Test of replacing metadata."""
N_DOCS = 100
REPLACE_RATIO = 2 # one in ... will have replaced metadata
BATCH_SIZE = 3
vstore_f1 = _vectorstore_from_texts(
texts=[],
metadata_indexing=("allowlist", ["field1", "field2"]),
table_name="vector_test_table_indexing",
)
orig_documents = [
Document(
page_content=f"doc_{doc_i}",
id=f"doc_id_{doc_i}",
metadata={"field1": f"f1_{doc_i}", "otherf": "pre"},
)
for doc_i in range(N_DOCS)
]
await vstore_f1.aadd_documents(orig_documents)
ids_to_replace = [
f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0
]
# various kinds of replacement at play here:
def _make_new_md(mode: int, doc_id: str) -> dict[str, str]:
if mode == 0:
return {}
elif mode == 1:
return {"field2": f"NEW_{doc_id}"}
elif mode == 2:
return {"field2": f"NEW_{doc_id}", "ofherf2": "post"}
else:
return {"ofherf2": "post"}
ids_to_new_md = {
doc_id: _make_new_md(rep_i % 4, doc_id)
for rep_i, doc_id in enumerate(ids_to_replace)
}
await vstore_f1.areplace_metadata(ids_to_new_md, concurrency=BATCH_SIZE)
# thorough check
expected_id_to_metadata: dict[str, dict] = {
**{(document.id or ""): document.metadata for document in orig_documents},
**ids_to_new_md,
}
for hit in await vstore_f1.asimilarity_search("doc", k=N_DOCS + 1):
assert hit.id is not None
assert hit.metadata == expected_id_to_metadata[hit.id]