community[minor]: Add metadata filtering support for neo4j vector (#20001)

This commit is contained in:
Tomaz Bratanic 2024-04-04 17:37:06 +02:00 committed by GitHub
parent b52b78478f
commit df25829f33
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 367 additions and 6 deletions

View File

@ -28,6 +28,35 @@ DISTANCE_MAPPING = {
DistanceStrategy.COSINE: "cosine", DistanceStrategy.COSINE: "cosine",
} }
COMPARISONS_TO_NATIVE = {
"$eq": "=",
"$ne": "<>",
"$lt": "<",
"$lte": "<=",
"$gt": ">",
"$gte": ">=",
}
SPECIAL_CASED_OPERATORS = {
"$in",
"$nin",
"$between",
}
TEXT_OPERATORS = {
"$like",
"$ilike",
}
LOGICAL_OPERATORS = {"$and", "$or"}
SUPPORTED_OPERATORS = (
set(COMPARISONS_TO_NATIVE)
.union(TEXT_OPERATORS)
.union(LOGICAL_OPERATORS)
.union(SPECIAL_CASED_OPERATORS)
)
class SearchType(str, enum.Enum): class SearchType(str, enum.Enum):
"""Enumerator of the Distance strategies.""" """Enumerator of the Distance strategies."""
@ -133,6 +162,240 @@ def dict_to_yaml_str(input_dict: Dict, indent: int = 0) -> str:
return yaml_str return yaml_str
def combine_queries(
input_queries: List[Tuple[str, Dict[str, Any]]], operator: str
) -> Tuple[str, Dict[str, Any]]:
# Initialize variables to hold the combined query and parameters
combined_query: str = ""
combined_params: Dict = {}
param_counter: Dict = {}
for query, params in input_queries:
# Process each query fragment and its parameters
new_query = query
for param, value in params.items():
# Update the parameter name to ensure uniqueness
if param in param_counter:
param_counter[param] += 1
else:
param_counter[param] = 1
new_param_name = f"{param}_{param_counter[param]}"
# Replace the parameter in the query fragment
new_query = new_query.replace(f"${param}", f"${new_param_name}")
# Add the parameter to the combined parameters dictionary
combined_params[new_param_name] = value
# Combine the query fragments with an AND operator
if combined_query:
combined_query += f" {operator} "
combined_query += f"({new_query})"
return combined_query, combined_params
def collect_params(
input_data: List[Tuple[str, Dict[str, str]]],
) -> Tuple[List[str], Dict[str, Any]]:
"""
Transform the input data into the desired format.
Args:
- input_data (list of tuples): Input data to transform.
Each tuple contains a string and a dictionary.
Returns:
- tuple: A tuple containing a list of strings and a dictionary.
"""
# Initialize variables to hold the output parts
query_parts = []
params = {}
# Loop through each item in the input data
for query_part, param in input_data:
# Append the query part to the list
query_parts.append(query_part)
# Update the params dictionary with the param dictionary
params.update(param)
# Return the transformed data
return (query_parts, params)
def _handle_field_filter(
field: str, value: Any, param_number: int = 1
) -> Tuple[str, Dict]:
"""Create a filter for a specific field.
Args:
field: name of field
value: value to filter
If provided as is then this will be an equality filter
If provided as a dictionary then this will be a filter, the key
will be the operator and the value will be the value to filter by
param_number: sequence number of parameters used to map between param
dict and Cypher snippet
Returns a tuple of
- Cypher filter snippet
- Dictionary with parameters used in filter snippet
"""
if not isinstance(field, str):
raise ValueError(
f"field should be a string but got: {type(field)} with value: {field}"
)
if field.startswith("$"):
raise ValueError(
f"Invalid filter condition. Expected a field but got an operator: "
f"{field}"
)
# Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters
if not field.isidentifier():
raise ValueError(f"Invalid field name: {field}. Expected a valid identifier.")
if isinstance(value, dict):
# This is a filter specification
if len(value) != 1:
raise ValueError(
"Invalid filter condition. Expected a value which "
"is a dictionary with a single key that corresponds to an operator "
f"but got a dictionary with {len(value)} keys. The first few "
f"keys are: {list(value.keys())[:3]}"
)
operator, filter_value = list(value.items())[0]
# Verify that that operator is an operator
if operator not in SUPPORTED_OPERATORS:
raise ValueError(
f"Invalid operator: {operator}. "
f"Expected one of {SUPPORTED_OPERATORS}"
)
else: # Then we assume an equality operator
operator = "$eq"
filter_value = value
if operator in COMPARISONS_TO_NATIVE:
# Then we implement an equality filter
# native is trusted input
native = COMPARISONS_TO_NATIVE[operator]
query_snippet = f"n.`{field}` {native} $param_{param_number}"
query_param = {f"param_{param_number}": filter_value}
return (query_snippet, query_param)
elif operator == "$between":
low, high = filter_value
query_snippet = (
f"$param_{param_number}_low <= n.`{field}` <= $param_{param_number}_high"
)
query_param = {
f"param_{param_number}_low": low,
f"param_{param_number}_high": high,
}
return (query_snippet, query_param)
elif operator in {"$in", "$nin", "$like", "$ilike"}:
# We'll do force coercion to text
if operator in {"$in", "$nin"}:
for val in filter_value:
if not isinstance(val, (str, int, float)):
raise NotImplementedError(
f"Unsupported type: {type(val)} for value: {val}"
)
if operator in {"$in"}:
query_snippet = f"n.`{field}` IN $param_{param_number}"
query_param = {f"param_{param_number}": filter_value}
return (query_snippet, query_param)
elif operator in {"$nin"}:
query_snippet = f"n.`{field}` NOT IN $param_{param_number}"
query_param = {f"param_{param_number}": filter_value}
return (query_snippet, query_param)
elif operator in {"$like"}:
query_snippet = f"n.`{field}` CONTAINS $param_{param_number}"
query_param = {f"param_{param_number}": filter_value.rstrip("%")}
return (query_snippet, query_param)
elif operator in {"$ilike"}:
query_snippet = f"toLower(n.`{field}`) CONTAINS $param_{param_number}"
query_param = {f"param_{param_number}": filter_value.rstrip("%")}
return (query_snippet, query_param)
else:
raise NotImplementedError()
else:
raise NotImplementedError()
def construct_metadata_filter(filter: Dict[str, Any]) -> Tuple[str, Dict]:
if isinstance(filter, dict):
if len(filter) == 1:
# The only operators allowed at the top level are $AND and $OR
# First check if an operator or a field
key, value = list(filter.items())[0]
if key.startswith("$"):
# Then it's an operator
if key.lower() not in ["$and", "$or"]:
raise ValueError(
f"Invalid filter condition. Expected $and or $or "
f"but got: {key}"
)
else:
# Then it's a field
return _handle_field_filter(key, filter[key])
# Here we handle the $and and $or operators
if not isinstance(value, list):
raise ValueError(
f"Expected a list, but got {type(value)} for value: {value}"
)
if key.lower() == "$and":
and_ = combine_queries(
[construct_metadata_filter(el) for el in value], "AND"
)
if len(and_) >= 1:
return and_
else:
raise ValueError(
"Invalid filter condition. Expected a dictionary "
"but got an empty dictionary"
)
elif key.lower() == "$or":
or_ = combine_queries(
[construct_metadata_filter(el) for el in value], "OR"
)
if len(or_) >= 1:
return or_
else:
raise ValueError(
"Invalid filter condition. Expected a dictionary "
"but got an empty dictionary"
)
else:
raise ValueError(
f"Invalid filter condition. Expected $and or $or " f"but got: {key}"
)
elif len(filter) > 1:
# Then all keys have to be fields (they cannot be operators)
for key in filter.keys():
if key.startswith("$"):
raise ValueError(
f"Invalid filter condition. Expected a field but got: {key}"
)
# These should all be fields and combined using an $and operator
and_multiple = collect_params(
[
_handle_field_filter(k, v, index)
for index, (k, v) in enumerate(filter.items())
]
)
if len(and_multiple) >= 1:
return " AND ".join(and_multiple[0]), and_multiple[1]
else:
raise ValueError(
"Invalid filter condition. Expected a dictionary "
"but got an empty dictionary"
)
else:
raise ValueError("Got an empty dictionary for filters.")
class Neo4jVector(VectorStore): class Neo4jVector(VectorStore):
"""`Neo4j` vector index. """`Neo4j` vector index.
@ -243,6 +506,7 @@ class Neo4jVector(VectorStore):
) )
# Verify if the version support vector index # Verify if the version support vector index
self._is_enterprise = False
self.verify_version() self.verify_version()
# Verify that required values are not null # Verify that required values are not null
@ -318,7 +582,8 @@ class Neo4jVector(VectorStore):
indexing. Raises a ValueError if the connected Neo4j version is indexing. Raises a ValueError if the connected Neo4j version is
not supported. not supported.
""" """
version = self.query("CALL dbms.components()")[0]["versions"][0] db_data = self.query("CALL dbms.components()")
version = db_data[0]["versions"][0]
if "aura" in version: if "aura" in version:
version_tuple = tuple(map(int, version.split("-")[0].split("."))) + (0,) version_tuple = tuple(map(int, version.split("-")[0].split("."))) + (0,)
else: else:
@ -331,6 +596,15 @@ class Neo4jVector(VectorStore):
"Version index is only supported in Neo4j version 5.11 or greater" "Version index is only supported in Neo4j version 5.11 or greater"
) )
# Flag for metadata filtering
metadata_target_version = (5, 18, 0)
if version_tuple < metadata_target_version:
self.support_metadata_filter = False
else:
self.support_metadata_filter = True
# Flag for enterprise
self._is_enterprise = True if db_data[0]["edition"] == "enterprise" else False
def retrieve_existing_index(self) -> Optional[int]: def retrieve_existing_index(self) -> Optional[int]:
""" """
Check if the vector index exists in the Neo4j database Check if the vector index exists in the Neo4j database
@ -583,6 +857,7 @@ class Neo4jVector(VectorStore):
query: str, query: str,
k: int = 4, k: int = 4,
params: Dict[str, Any] = {}, params: Dict[str, Any] = {},
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Run similarity search with Neo4jVector. """Run similarity search with Neo4jVector.
@ -596,7 +871,12 @@ class Neo4jVector(VectorStore):
""" """
embedding = self.embedding.embed_query(text=query) embedding = self.embedding.embed_query(text=query)
return self.similarity_search_by_vector( return self.similarity_search_by_vector(
embedding=embedding, k=k, query=query, params=params, **kwargs embedding=embedding,
k=k,
query=query,
params=params,
filter=filter,
**kwargs,
) )
def similarity_search_with_score( def similarity_search_with_score(
@ -604,6 +884,7 @@ class Neo4jVector(VectorStore):
query: str, query: str,
k: int = 4, k: int = 4,
params: Dict[str, Any] = {}, params: Dict[str, Any] = {},
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
"""Return docs most similar to query. """Return docs most similar to query.
@ -617,7 +898,12 @@ class Neo4jVector(VectorStore):
""" """
embedding = self.embedding.embed_query(query) embedding = self.embedding.embed_query(query)
docs = self.similarity_search_with_score_by_vector( docs = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, query=query, params=params, **kwargs embedding=embedding,
k=k,
query=query,
params=params,
filter=filter,
**kwargs,
) )
return docs return docs
@ -625,6 +911,7 @@ class Neo4jVector(VectorStore):
self, self,
embedding: List[float], embedding: List[float],
k: int = 4, k: int = 4,
filter: Optional[Dict[str, Any]] = None,
params: Dict[str, Any] = {}, params: Dict[str, Any] = {},
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
@ -646,6 +933,33 @@ class Neo4jVector(VectorStore):
List[Tuple[Document, float]]: A list of tuples, each containing List[Tuple[Document, float]]: A list of tuples, each containing
a Document object and its similarity score. a Document object and its similarity score.
""" """
if filter:
# Verify that 5.18 or later is used
if not self.support_metadata_filter:
raise ValueError(
"Metadata filtering is only supported in "
"Neo4j version 5.18 or greater"
)
# Metadata filtering and hybrid doesn't work
if self.search_type == SearchType.HYBRID:
raise ValueError(
"Metadata filtering can't be use in combination with "
"a hybrid search approach"
)
parallel_query = "CYPHER runtime = parallel " if self._is_enterprise else ""
base_index_query = parallel_query + f"MATCH (n:`{self.node_label}`) WHERE "
base_cosine_query = (
" WITH n as node, vector.similarity.cosine("
f"n.`{self.embedding_node_property}`, "
"$embedding) AS score ORDER BY score DESC LIMIT toInteger($k) "
)
filter_snippets, filter_params = construct_metadata_filter(filter)
index_query = base_index_query + filter_snippets + base_cosine_query
else:
index_query = _get_search_index_query(self.search_type)
filter_params = {}
default_retrieval = ( default_retrieval = (
f"RETURN node.`{self.text_node_property}` AS text, score, " f"RETURN node.`{self.text_node_property}` AS text, score, "
f"node {{.*, `{self.text_node_property}`: Null, " f"node {{.*, `{self.text_node_property}`: Null, "
@ -656,7 +970,7 @@ class Neo4jVector(VectorStore):
self.retrieval_query if self.retrieval_query else default_retrieval self.retrieval_query if self.retrieval_query else default_retrieval
) )
read_query = _get_search_index_query(self.search_type) + retrieval_query read_query = index_query + retrieval_query
parameters = { parameters = {
"index": self.index_name, "index": self.index_name,
"k": k, "k": k,
@ -664,6 +978,7 @@ class Neo4jVector(VectorStore):
"keyword_index": self.keyword_index_name, "keyword_index": self.keyword_index_name,
"query": remove_lucene_chars(kwargs["query"]), "query": remove_lucene_chars(kwargs["query"]),
**params, **params,
**filter_params,
} }
results = self.query(read_query, params=parameters) results = self.query(read_query, params=parameters)
@ -688,6 +1003,7 @@ class Neo4jVector(VectorStore):
self, self,
embedding: List[float], embedding: List[float],
k: int = 4, k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs most similar to embedding vector. """Return docs most similar to embedding vector.
@ -700,7 +1016,7 @@ class Neo4jVector(VectorStore):
List of Documents most similar to the query vector. List of Documents most similar to the query vector.
""" """
docs_and_scores = self.similarity_search_with_score_by_vector( docs_and_scores = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, **kwargs embedding=embedding, k=k, filter=filter, **kwargs
) )
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]

