mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
community[minor]: Add support for metadata indexing policy in Cassandra vector store (#22548)
This PR adds a constructor `metadata_indexing` parameter to the Cassandra vector store to allow optional fine-tuning of which fields of the metadata are to be indexed. This is a feature supported by the underlying CassIO library. Indexing mode of "all", "none" or deny- and allow-list based choices are available. The rationale is, in some cases it's advisable to programmatically exclude some portions of the metadata from the index if one knows in advance they won't ever be used at search-time. this keeps the index more lightweight and performant and avoids limitations on the length of _indexed_ strings. I added a integration test of the feature. I also added the possibility of running the integration test with Cassandra on an arbitrary IP address (e.g. Dockerized), via `CASSANDRA_CONTACT_POINTS=10.1.1.5,10.1.1.6 poetry run pytest [...]` or similar. While I was at it, I added a line to the `.gitignore` since the mypy _test_ cache was not ignored yet. My X (Twitter) handle: @rsprrs.
This commit is contained in:
parent
c3d4126eb1
commit
328d0c99f2
1
.gitignore
vendored
1
.gitignore
vendored
@ -133,6 +133,7 @@ env.bak/
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.mypy_cache_test/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
|
@ -59,6 +59,7 @@ class Cassandra(VectorStore):
|
||||
*,
|
||||
body_index_options: Optional[List[Tuple[str, Any]]] = None,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
|
||||
) -> None:
|
||||
"""Apache Cassandra(R) for vector-store workloads.
|
||||
|
||||
@ -83,13 +84,24 @@ class Cassandra(VectorStore):
|
||||
embedding: Embedding function to use.
|
||||
session: Cassandra driver session. If not provided, it is resolved from
|
||||
cassio.
|
||||
keyspace: Cassandra key space. If not provided, it is resolved from cassio.
|
||||
keyspace: Cassandra keyspace. If not provided, it is resolved from cassio.
|
||||
table_name: Cassandra table (required).
|
||||
ttl_seconds: Optional time-to-live for the added texts.
|
||||
body_index_options: Optional options used to create the body index.
|
||||
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
|
||||
setup_mode: mode used to create the Cassandra table (SYNC,
|
||||
ASYNC or OFF).
|
||||
metadata_indexing: Optional specification of a metadata indexing policy,
|
||||
i.e. to fine-tune which of the metadata fields are indexed.
|
||||
It can be a string ("all" or "none"), or a 2-tuple. The following
|
||||
means that all fields except 'f1', 'f2' ... are NOT indexed:
|
||||
metadata_indexing=("allowlist", ["f1", "f2", ...])
|
||||
The following means all fields EXCEPT 'g1', 'g2', ... are indexed:
|
||||
metadata_indexing("denylist", ["g1", "g2", ...])
|
||||
The default is to index every metadata field.
|
||||
Note: if you plan to have massive unique text metadata entries,
|
||||
consider not indexing them for performance
|
||||
(and to overcome max-length limitations).
|
||||
"""
|
||||
try:
|
||||
from cassio.table import MetadataVectorCassandraTable
|
||||
@ -125,7 +137,7 @@ class Cassandra(VectorStore):
|
||||
keyspace=keyspace,
|
||||
table=table_name,
|
||||
vector_dimension=embedding_dimension,
|
||||
metadata_indexing="all",
|
||||
metadata_indexing=metadata_indexing,
|
||||
primary_key_type="TEXT",
|
||||
skip_provisioning=setup_mode == SetupMode.OFF,
|
||||
**kwargs,
|
||||
@ -885,6 +897,7 @@ class Cassandra(VectorStore):
|
||||
batch_size: int = 16,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
body_index_options: Optional[List[Tuple[str, Any]]] = None,
|
||||
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
|
||||
**kwargs: Any,
|
||||
) -> CVST:
|
||||
"""Create a Cassandra vectorstore from raw texts.
|
||||
@ -915,6 +928,7 @@ class Cassandra(VectorStore):
|
||||
table_name=table_name,
|
||||
ttl_seconds=ttl_seconds,
|
||||
body_index_options=body_index_options,
|
||||
metadata_indexing=metadata_indexing,
|
||||
)
|
||||
store.add_texts(
|
||||
texts=texts, metadatas=metadatas, ids=ids, batch_size=batch_size
|
||||
@ -935,6 +949,7 @@ class Cassandra(VectorStore):
|
||||
concurrency: int = 16,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
body_index_options: Optional[List[Tuple[str, Any]]] = None,
|
||||
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
|
||||
**kwargs: Any,
|
||||
) -> CVST:
|
||||
"""Create a Cassandra vectorstore from raw texts.
|
||||
@ -966,6 +981,7 @@ class Cassandra(VectorStore):
|
||||
ttl_seconds=ttl_seconds,
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
body_index_options=body_index_options,
|
||||
metadata_indexing=metadata_indexing,
|
||||
)
|
||||
await store.aadd_texts(
|
||||
texts=texts, metadatas=metadatas, ids=ids, concurrency=concurrency
|
||||
@ -985,6 +1001,7 @@ class Cassandra(VectorStore):
|
||||
batch_size: int = 16,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
body_index_options: Optional[List[Tuple[str, Any]]] = None,
|
||||
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
|
||||
**kwargs: Any,
|
||||
) -> CVST:
|
||||
"""Create a Cassandra vectorstore from a document list.
|
||||
@ -1020,6 +1037,7 @@ class Cassandra(VectorStore):
|
||||
batch_size=batch_size,
|
||||
ttl_seconds=ttl_seconds,
|
||||
body_index_options=body_index_options,
|
||||
metadata_indexing=metadata_indexing,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -1036,6 +1054,7 @@ class Cassandra(VectorStore):
|
||||
concurrency: int = 16,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
body_index_options: Optional[List[Tuple[str, Any]]] = None,
|
||||
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
|
||||
**kwargs: Any,
|
||||
) -> CVST:
|
||||
"""Create a Cassandra vectorstore from a document list.
|
||||
@ -1071,6 +1090,7 @@ class Cassandra(VectorStore):
|
||||
concurrency=concurrency,
|
||||
ttl_seconds=ttl_seconds,
|
||||
body_index_options=body_index_options,
|
||||
metadata_indexing=metadata_indexing,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -1,8 +1,10 @@
|
||||
"""Test Cassandra functionality."""
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional, Type
|
||||
from typing import Iterable, List, Optional, Tuple, Type, Union
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.vectorstores import Cassandra
|
||||
@ -19,13 +21,22 @@ def _vectorstore_from_texts(
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings,
|
||||
drop: bool = True,
|
||||
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
|
||||
table_name: str = "vector_test_table",
|
||||
) -> Cassandra:
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
keyspace = "vector_test_keyspace"
|
||||
table_name = "vector_test_table"
|
||||
# get db connection
|
||||
cluster = Cluster()
|
||||
if "CASSANDRA_CONTACT_POINTS" in os.environ:
|
||||
contact_points = [
|
||||
cp.strip()
|
||||
for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",")
|
||||
if cp.strip()
|
||||
]
|
||||
else:
|
||||
contact_points = None
|
||||
cluster = Cluster(contact_points)
|
||||
session = cluster.connect()
|
||||
# ensure keyspace exists
|
||||
session.execute(
|
||||
@ -45,6 +56,7 @@ def _vectorstore_from_texts(
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
table_name=table_name,
|
||||
metadata_indexing=metadata_indexing,
|
||||
)
|
||||
|
||||
|
||||
@ -53,13 +65,22 @@ async def _vectorstore_from_texts_async(
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings,
|
||||
drop: bool = True,
|
||||
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
|
||||
table_name: str = "vector_test_table",
|
||||
) -> Cassandra:
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
keyspace = "vector_test_keyspace"
|
||||
table_name = "vector_test_table"
|
||||
# get db connection
|
||||
cluster = Cluster()
|
||||
if "CASSANDRA_CONTACT_POINTS" in os.environ:
|
||||
contact_points = [
|
||||
cp.strip()
|
||||
for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",")
|
||||
if cp.strip()
|
||||
]
|
||||
else:
|
||||
contact_points = None
|
||||
cluster = Cluster(contact_points)
|
||||
session = cluster.connect()
|
||||
# ensure keyspace exists
|
||||
session.execute(
|
||||
@ -268,3 +289,29 @@ async def test_cassandra_adelete() -> None:
|
||||
await asyncio.sleep(0.3)
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 0
|
||||
|
||||
|
||||
def test_cassandra_metadata_indexing() -> None:
|
||||
"""Test comparing metadata indexing policies."""
|
||||
texts = ["foo"]
|
||||
metadatas = [{"field1": "a", "field2": "b"}]
|
||||
vstore_all = _vectorstore_from_texts(texts, metadatas=metadatas)
|
||||
vstore_f1 = _vectorstore_from_texts(
|
||||
texts,
|
||||
metadatas=metadatas,
|
||||
metadata_indexing=("allowlist", ["field1"]),
|
||||
table_name="vector_test_table_indexing",
|
||||
)
|
||||
|
||||
output_all = vstore_all.similarity_search("bar", k=2)
|
||||
output_f1 = vstore_f1.similarity_search("bar", filter={"field1": "a"}, k=2)
|
||||
output_f1_no = vstore_f1.similarity_search("bar", filter={"field1": "Z"}, k=2)
|
||||
assert len(output_all) == 1
|
||||
assert output_all[0].metadata == metadatas[0]
|
||||
assert len(output_f1) == 1
|
||||
assert output_f1[0].metadata == metadatas[0]
|
||||
assert len(output_f1_no) == 0
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# "Non-indexed metadata fields cannot be used in queries."
|
||||
vstore_f1.similarity_search("bar", filter={"field2": "b"}, k=2)
|
||||
|
Loading…
Reference in New Issue
Block a user