mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
community: Cassandra Vector Store: extend metadata-related methods (#27078)
**Description:** this PR adds a set of methods to deal with metadata associated to the vector store entries. These, while essential to the Graph-related extension of the `Cassandra` vector store, are also useful in themselves. These are (all come in their sync+async versions): - `[a]delete_by_metadata_filter` - `[a]replace_metadata` - `[a]get_by_document_id` - `[a]metadata_search` Additionally, a `[a]similarity_search_with_embedding_id_by_vector` method is introduced to better serve the store's internal working (esp. related to reranking logic). **Issue:** no issue number, but now all Document's returned bear their `.id` consistently (as a consequence of a slight refactoring in how the raw entries read from DB are made back into `Document` instances). **Dependencies:** (no new deps: packaging comes through langchain-core already; `cassio` is now required to be version 0.1.10+) **Add tests and docs** Added integration tests for the relevant newly-introduced methods. (Docs will be updated in a separate PR). **Lint and test** Lint and (updated) test all pass. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
84c05b031d
commit
d05fdd97dd
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib.metadata
|
||||
import typing
|
||||
import uuid
|
||||
from typing import (
|
||||
@ -18,6 +19,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from packaging.version import Version # this is a lancghain-core dependency
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from cassandra.cluster import Session
|
||||
@ -30,6 +32,7 @@ from langchain_community.utilities.cassandra import SetupMode
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
CVST = TypeVar("CVST", bound="Cassandra")
|
||||
MIN_CASSIO_VERSION = Version("0.1.10")
|
||||
|
||||
|
||||
class Cassandra(VectorStore):
|
||||
@ -110,6 +113,15 @@ class Cassandra(VectorStore):
|
||||
"Could not import cassio python package. "
|
||||
"Please install it with `pip install cassio`."
|
||||
)
|
||||
cassio_version = Version(importlib.metadata.version("cassio"))
|
||||
|
||||
if cassio_version is not None and cassio_version < MIN_CASSIO_VERSION:
|
||||
msg = (
|
||||
"Cassio version not supported. Please upgrade cassio "
|
||||
f"to version {MIN_CASSIO_VERSION} or higher."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
if not table_name:
|
||||
raise ValueError("Missing required parameter 'table_name'.")
|
||||
self.embedding = embedding
|
||||
@ -143,6 +155,9 @@ class Cassandra(VectorStore):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.session is None:
|
||||
self.session = self.table.session
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding
|
||||
@ -231,6 +246,70 @@ class Cassandra(VectorStore):
|
||||
await self.adelete_by_document_id(document_id)
|
||||
return True
|
||||
|
||||
def delete_by_metadata_filter(
|
||||
self,
|
||||
filter: dict[str, Any],
|
||||
*,
|
||||
batch_size: int = 50,
|
||||
) -> int:
|
||||
"""Delete all documents matching a certain metadata filtering condition.
|
||||
|
||||
This operation does not use the vector embeddings in any way, it simply
|
||||
removes all documents whose metadata match the provided condition.
|
||||
|
||||
Args:
|
||||
filter: Filter on the metadata to apply. The filter cannot be empty.
|
||||
batch_size: amount of deletions per each batch (until exhaustion of
|
||||
the matching documents).
|
||||
|
||||
Returns:
|
||||
A number expressing the amount of deleted documents.
|
||||
"""
|
||||
if not filter:
|
||||
msg = (
|
||||
"Method `delete_by_metadata_filter` does not accept an empty "
|
||||
"filter. Use the `clear()` method if you really want to empty "
|
||||
"the vector store."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
return self.table.find_and_delete_entries(
|
||||
metadata=filter,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
async def adelete_by_metadata_filter(
|
||||
self,
|
||||
filter: dict[str, Any],
|
||||
*,
|
||||
batch_size: int = 50,
|
||||
) -> int:
|
||||
"""Delete all documents matching a certain metadata filtering condition.
|
||||
|
||||
This operation does not use the vector embeddings in any way, it simply
|
||||
removes all documents whose metadata match the provided condition.
|
||||
|
||||
Args:
|
||||
filter: Filter on the metadata to apply. The filter cannot be empty.
|
||||
batch_size: amount of deletions per each batch (until exhaustion of
|
||||
the matching documents).
|
||||
|
||||
Returns:
|
||||
A number expressing the amount of deleted documents.
|
||||
"""
|
||||
if not filter:
|
||||
msg = (
|
||||
"Method `delete_by_metadata_filter` does not accept an empty "
|
||||
"filter. Use the `clear()` method if you really want to empty "
|
||||
"the vector store."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
return await self.table.afind_and_delete_entries(
|
||||
metadata=filter,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
@ -333,6 +412,180 @@ class Cassandra(VectorStore):
|
||||
await asyncio.gather(*tasks)
|
||||
return ids
|
||||
|
||||
def replace_metadata(
|
||||
self,
|
||||
id_to_metadata: dict[str, dict],
|
||||
*,
|
||||
batch_size: int = 50,
|
||||
) -> None:
|
||||
"""Replace the metadata of documents.
|
||||
|
||||
For each document to update, identified by its ID, the new metadata
|
||||
dictionary completely replaces what is on the store. This includes
|
||||
passing empty metadata `{}` to erase the currently-stored information.
|
||||
|
||||
Args:
|
||||
id_to_metadata: map from the Document IDs to modify to the
|
||||
new metadata for updating.
|
||||
Keys in this dictionary that do not correspond to an existing
|
||||
document will not cause an error, rather will result in new
|
||||
rows being written into the Cassandra table but without an
|
||||
associated vector: hence unreachable through vector search.
|
||||
batch_size: Number of concurrent requests to send to the server.
|
||||
|
||||
Returns:
|
||||
None if the writes succeed (otherwise an error is raised).
|
||||
"""
|
||||
ids_and_metadatas = list(id_to_metadata.items())
|
||||
for i in range(0, len(ids_and_metadatas), batch_size):
|
||||
batch_i_m = ids_and_metadatas[i : i + batch_size]
|
||||
futures = [
|
||||
self.table.put_async(
|
||||
row_id=doc_id,
|
||||
metadata=doc_md,
|
||||
)
|
||||
for doc_id, doc_md in batch_i_m
|
||||
]
|
||||
for future in futures:
|
||||
future.result()
|
||||
return
|
||||
|
||||
async def areplace_metadata(
|
||||
self,
|
||||
id_to_metadata: dict[str, dict],
|
||||
*,
|
||||
concurrency: int = 50,
|
||||
) -> None:
|
||||
"""Replace the metadata of documents.
|
||||
|
||||
For each document to update, identified by its ID, the new metadata
|
||||
dictionary completely replaces what is on the store. This includes
|
||||
passing empty metadata `{}` to erase the currently-stored information.
|
||||
|
||||
Args:
|
||||
id_to_metadata: map from the Document IDs to modify to the
|
||||
new metadata for updating.
|
||||
Keys in this dictionary that do not correspond to an existing
|
||||
document will not cause an error, rather will result in new
|
||||
rows being written into the Cassandra table but without an
|
||||
associated vector: hence unreachable through vector search.
|
||||
concurrency: Number of concurrent queries to the database.
|
||||
Defaults to 50.
|
||||
|
||||
Returns:
|
||||
None if the writes succeed (otherwise an error is raised).
|
||||
"""
|
||||
ids_and_metadatas = list(id_to_metadata.items())
|
||||
|
||||
sem = asyncio.Semaphore(concurrency)
|
||||
|
||||
async def send_concurrently(doc_id: str, doc_md: dict) -> None:
|
||||
async with sem:
|
||||
await self.table.aput(
|
||||
row_id=doc_id,
|
||||
metadata=doc_md,
|
||||
)
|
||||
|
||||
for doc_id, doc_md in ids_and_metadatas:
|
||||
tasks = [asyncio.create_task(send_concurrently(doc_id, doc_md))]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _row_to_document(row: Dict[str, Any]) -> Document:
|
||||
return Document(
|
||||
id=row["row_id"],
|
||||
page_content=row["body_blob"],
|
||||
metadata=row["metadata"],
|
||||
)
|
||||
|
||||
def get_by_document_id(self, document_id: str) -> Document | None:
|
||||
"""Get by document ID.
|
||||
|
||||
Args:
|
||||
document_id: the document ID to get.
|
||||
"""
|
||||
row = self.table.get(row_id=document_id)
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_document(row=row)
|
||||
|
||||
async def aget_by_document_id(self, document_id: str) -> Document | None:
|
||||
"""Get by document ID.
|
||||
|
||||
Args:
|
||||
document_id: the document ID to get.
|
||||
"""
|
||||
row = await self.table.aget(row_id=document_id)
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_document(row=row)
|
||||
|
||||
def metadata_search(
|
||||
self,
|
||||
metadata: dict[str, Any] = {}, # noqa: B006
|
||||
n: int = 5,
|
||||
) -> Iterable[Document]:
|
||||
"""Get documents via a metadata search.
|
||||
|
||||
Args:
|
||||
metadata: the metadata to query for.
|
||||
"""
|
||||
rows = self.table.find_entries(metadata=metadata, n=n)
|
||||
return [self._row_to_document(row=row) for row in rows if row]
|
||||
|
||||
async def ametadata_search(
|
||||
self,
|
||||
metadata: dict[str, Any] = {}, # noqa: B006
|
||||
n: int = 5,
|
||||
) -> Iterable[Document]:
|
||||
"""Get documents via a metadata search.
|
||||
|
||||
Args:
|
||||
metadata: the metadata to query for.
|
||||
"""
|
||||
rows = await self.table.afind_entries(metadata=metadata, n=n)
|
||||
return [self._row_to_document(row=row) for row in rows]
|
||||
|
||||
async def asimilarity_search_with_embedding_id_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
body_search: Optional[Union[str, List[str]]] = None,
|
||||
) -> List[Tuple[Document, List[float], str]]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: Filter on the metadata to apply.
|
||||
body_search: Document textual search terms to apply.
|
||||
Only supported by Astra DB at the moment.
|
||||
Returns:
|
||||
List of (Document, embedding, id), the most similar to the query vector.
|
||||
"""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if filter is not None:
|
||||
kwargs["metadata"] = filter
|
||||
if body_search is not None:
|
||||
kwargs["body_search"] = body_search
|
||||
|
||||
hits = await self.table.aann_search(
|
||||
vector=embedding,
|
||||
n=k,
|
||||
**kwargs,
|
||||
)
|
||||
return [
|
||||
(
|
||||
self._row_to_document(row=hit),
|
||||
hit["vector"],
|
||||
hit["row_id"],
|
||||
)
|
||||
for hit in hits
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _search_to_documents(
|
||||
hits: Iterable[Dict[str, Any]],
|
||||
@ -341,10 +594,7 @@ class Cassandra(VectorStore):
|
||||
# (1=most relevant), as required by this class' contract.
|
||||
return [
|
||||
(
|
||||
Document(
|
||||
page_content=hit["body_blob"],
|
||||
metadata=hit["metadata"],
|
||||
),
|
||||
Cassandra._row_to_document(row=hit),
|
||||
0.5 + 0.5 * hit["distance"],
|
||||
hit["row_id"],
|
||||
)
|
||||
@ -375,7 +625,6 @@ class Cassandra(VectorStore):
|
||||
kwargs["metadata"] = filter
|
||||
if body_search is not None:
|
||||
kwargs["body_search"] = body_search
|
||||
|
||||
hits = self.table.metric_ann_search(
|
||||
vector=embedding,
|
||||
n=k,
|
||||
@ -712,13 +961,7 @@ class Cassandra(VectorStore):
|
||||
for pf_index, pf_hit in enumerate(prefetch_hits)
|
||||
if pf_index in mmr_chosen_indices
|
||||
]
|
||||
return [
|
||||
Document(
|
||||
page_content=hit["body_blob"],
|
||||
metadata=hit["metadata"],
|
||||
)
|
||||
for hit in mmr_hits
|
||||
]
|
||||
return [Cassandra._row_to_document(row=hit) for hit in mmr_hits]
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
|
@ -17,6 +17,17 @@ from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
)
|
||||
|
||||
|
||||
def _strip_docs(documents: List[Document]) -> List[Document]:
|
||||
return [_strip_doc(doc) for doc in documents]
|
||||
|
||||
|
||||
def _strip_doc(document: Document) -> Document:
|
||||
return Document(
|
||||
page_content=document.page_content,
|
||||
metadata=document.metadata,
|
||||
)
|
||||
|
||||
|
||||
def _vectorstore_from_texts(
|
||||
texts: List[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
@ -110,9 +121,9 @@ async def test_cassandra() -> None:
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = _vectorstore_from_texts(texts)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
assert _strip_docs(output) == _strip_docs([Document(page_content="foo")])
|
||||
output = await docsearch.asimilarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
assert _strip_docs(output) == _strip_docs([Document(page_content="foo")])
|
||||
|
||||
|
||||
async def test_cassandra_with_score() -> None:
|
||||
@ -130,13 +141,13 @@ async def test_cassandra_with_score() -> None:
|
||||
output = docsearch.similarity_search_with_score("foo", k=3)
|
||||
docs = [o[0] for o in output]
|
||||
scores = [o[1] for o in output]
|
||||
assert docs == expected_docs
|
||||
assert _strip_docs(docs) == _strip_docs(expected_docs)
|
||||
assert scores[0] > scores[1] > scores[2]
|
||||
|
||||
output = await docsearch.asimilarity_search_with_score("foo", k=3)
|
||||
docs = [o[0] for o in output]
|
||||
scores = [o[1] for o in output]
|
||||
assert docs == expected_docs
|
||||
assert _strip_docs(docs) == _strip_docs(expected_docs)
|
||||
assert scores[0] > scores[1] > scores[2]
|
||||
|
||||
|
||||
@ -239,7 +250,7 @@ async def test_cassandra_no_drop_async() -> None:
|
||||
def test_cassandra_delete() -> None:
|
||||
"""Test delete methods from vector store."""
|
||||
texts = ["foo", "bar", "baz", "gni"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
metadatas = [{"page": i, "mod2": i % 2} for i in range(len(texts))]
|
||||
docsearch = _vectorstore_from_texts([], metadatas=metadatas)
|
||||
|
||||
ids = docsearch.add_texts(texts, metadatas)
|
||||
@ -263,11 +274,21 @@ def test_cassandra_delete() -> None:
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 0
|
||||
|
||||
docsearch.add_texts(texts, metadatas)
|
||||
num_deleted = docsearch.delete_by_metadata_filter({"mod2": 0}, batch_size=1)
|
||||
assert num_deleted == 2
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 2
|
||||
docsearch.clear()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
docsearch.delete_by_metadata_filter({})
|
||||
|
||||
|
||||
async def test_cassandra_adelete() -> None:
|
||||
"""Test delete methods from vector store."""
|
||||
texts = ["foo", "bar", "baz", "gni"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
metadatas = [{"page": i, "mod2": i % 2} for i in range(len(texts))]
|
||||
docsearch = await _vectorstore_from_texts_async([], metadatas=metadatas)
|
||||
|
||||
ids = await docsearch.aadd_texts(texts, metadatas)
|
||||
@ -291,6 +312,16 @@ async def test_cassandra_adelete() -> None:
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 0
|
||||
|
||||
await docsearch.aadd_texts(texts, metadatas)
|
||||
num_deleted = await docsearch.adelete_by_metadata_filter({"mod2": 0}, batch_size=1)
|
||||
assert num_deleted == 2
|
||||
output = await docsearch.asimilarity_search("foo", k=10)
|
||||
assert len(output) == 2
|
||||
await docsearch.aclear()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await docsearch.adelete_by_metadata_filter({})
|
||||
|
||||
|
||||
def test_cassandra_metadata_indexing() -> None:
|
||||
"""Test comparing metadata indexing policies."""
|
||||
@ -316,3 +347,107 @@ def test_cassandra_metadata_indexing() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
# "Non-indexed metadata fields cannot be used in queries."
|
||||
vstore_f1.similarity_search("bar", filter={"field2": "b"}, k=2)
|
||||
|
||||
|
||||
def test_cassandra_replace_metadata() -> None:
|
||||
"""Test of replacing metadata."""
|
||||
N_DOCS = 100
|
||||
REPLACE_RATIO = 2 # one in ... will have replaced metadata
|
||||
BATCH_SIZE = 3
|
||||
|
||||
vstore_f1 = _vectorstore_from_texts(
|
||||
texts=[],
|
||||
metadata_indexing=("allowlist", ["field1", "field2"]),
|
||||
table_name="vector_test_table_indexing",
|
||||
)
|
||||
orig_documents = [
|
||||
Document(
|
||||
page_content=f"doc_{doc_i}",
|
||||
id=f"doc_id_{doc_i}",
|
||||
metadata={"field1": f"f1_{doc_i}", "otherf": "pre"},
|
||||
)
|
||||
for doc_i in range(N_DOCS)
|
||||
]
|
||||
vstore_f1.add_documents(orig_documents)
|
||||
|
||||
ids_to_replace = [
|
||||
f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0
|
||||
]
|
||||
|
||||
# various kinds of replacement at play here:
|
||||
def _make_new_md(mode: int, doc_id: str) -> dict[str, str]:
|
||||
if mode == 0:
|
||||
return {}
|
||||
elif mode == 1:
|
||||
return {"field2": f"NEW_{doc_id}"}
|
||||
elif mode == 2:
|
||||
return {"field2": f"NEW_{doc_id}", "ofherf2": "post"}
|
||||
else:
|
||||
return {"ofherf2": "post"}
|
||||
|
||||
ids_to_new_md = {
|
||||
doc_id: _make_new_md(rep_i % 4, doc_id)
|
||||
for rep_i, doc_id in enumerate(ids_to_replace)
|
||||
}
|
||||
|
||||
vstore_f1.replace_metadata(ids_to_new_md, batch_size=BATCH_SIZE)
|
||||
# thorough check
|
||||
expected_id_to_metadata: dict[str, dict] = {
|
||||
**{(document.id or ""): document.metadata for document in orig_documents},
|
||||
**ids_to_new_md,
|
||||
}
|
||||
for hit in vstore_f1.similarity_search("doc", k=N_DOCS + 1):
|
||||
assert hit.id is not None
|
||||
assert hit.metadata == expected_id_to_metadata[hit.id]
|
||||
|
||||
|
||||
async def test_cassandra_areplace_metadata() -> None:
|
||||
"""Test of replacing metadata."""
|
||||
N_DOCS = 100
|
||||
REPLACE_RATIO = 2 # one in ... will have replaced metadata
|
||||
BATCH_SIZE = 3
|
||||
|
||||
vstore_f1 = _vectorstore_from_texts(
|
||||
texts=[],
|
||||
metadata_indexing=("allowlist", ["field1", "field2"]),
|
||||
table_name="vector_test_table_indexing",
|
||||
)
|
||||
orig_documents = [
|
||||
Document(
|
||||
page_content=f"doc_{doc_i}",
|
||||
id=f"doc_id_{doc_i}",
|
||||
metadata={"field1": f"f1_{doc_i}", "otherf": "pre"},
|
||||
)
|
||||
for doc_i in range(N_DOCS)
|
||||
]
|
||||
await vstore_f1.aadd_documents(orig_documents)
|
||||
|
||||
ids_to_replace = [
|
||||
f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0
|
||||
]
|
||||
|
||||
# various kinds of replacement at play here:
|
||||
def _make_new_md(mode: int, doc_id: str) -> dict[str, str]:
|
||||
if mode == 0:
|
||||
return {}
|
||||
elif mode == 1:
|
||||
return {"field2": f"NEW_{doc_id}"}
|
||||
elif mode == 2:
|
||||
return {"field2": f"NEW_{doc_id}", "ofherf2": "post"}
|
||||
else:
|
||||
return {"ofherf2": "post"}
|
||||
|
||||
ids_to_new_md = {
|
||||
doc_id: _make_new_md(rep_i % 4, doc_id)
|
||||
for rep_i, doc_id in enumerate(ids_to_replace)
|
||||
}
|
||||
|
||||
await vstore_f1.areplace_metadata(ids_to_new_md, concurrency=BATCH_SIZE)
|
||||
# thorough check
|
||||
expected_id_to_metadata: dict[str, dict] = {
|
||||
**{(document.id or ""): document.metadata for document in orig_documents},
|
||||
**ids_to_new_md,
|
||||
}
|
||||
for hit in await vstore_f1.asimilarity_search("doc", k=N_DOCS + 1):
|
||||
assert hit.id is not None
|
||||
assert hit.metadata == expected_id_to_metadata[hit.id]
|
||||
|
Loading…
Reference in New Issue
Block a user