diff --git a/.gitignore b/.gitignore index 1b5b4c52c9..b73165db55 100644 --- a/.gitignore +++ b/.gitignore @@ -133,6 +133,7 @@ env.bak/ # mypy .mypy_cache/ +.mypy_cache_test/ .dmypy.json dmypy.json diff --git a/libs/community/langchain_community/vectorstores/cassandra.py b/libs/community/langchain_community/vectorstores/cassandra.py index 1c4abfbb8e..85f6460d01 100644 --- a/libs/community/langchain_community/vectorstores/cassandra.py +++ b/libs/community/langchain_community/vectorstores/cassandra.py @@ -59,6 +59,7 @@ class Cassandra(VectorStore): *, body_index_options: Optional[List[Tuple[str, Any]]] = None, setup_mode: SetupMode = SetupMode.SYNC, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", ) -> None: """Apache Cassandra(R) for vector-store workloads. @@ -83,13 +84,24 @@ class Cassandra(VectorStore): embedding: Embedding function to use. session: Cassandra driver session. If not provided, it is resolved from cassio. - keyspace: Cassandra key space. If not provided, it is resolved from cassio. + keyspace: Cassandra keyspace. If not provided, it is resolved from cassio. table_name: Cassandra table (required). ttl_seconds: Optional time-to-live for the added texts. body_index_options: Optional options used to create the body index. Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] setup_mode: mode used to create the Cassandra table (SYNC, ASYNC or OFF). + metadata_indexing: Optional specification of a metadata indexing policy, + i.e. to fine-tune which of the metadata fields are indexed. + It can be a string ("all" or "none"), or a 2-tuple. The following + means that all fields except 'f1', 'f2' ... are NOT indexed: + metadata_indexing=("allowlist", ["f1", "f2", ...]) + The following means all fields EXCEPT 'g1', 'g2', ... are indexed: + metadata_indexing("denylist", ["g1", "g2", ...]) + The default is to index every metadata field. + Note: if you plan to have massive unique text metadata entries, + consider not indexing them for performance + (and to overcome max-length limitations). """ try: from cassio.table import MetadataVectorCassandraTable @@ -125,7 +137,7 @@ class Cassandra(VectorStore): keyspace=keyspace, table=table_name, vector_dimension=embedding_dimension, - metadata_indexing="all", + metadata_indexing=metadata_indexing, primary_key_type="TEXT", skip_provisioning=setup_mode == SetupMode.OFF, **kwargs, @@ -885,6 +897,7 @@ class Cassandra(VectorStore): batch_size: int = 16, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", **kwargs: Any, ) -> CVST: """Create a Cassandra vectorstore from raw texts. @@ -915,6 +928,7 @@ class Cassandra(VectorStore): table_name=table_name, ttl_seconds=ttl_seconds, body_index_options=body_index_options, + metadata_indexing=metadata_indexing, ) store.add_texts( texts=texts, metadatas=metadatas, ids=ids, batch_size=batch_size @@ -935,6 +949,7 @@ class Cassandra(VectorStore): concurrency: int = 16, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", **kwargs: Any, ) -> CVST: """Create a Cassandra vectorstore from raw texts. @@ -966,6 +981,7 @@ class Cassandra(VectorStore): ttl_seconds=ttl_seconds, setup_mode=SetupMode.ASYNC, body_index_options=body_index_options, + metadata_indexing=metadata_indexing, ) await store.aadd_texts( texts=texts, metadatas=metadatas, ids=ids, concurrency=concurrency @@ -985,6 +1001,7 @@ class Cassandra(VectorStore): batch_size: int = 16, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", **kwargs: Any, ) -> CVST: """Create a Cassandra vectorstore from a document list. @@ -1020,6 +1037,7 @@ class Cassandra(VectorStore): batch_size=batch_size, ttl_seconds=ttl_seconds, body_index_options=body_index_options, + metadata_indexing=metadata_indexing, **kwargs, ) @@ -1036,6 +1054,7 @@ class Cassandra(VectorStore): concurrency: int = 16, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", **kwargs: Any, ) -> CVST: """Create a Cassandra vectorstore from a document list. @@ -1071,6 +1090,7 @@ class Cassandra(VectorStore): concurrency=concurrency, ttl_seconds=ttl_seconds, body_index_options=body_index_options, + metadata_indexing=metadata_indexing, **kwargs, ) diff --git a/libs/community/tests/integration_tests/vectorstores/test_cassandra.py b/libs/community/tests/integration_tests/vectorstores/test_cassandra.py index 12c3a0bdf7..4e566f3cd7 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cassandra.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cassandra.py @@ -1,8 +1,10 @@ """Test Cassandra functionality.""" import asyncio +import os import time -from typing import List, Optional, Type +from typing import Iterable, List, Optional, Tuple, Type, Union +import pytest from langchain_core.documents import Document from langchain_community.vectorstores import Cassandra @@ -19,13 +21,22 @@ def _vectorstore_from_texts( metadatas: Optional[List[dict]] = None, embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings, drop: bool = True, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", + table_name: str = "vector_test_table", ) -> Cassandra: from cassandra.cluster import Cluster keyspace = "vector_test_keyspace" - table_name = "vector_test_table" # get db connection - cluster = Cluster() + if "CASSANDRA_CONTACT_POINTS" in os.environ: + contact_points = [ + cp.strip() + for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",") + if cp.strip() + ] + else: + contact_points = None + cluster = Cluster(contact_points) session = cluster.connect() # ensure keyspace exists session.execute( @@ -45,6 +56,7 @@ def _vectorstore_from_texts( session=session, keyspace=keyspace, table_name=table_name, + metadata_indexing=metadata_indexing, ) @@ -53,13 +65,22 @@ async def _vectorstore_from_texts_async( metadatas: Optional[List[dict]] = None, embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings, drop: bool = True, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", + table_name: str = "vector_test_table", ) -> Cassandra: from cassandra.cluster import Cluster keyspace = "vector_test_keyspace" - table_name = "vector_test_table" # get db connection - cluster = Cluster() + if "CASSANDRA_CONTACT_POINTS" in os.environ: + contact_points = [ + cp.strip() + for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",") + if cp.strip() + ] + else: + contact_points = None + cluster = Cluster(contact_points) session = cluster.connect() # ensure keyspace exists session.execute( @@ -268,3 +289,29 @@ async def test_cassandra_adelete() -> None: await asyncio.sleep(0.3) output = docsearch.similarity_search("foo", k=10) assert len(output) == 0 + + +def test_cassandra_metadata_indexing() -> None: + """Test comparing metadata indexing policies.""" + texts = ["foo"] + metadatas = [{"field1": "a", "field2": "b"}] + vstore_all = _vectorstore_from_texts(texts, metadatas=metadatas) + vstore_f1 = _vectorstore_from_texts( + texts, + metadatas=metadatas, + metadata_indexing=("allowlist", ["field1"]), + table_name="vector_test_table_indexing", + ) + + output_all = vstore_all.similarity_search("bar", k=2) + output_f1 = vstore_f1.similarity_search("bar", filter={"field1": "a"}, k=2) + output_f1_no = vstore_f1.similarity_search("bar", filter={"field1": "Z"}, k=2) + assert len(output_all) == 1 + assert output_all[0].metadata == metadatas[0] + assert len(output_f1) == 1 + assert output_f1[0].metadata == metadatas[0] + assert len(output_f1_no) == 0 + + with pytest.raises(ValueError): + # "Non-indexed metadata fields cannot be used in queries." + vstore_f1.similarity_search("bar", filter={"field2": "b"}, k=2)