From 328d0c99f24b02c9f5cc15b2832fab0372450742 Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Wed, 5 Jun 2024 17:23:26 +0200 Subject: [PATCH] community[minor]: Add support for metadata indexing policy in Cassandra vector store (#22548) This PR adds a constructor `metadata_indexing` parameter to the Cassandra vector store to allow optional fine-tuning of which fields of the metadata are to be indexed. This is a feature supported by the underlying CassIO library. Indexing mode of "all", "none" or deny- and allow-list based choices are available. The rationale is, in some cases it's advisable to programmatically exclude some portions of the metadata from the index if one knows in advance they won't ever be used at search-time. this keeps the index more lightweight and performant and avoids limitations on the length of _indexed_ strings. I added a integration test of the feature. I also added the possibility of running the integration test with Cassandra on an arbitrary IP address (e.g. Dockerized), via `CASSANDRA_CONTACT_POINTS=10.1.1.5,10.1.1.6 poetry run pytest [...]` or similar. While I was at it, I added a line to the `.gitignore` since the mypy _test_ cache was not ignored yet. My X (Twitter) handle: @rsprrs. --- .gitignore | 1 + .../vectorstores/cassandra.py | 24 +++++++- .../vectorstores/test_cassandra.py | 57 +++++++++++++++++-- 3 files changed, 75 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 1b5b4c52c9..b73165db55 100644 --- a/.gitignore +++ b/.gitignore @@ -133,6 +133,7 @@ env.bak/ # mypy .mypy_cache/ +.mypy_cache_test/ .dmypy.json dmypy.json diff --git a/libs/community/langchain_community/vectorstores/cassandra.py b/libs/community/langchain_community/vectorstores/cassandra.py index 1c4abfbb8e..85f6460d01 100644 --- a/libs/community/langchain_community/vectorstores/cassandra.py +++ b/libs/community/langchain_community/vectorstores/cassandra.py @@ -59,6 +59,7 @@ class Cassandra(VectorStore): *, body_index_options: Optional[List[Tuple[str, Any]]] = None, setup_mode: SetupMode = SetupMode.SYNC, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", ) -> None: """Apache Cassandra(R) for vector-store workloads. @@ -83,13 +84,24 @@ class Cassandra(VectorStore): embedding: Embedding function to use. session: Cassandra driver session. If not provided, it is resolved from cassio. - keyspace: Cassandra key space. If not provided, it is resolved from cassio. + keyspace: Cassandra keyspace. If not provided, it is resolved from cassio. table_name: Cassandra table (required). ttl_seconds: Optional time-to-live for the added texts. body_index_options: Optional options used to create the body index. Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] setup_mode: mode used to create the Cassandra table (SYNC, ASYNC or OFF). + metadata_indexing: Optional specification of a metadata indexing policy, + i.e. to fine-tune which of the metadata fields are indexed. + It can be a string ("all" or "none"), or a 2-tuple. The following + means that all fields except 'f1', 'f2' ... are NOT indexed: + metadata_indexing=("allowlist", ["f1", "f2", ...]) + The following means all fields EXCEPT 'g1', 'g2', ... are indexed: + metadata_indexing("denylist", ["g1", "g2", ...]) + The default is to index every metadata field. + Note: if you plan to have massive unique text metadata entries, + consider not indexing them for performance + (and to overcome max-length limitations). """ try: from cassio.table import MetadataVectorCassandraTable @@ -125,7 +137,7 @@ class Cassandra(VectorStore): keyspace=keyspace, table=table_name, vector_dimension=embedding_dimension, - metadata_indexing="all", + metadata_indexing=metadata_indexing, primary_key_type="TEXT", skip_provisioning=setup_mode == SetupMode.OFF, **kwargs, @@ -885,6 +897,7 @@ class Cassandra(VectorStore): batch_size: int = 16, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", **kwargs: Any, ) -> CVST: """Create a Cassandra vectorstore from raw texts. @@ -915,6 +928,7 @@ class Cassandra(VectorStore): table_name=table_name, ttl_seconds=ttl_seconds, body_index_options=body_index_options, + metadata_indexing=metadata_indexing, ) store.add_texts( texts=texts, metadatas=metadatas, ids=ids, batch_size=batch_size @@ -935,6 +949,7 @@ class Cassandra(VectorStore): concurrency: int = 16, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", **kwargs: Any, ) -> CVST: """Create a Cassandra vectorstore from raw texts. @@ -966,6 +981,7 @@ class Cassandra(VectorStore): ttl_seconds=ttl_seconds, setup_mode=SetupMode.ASYNC, body_index_options=body_index_options, + metadata_indexing=metadata_indexing, ) await store.aadd_texts( texts=texts, metadatas=metadatas, ids=ids, concurrency=concurrency @@ -985,6 +1001,7 @@ class Cassandra(VectorStore): batch_size: int = 16, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", **kwargs: Any, ) -> CVST: """Create a Cassandra vectorstore from a document list. @@ -1020,6 +1037,7 @@ class Cassandra(VectorStore): batch_size=batch_size, ttl_seconds=ttl_seconds, body_index_options=body_index_options, + metadata_indexing=metadata_indexing, **kwargs, ) @@ -1036,6 +1054,7 @@ class Cassandra(VectorStore): concurrency: int = 16, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", **kwargs: Any, ) -> CVST: """Create a Cassandra vectorstore from a document list. @@ -1071,6 +1090,7 @@ class Cassandra(VectorStore): concurrency=concurrency, ttl_seconds=ttl_seconds, body_index_options=body_index_options, + metadata_indexing=metadata_indexing, **kwargs, ) diff --git a/libs/community/tests/integration_tests/vectorstores/test_cassandra.py b/libs/community/tests/integration_tests/vectorstores/test_cassandra.py index 12c3a0bdf7..4e566f3cd7 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cassandra.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cassandra.py @@ -1,8 +1,10 @@ """Test Cassandra functionality.""" import asyncio +import os import time -from typing import List, Optional, Type +from typing import Iterable, List, Optional, Tuple, Type, Union +import pytest from langchain_core.documents import Document from langchain_community.vectorstores import Cassandra @@ -19,13 +21,22 @@ def _vectorstore_from_texts( metadatas: Optional[List[dict]] = None, embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings, drop: bool = True, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", + table_name: str = "vector_test_table", ) -> Cassandra: from cassandra.cluster import Cluster keyspace = "vector_test_keyspace" - table_name = "vector_test_table" # get db connection - cluster = Cluster() + if "CASSANDRA_CONTACT_POINTS" in os.environ: + contact_points = [ + cp.strip() + for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",") + if cp.strip() + ] + else: + contact_points = None + cluster = Cluster(contact_points) session = cluster.connect() # ensure keyspace exists session.execute( @@ -45,6 +56,7 @@ def _vectorstore_from_texts( session=session, keyspace=keyspace, table_name=table_name, + metadata_indexing=metadata_indexing, ) @@ -53,13 +65,22 @@ async def _vectorstore_from_texts_async( metadatas: Optional[List[dict]] = None, embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings, drop: bool = True, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", + table_name: str = "vector_test_table", ) -> Cassandra: from cassandra.cluster import Cluster keyspace = "vector_test_keyspace" - table_name = "vector_test_table" # get db connection - cluster = Cluster() + if "CASSANDRA_CONTACT_POINTS" in os.environ: + contact_points = [ + cp.strip() + for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",") + if cp.strip() + ] + else: + contact_points = None + cluster = Cluster(contact_points) session = cluster.connect() # ensure keyspace exists session.execute( @@ -268,3 +289,29 @@ async def test_cassandra_adelete() -> None: await asyncio.sleep(0.3) output = docsearch.similarity_search("foo", k=10) assert len(output) == 0 + + +def test_cassandra_metadata_indexing() -> None: + """Test comparing metadata indexing policies.""" + texts = ["foo"] + metadatas = [{"field1": "a", "field2": "b"}] + vstore_all = _vectorstore_from_texts(texts, metadatas=metadatas) + vstore_f1 = _vectorstore_from_texts( + texts, + metadatas=metadatas, + metadata_indexing=("allowlist", ["field1"]), + table_name="vector_test_table_indexing", + ) + + output_all = vstore_all.similarity_search("bar", k=2) + output_f1 = vstore_f1.similarity_search("bar", filter={"field1": "a"}, k=2) + output_f1_no = vstore_f1.similarity_search("bar", filter={"field1": "Z"}, k=2) + assert len(output_all) == 1 + assert output_all[0].metadata == metadatas[0] + assert len(output_f1) == 1 + assert output_f1[0].metadata == metadatas[0] + assert len(output_f1_no) == 0 + + with pytest.raises(ValueError): + # "Non-indexed metadata fields cannot be used in queries." + vstore_f1.similarity_search("bar", filter={"field2": "b"}, k=2)