View File

@ -1,6 +1,6 @@
"""Test Neo4jVector functionality.""" """Test Neo4jVector functionality."""
import os import os
from typing import List from typing import Any, Dict, List, cast
from langchain_core.documents import Document from langchain_core.documents import Document
@ -11,6 +11,13 @@ from langchain_community.vectorstores.neo4j_vector import (
) )
from langchain_community.vectorstores.utils import DistanceStrategy from langchain_community.vectorstores.utils import DistanceStrategy
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import (
DOCUMENTS,
TYPE_1_FILTERING_TEST_CASES,
TYPE_2_FILTERING_TEST_CASES,
TYPE_3_FILTERING_TEST_CASES,
TYPE_4_FILTERING_TEST_CASES,
)
url = os.environ.get("NEO4J_URL", "bolt://localhost:7687") url = os.environ.get("NEO4J_URL", "bolt://localhost:7687")
username = os.environ.get("NEO4J_USERNAME", "neo4j") username = os.environ.get("NEO4J_USERNAME", "neo4j")
@ -721,6 +728,8 @@ def test_index_fetching() -> None:
index_0_store = fetch_store(index_0_str) index_0_store = fetch_store(index_0_str)
assert index_0_store.index_name == index_0_str assert index_0_store.index_name == index_0_str
drop_vector_indexes(index_1_store)
drop_vector_indexes(index_0_store)
def test_retrieval_params() -> None: def test_retrieval_params() -> None:
@ -741,6 +750,7 @@ def test_retrieval_params() -> None:
Document(page_content="test", metadata={"test": "test1"}), Document(page_content="test", metadata={"test": "test1"}),
Document(page_content="test", metadata={"test": "test1"}), Document(page_content="test", metadata={"test": "test1"}),
] ]
drop_vector_indexes(docsearch)
def test_retrieval_dictionary() -> None: def test_retrieval_dictionary() -> None:
@ -767,3 +777,38 @@ def test_retrieval_dictionary() -> None:
] ]
output = docsearch.similarity_search("Foo", k=1) output = docsearch.similarity_search("Foo", k=1)
assert output == expected_output assert output == expected_output
drop_vector_indexes(docsearch)
def test_metadata_filters_type1() -> None:
"""Test metadata filters"""
docsearch = Neo4jVector.from_documents(
DOCUMENTS,
embedding=FakeEmbeddings(),
pre_delete_collection=True,
)
# We don't test type 5, because LIKE has very SQL specific examples
for example in (
TYPE_1_FILTERING_TEST_CASES
+ TYPE_2_FILTERING_TEST_CASES
+ TYPE_3_FILTERING_TEST_CASES
+ TYPE_4_FILTERING_TEST_CASES
):
filter_dict = cast(Dict[str, Any], example[0])
output = docsearch.similarity_search("Foo", filter=filter_dict)
indices = cast(List[int], example[1])
adjusted_indices = [index - 1 for index in indices]
expected_output = [DOCUMENTS[index] for index in adjusted_indices]
# We don't return id properties from similarity search by default
# Also remove any key where the value is None
for doc in expected_output:
if "id" in doc.metadata:
del doc.metadata["id"]
keys_with_none = [
key for key, value in doc.metadata.items() if value is None
]
for key in keys_with_none:
del doc.metadata[key]
assert output == expected_output
drop_vector_indexes(docsearch)