community[minor]: Add support for metadata indexing policy in Cassandra vector store (#22548)

This PR adds a constructor `metadata_indexing` parameter to the
Cassandra vector store to allow optional fine-tuning of which fields of
the metadata are to be indexed.

This is a feature supported by the underlying CassIO library. Indexing
mode of "all", "none" or deny- and allow-list based choices are
available.

The rationale is, in some cases it's advisable to programmatically
exclude some portions of the metadata from the index if one knows in
advance they won't ever be used at search-time. this keeps the index
more lightweight and performant and avoids limitations on the length of
_indexed_ strings.

I added a integration test of the feature. I also added the possibility
of running the integration test with Cassandra on an arbitrary IP
address (e.g. Dockerized), via
`CASSANDRA_CONTACT_POINTS=10.1.1.5,10.1.1.6 poetry run pytest [...]` or
similar.

While I was at it, I added a line to the `.gitignore` since the mypy
_test_ cache was not ignored yet.

My X (Twitter) handle: @rsprrs.
This commit is contained in:
Stefano Lottini 2024-06-05 17:23:26 +02:00 committed by GitHub
parent c3d4126eb1
commit 328d0c99f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 75 additions and 7 deletions

1
.gitignore vendored
View File

@ -133,6 +133,7 @@ env.bak/
# mypy
.mypy_cache/
.mypy_cache_test/
.dmypy.json
dmypy.json

View File

@ -59,6 +59,7 @@ class Cassandra(VectorStore):
*,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
setup_mode: SetupMode = SetupMode.SYNC,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
) -> None:
"""Apache Cassandra(R) for vector-store workloads.
@ -83,13 +84,24 @@ class Cassandra(VectorStore):
embedding: Embedding function to use.
session: Cassandra driver session. If not provided, it is resolved from
cassio.
keyspace: Cassandra key space. If not provided, it is resolved from cassio.
keyspace: Cassandra keyspace. If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ttl_seconds: Optional time-to-live for the added texts.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
setup_mode: mode used to create the Cassandra table (SYNC,
ASYNC or OFF).
metadata_indexing: Optional specification of a metadata indexing policy,
i.e. to fine-tune which of the metadata fields are indexed.
It can be a string ("all" or "none"), or a 2-tuple. The following
means that all fields except 'f1', 'f2' ... are NOT indexed:
metadata_indexing=("allowlist", ["f1", "f2", ...])
The following means all fields EXCEPT 'g1', 'g2', ... are indexed:
metadata_indexing("denylist", ["g1", "g2", ...])
The default is to index every metadata field.
Note: if you plan to have massive unique text metadata entries,
consider not indexing them for performance
(and to overcome max-length limitations).
"""
try:
from cassio.table import MetadataVectorCassandraTable
@ -125,7 +137,7 @@ class Cassandra(VectorStore):
keyspace=keyspace,
table=table_name,
vector_dimension=embedding_dimension,
metadata_indexing="all",
metadata_indexing=metadata_indexing,
primary_key_type="TEXT",
skip_provisioning=setup_mode == SetupMode.OFF,
**kwargs,
@ -885,6 +897,7 @@ class Cassandra(VectorStore):
batch_size: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from raw texts.
@ -915,6 +928,7 @@ class Cassandra(VectorStore):
table_name=table_name,
ttl_seconds=ttl_seconds,
body_index_options=body_index_options,
metadata_indexing=metadata_indexing,
)
store.add_texts(
texts=texts, metadatas=metadatas, ids=ids, batch_size=batch_size
@ -935,6 +949,7 @@ class Cassandra(VectorStore):
concurrency: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from raw texts.
@ -966,6 +981,7 @@ class Cassandra(VectorStore):
ttl_seconds=ttl_seconds,
setup_mode=SetupMode.ASYNC,
body_index_options=body_index_options,
metadata_indexing=metadata_indexing,
)
await store.aadd_texts(
texts=texts, metadatas=metadatas, ids=ids, concurrency=concurrency
@ -985,6 +1001,7 @@ class Cassandra(VectorStore):
batch_size: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from a document list.
@ -1020,6 +1037,7 @@ class Cassandra(VectorStore):
batch_size=batch_size,
ttl_seconds=ttl_seconds,
body_index_options=body_index_options,
metadata_indexing=metadata_indexing,
**kwargs,
)
@ -1036,6 +1054,7 @@ class Cassandra(VectorStore):
concurrency: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from a document list.
@ -1071,6 +1090,7 @@ class Cassandra(VectorStore):
concurrency=concurrency,
ttl_seconds=ttl_seconds,
body_index_options=body_index_options,
metadata_indexing=metadata_indexing,
**kwargs,
)

View File

@ -1,8 +1,10 @@
"""Test Cassandra functionality."""
import asyncio
import os
import time
from typing import List, Optional, Type
from typing import Iterable, List, Optional, Tuple, Type, Union
import pytest
from langchain_core.documents import Document
from langchain_community.vectorstores import Cassandra
@ -19,13 +21,22 @@ def _vectorstore_from_texts(
metadatas: Optional[List[dict]] = None,
embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings,
drop: bool = True,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
table_name: str = "vector_test_table",
) -> Cassandra:
from cassandra.cluster import Cluster
keyspace = "vector_test_keyspace"
table_name = "vector_test_table"
# get db connection
cluster = Cluster()
if "CASSANDRA_CONTACT_POINTS" in os.environ:
contact_points = [
cp.strip()
for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",")
if cp.strip()
]
else:
contact_points = None
cluster = Cluster(contact_points)
session = cluster.connect()
# ensure keyspace exists
session.execute(
@ -45,6 +56,7 @@ def _vectorstore_from_texts(
session=session,
keyspace=keyspace,
table_name=table_name,
metadata_indexing=metadata_indexing,
)
@ -53,13 +65,22 @@ async def _vectorstore_from_texts_async(
metadatas: Optional[List[dict]] = None,
embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings,
drop: bool = True,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
table_name: str = "vector_test_table",
) -> Cassandra:
from cassandra.cluster import Cluster
keyspace = "vector_test_keyspace"
table_name = "vector_test_table"
# get db connection
cluster = Cluster()
if "CASSANDRA_CONTACT_POINTS" in os.environ:
contact_points = [
cp.strip()
for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",")
if cp.strip()
]
else:
contact_points = None
cluster = Cluster(contact_points)
session = cluster.connect()
# ensure keyspace exists
session.execute(
@ -268,3 +289,29 @@ async def test_cassandra_adelete() -> None:
await asyncio.sleep(0.3)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 0
def test_cassandra_metadata_indexing() -> None:
"""Test comparing metadata indexing policies."""
texts = ["foo"]
metadatas = [{"field1": "a", "field2": "b"}]
vstore_all = _vectorstore_from_texts(texts, metadatas=metadatas)
vstore_f1 = _vectorstore_from_texts(
texts,
metadatas=metadatas,
metadata_indexing=("allowlist", ["field1"]),
table_name="vector_test_table_indexing",
)
output_all = vstore_all.similarity_search("bar", k=2)
output_f1 = vstore_f1.similarity_search("bar", filter={"field1": "a"}, k=2)
output_f1_no = vstore_f1.similarity_search("bar", filter={"field1": "Z"}, k=2)
assert len(output_all) == 1
assert output_all[0].metadata == metadatas[0]
assert len(output_f1) == 1
assert output_f1[0].metadata == metadatas[0]
assert len(output_f1_no) == 0
with pytest.raises(ValueError):
# "Non-indexed metadata fields cannot be used in queries."
vstore_f1.similarity_search("bar", filter={"field2": "b"}, k=2)