mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
community[minor]: Add metadata filtering support for neo4j vector (#20001)
This commit is contained in:
parent
b52b78478f
commit
df25829f33
@ -28,6 +28,35 @@ DISTANCE_MAPPING = {
|
||||
DistanceStrategy.COSINE: "cosine",
|
||||
}
|
||||
|
||||
COMPARISONS_TO_NATIVE = {
|
||||
"$eq": "=",
|
||||
"$ne": "<>",
|
||||
"$lt": "<",
|
||||
"$lte": "<=",
|
||||
"$gt": ">",
|
||||
"$gte": ">=",
|
||||
}
|
||||
|
||||
SPECIAL_CASED_OPERATORS = {
|
||||
"$in",
|
||||
"$nin",
|
||||
"$between",
|
||||
}
|
||||
|
||||
TEXT_OPERATORS = {
|
||||
"$like",
|
||||
"$ilike",
|
||||
}
|
||||
|
||||
LOGICAL_OPERATORS = {"$and", "$or"}
|
||||
|
||||
SUPPORTED_OPERATORS = (
|
||||
set(COMPARISONS_TO_NATIVE)
|
||||
.union(TEXT_OPERATORS)
|
||||
.union(LOGICAL_OPERATORS)
|
||||
.union(SPECIAL_CASED_OPERATORS)
|
||||
)
|
||||
|
||||
|
||||
class SearchType(str, enum.Enum):
|
||||
"""Enumerator of the Distance strategies."""
|
||||
@ -133,6 +162,240 @@ def dict_to_yaml_str(input_dict: Dict, indent: int = 0) -> str:
|
||||
return yaml_str
|
||||
|
||||
|
||||
def combine_queries(
|
||||
input_queries: List[Tuple[str, Dict[str, Any]]], operator: str
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
# Initialize variables to hold the combined query and parameters
|
||||
combined_query: str = ""
|
||||
combined_params: Dict = {}
|
||||
param_counter: Dict = {}
|
||||
|
||||
for query, params in input_queries:
|
||||
# Process each query fragment and its parameters
|
||||
new_query = query
|
||||
for param, value in params.items():
|
||||
# Update the parameter name to ensure uniqueness
|
||||
if param in param_counter:
|
||||
param_counter[param] += 1
|
||||
else:
|
||||
param_counter[param] = 1
|
||||
new_param_name = f"{param}_{param_counter[param]}"
|
||||
|
||||
# Replace the parameter in the query fragment
|
||||
new_query = new_query.replace(f"${param}", f"${new_param_name}")
|
||||
# Add the parameter to the combined parameters dictionary
|
||||
combined_params[new_param_name] = value
|
||||
|
||||
# Combine the query fragments with an AND operator
|
||||
if combined_query:
|
||||
combined_query += f" {operator} "
|
||||
combined_query += f"({new_query})"
|
||||
|
||||
return combined_query, combined_params
|
||||
|
||||
|
||||
def collect_params(
|
||||
input_data: List[Tuple[str, Dict[str, str]]],
|
||||
) -> Tuple[List[str], Dict[str, Any]]:
|
||||
"""
|
||||
Transform the input data into the desired format.
|
||||
|
||||
Args:
|
||||
- input_data (list of tuples): Input data to transform.
|
||||
Each tuple contains a string and a dictionary.
|
||||
|
||||
Returns:
|
||||
- tuple: A tuple containing a list of strings and a dictionary.
|
||||
"""
|
||||
# Initialize variables to hold the output parts
|
||||
query_parts = []
|
||||
params = {}
|
||||
|
||||
# Loop through each item in the input data
|
||||
for query_part, param in input_data:
|
||||
# Append the query part to the list
|
||||
query_parts.append(query_part)
|
||||
# Update the params dictionary with the param dictionary
|
||||
params.update(param)
|
||||
|
||||
# Return the transformed data
|
||||
return (query_parts, params)
|
||||
|
||||
|
||||
def _handle_field_filter(
|
||||
field: str, value: Any, param_number: int = 1
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Create a filter for a specific field.
|
||||
|
||||
Args:
|
||||
field: name of field
|
||||
value: value to filter
|
||||
If provided as is then this will be an equality filter
|
||||
If provided as a dictionary then this will be a filter, the key
|
||||
will be the operator and the value will be the value to filter by
|
||||
param_number: sequence number of parameters used to map between param
|
||||
dict and Cypher snippet
|
||||
|
||||
Returns a tuple of
|
||||
- Cypher filter snippet
|
||||
- Dictionary with parameters used in filter snippet
|
||||
"""
|
||||
if not isinstance(field, str):
|
||||
raise ValueError(
|
||||
f"field should be a string but got: {type(field)} with value: {field}"
|
||||
)
|
||||
|
||||
if field.startswith("$"):
|
||||
raise ValueError(
|
||||
f"Invalid filter condition. Expected a field but got an operator: "
|
||||
f"{field}"
|
||||
)
|
||||
|
||||
# Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters
|
||||
if not field.isidentifier():
|
||||
raise ValueError(f"Invalid field name: {field}. Expected a valid identifier.")
|
||||
|
||||
if isinstance(value, dict):
|
||||
# This is a filter specification
|
||||
if len(value) != 1:
|
||||
raise ValueError(
|
||||
"Invalid filter condition. Expected a value which "
|
||||
"is a dictionary with a single key that corresponds to an operator "
|
||||
f"but got a dictionary with {len(value)} keys. The first few "
|
||||
f"keys are: {list(value.keys())[:3]}"
|
||||
)
|
||||
operator, filter_value = list(value.items())[0]
|
||||
# Verify that that operator is an operator
|
||||
if operator not in SUPPORTED_OPERATORS:
|
||||
raise ValueError(
|
||||
f"Invalid operator: {operator}. "
|
||||
f"Expected one of {SUPPORTED_OPERATORS}"
|
||||
)
|
||||
else: # Then we assume an equality operator
|
||||
operator = "$eq"
|
||||
filter_value = value
|
||||
|
||||
if operator in COMPARISONS_TO_NATIVE:
|
||||
# Then we implement an equality filter
|
||||
# native is trusted input
|
||||
native = COMPARISONS_TO_NATIVE[operator]
|
||||
query_snippet = f"n.`{field}` {native} $param_{param_number}"
|
||||
query_param = {f"param_{param_number}": filter_value}
|
||||
return (query_snippet, query_param)
|
||||
elif operator == "$between":
|
||||
low, high = filter_value
|
||||
query_snippet = (
|
||||
f"$param_{param_number}_low <= n.`{field}` <= $param_{param_number}_high"
|
||||
)
|
||||
query_param = {
|
||||
f"param_{param_number}_low": low,
|
||||
f"param_{param_number}_high": high,
|
||||
}
|
||||
return (query_snippet, query_param)
|
||||
|
||||
elif operator in {"$in", "$nin", "$like", "$ilike"}:
|
||||
# We'll do force coercion to text
|
||||
if operator in {"$in", "$nin"}:
|
||||
for val in filter_value:
|
||||
if not isinstance(val, (str, int, float)):
|
||||
raise NotImplementedError(
|
||||
f"Unsupported type: {type(val)} for value: {val}"
|
||||
)
|
||||
if operator in {"$in"}:
|
||||
query_snippet = f"n.`{field}` IN $param_{param_number}"
|
||||
query_param = {f"param_{param_number}": filter_value}
|
||||
return (query_snippet, query_param)
|
||||
elif operator in {"$nin"}:
|
||||
query_snippet = f"n.`{field}` NOT IN $param_{param_number}"
|
||||
query_param = {f"param_{param_number}": filter_value}
|
||||
return (query_snippet, query_param)
|
||||
elif operator in {"$like"}:
|
||||
query_snippet = f"n.`{field}` CONTAINS $param_{param_number}"
|
||||
query_param = {f"param_{param_number}": filter_value.rstrip("%")}
|
||||
return (query_snippet, query_param)
|
||||
elif operator in {"$ilike"}:
|
||||
query_snippet = f"toLower(n.`{field}`) CONTAINS $param_{param_number}"
|
||||
query_param = {f"param_{param_number}": filter_value.rstrip("%")}
|
||||
return (query_snippet, query_param)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def construct_metadata_filter(filter: Dict[str, Any]) -> Tuple[str, Dict]:
|
||||
if isinstance(filter, dict):
|
||||
if len(filter) == 1:
|
||||
# The only operators allowed at the top level are $AND and $OR
|
||||
# First check if an operator or a field
|
||||
key, value = list(filter.items())[0]
|
||||
if key.startswith("$"):
|
||||
# Then it's an operator
|
||||
if key.lower() not in ["$and", "$or"]:
|
||||
raise ValueError(
|
||||
f"Invalid filter condition. Expected $and or $or "
|
||||
f"but got: {key}"
|
||||
)
|
||||
else:
|
||||
# Then it's a field
|
||||
return _handle_field_filter(key, filter[key])
|
||||
|
||||
# Here we handle the $and and $or operators
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(
|
||||
f"Expected a list, but got {type(value)} for value: {value}"
|
||||
)
|
||||
if key.lower() == "$and":
|
||||
and_ = combine_queries(
|
||||
[construct_metadata_filter(el) for el in value], "AND"
|
||||
)
|
||||
if len(and_) >= 1:
|
||||
return and_
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid filter condition. Expected a dictionary "
|
||||
"but got an empty dictionary"
|
||||
)
|
||||
elif key.lower() == "$or":
|
||||
or_ = combine_queries(
|
||||
[construct_metadata_filter(el) for el in value], "OR"
|
||||
)
|
||||
if len(or_) >= 1:
|
||||
return or_
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid filter condition. Expected a dictionary "
|
||||
"but got an empty dictionary"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid filter condition. Expected $and or $or " f"but got: {key}"
|
||||
)
|
||||
elif len(filter) > 1:
|
||||
# Then all keys have to be fields (they cannot be operators)
|
||||
for key in filter.keys():
|
||||
if key.startswith("$"):
|
||||
raise ValueError(
|
||||
f"Invalid filter condition. Expected a field but got: {key}"
|
||||
)
|
||||
# These should all be fields and combined using an $and operator
|
||||
and_multiple = collect_params(
|
||||
[
|
||||
_handle_field_filter(k, v, index)
|
||||
for index, (k, v) in enumerate(filter.items())
|
||||
]
|
||||
)
|
||||
if len(and_multiple) >= 1:
|
||||
return " AND ".join(and_multiple[0]), and_multiple[1]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid filter condition. Expected a dictionary "
|
||||
"but got an empty dictionary"
|
||||
)
|
||||
else:
|
||||
raise ValueError("Got an empty dictionary for filters.")
|
||||
|
||||
|
||||
class Neo4jVector(VectorStore):
|
||||
"""`Neo4j` vector index.
|
||||
|
||||
@ -243,6 +506,7 @@ class Neo4jVector(VectorStore):
|
||||
)
|
||||
|
||||
# Verify if the version support vector index
|
||||
self._is_enterprise = False
|
||||
self.verify_version()
|
||||
|
||||
# Verify that required values are not null
|
||||
@ -318,7 +582,8 @@ class Neo4jVector(VectorStore):
|
||||
indexing. Raises a ValueError if the connected Neo4j version is
|
||||
not supported.
|
||||
"""
|
||||
version = self.query("CALL dbms.components()")[0]["versions"][0]
|
||||
db_data = self.query("CALL dbms.components()")
|
||||
version = db_data[0]["versions"][0]
|
||||
if "aura" in version:
|
||||
version_tuple = tuple(map(int, version.split("-")[0].split("."))) + (0,)
|
||||
else:
|
||||
@ -331,6 +596,15 @@ class Neo4jVector(VectorStore):
|
||||
"Version index is only supported in Neo4j version 5.11 or greater"
|
||||
)
|
||||
|
||||
# Flag for metadata filtering
|
||||
metadata_target_version = (5, 18, 0)
|
||||
if version_tuple < metadata_target_version:
|
||||
self.support_metadata_filter = False
|
||||
else:
|
||||
self.support_metadata_filter = True
|
||||
# Flag for enterprise
|
||||
self._is_enterprise = True if db_data[0]["edition"] == "enterprise" else False
|
||||
|
||||
def retrieve_existing_index(self) -> Optional[int]:
|
||||
"""
|
||||
Check if the vector index exists in the Neo4j database
|
||||
@ -583,6 +857,7 @@ class Neo4jVector(VectorStore):
|
||||
query: str,
|
||||
k: int = 4,
|
||||
params: Dict[str, Any] = {},
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Run similarity search with Neo4jVector.
|
||||
@ -596,7 +871,12 @@ class Neo4jVector(VectorStore):
|
||||
"""
|
||||
embedding = self.embedding.embed_query(text=query)
|
||||
return self.similarity_search_by_vector(
|
||||
embedding=embedding, k=k, query=query, params=params, **kwargs
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
query=query,
|
||||
params=params,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def similarity_search_with_score(
|
||||
@ -604,6 +884,7 @@ class Neo4jVector(VectorStore):
|
||||
query: str,
|
||||
k: int = 4,
|
||||
params: Dict[str, Any] = {},
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query.
|
||||
@ -617,7 +898,12 @@ class Neo4jVector(VectorStore):
|
||||
"""
|
||||
embedding = self.embedding.embed_query(query)
|
||||
docs = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k, query=query, params=params, **kwargs
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
query=query,
|
||||
params=params,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
return docs
|
||||
|
||||
@ -625,6 +911,7 @@ class Neo4jVector(VectorStore):
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
params: Dict[str, Any] = {},
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
@ -646,6 +933,33 @@ class Neo4jVector(VectorStore):
|
||||
List[Tuple[Document, float]]: A list of tuples, each containing
|
||||
a Document object and its similarity score.
|
||||
"""
|
||||
if filter:
|
||||
# Verify that 5.18 or later is used
|
||||
if not self.support_metadata_filter:
|
||||
raise ValueError(
|
||||
"Metadata filtering is only supported in "
|
||||
"Neo4j version 5.18 or greater"
|
||||
)
|
||||
# Metadata filtering and hybrid doesn't work
|
||||
if self.search_type == SearchType.HYBRID:
|
||||
raise ValueError(
|
||||
"Metadata filtering can't be use in combination with "
|
||||
"a hybrid search approach"
|
||||
)
|
||||
parallel_query = "CYPHER runtime = parallel " if self._is_enterprise else ""
|
||||
base_index_query = parallel_query + f"MATCH (n:`{self.node_label}`) WHERE "
|
||||
base_cosine_query = (
|
||||
" WITH n as node, vector.similarity.cosine("
|
||||
f"n.`{self.embedding_node_property}`, "
|
||||
"$embedding) AS score ORDER BY score DESC LIMIT toInteger($k) "
|
||||
)
|
||||
filter_snippets, filter_params = construct_metadata_filter(filter)
|
||||
index_query = base_index_query + filter_snippets + base_cosine_query
|
||||
|
||||
else:
|
||||
index_query = _get_search_index_query(self.search_type)
|
||||
filter_params = {}
|
||||
|
||||
default_retrieval = (
|
||||
f"RETURN node.`{self.text_node_property}` AS text, score, "
|
||||
f"node {{.*, `{self.text_node_property}`: Null, "
|
||||
@ -656,7 +970,7 @@ class Neo4jVector(VectorStore):
|
||||
self.retrieval_query if self.retrieval_query else default_retrieval
|
||||
)
|
||||
|
||||
read_query = _get_search_index_query(self.search_type) + retrieval_query
|
||||
read_query = index_query + retrieval_query
|
||||
parameters = {
|
||||
"index": self.index_name,
|
||||
"k": k,
|
||||
@ -664,6 +978,7 @@ class Neo4jVector(VectorStore):
|
||||
"keyword_index": self.keyword_index_name,
|
||||
"query": remove_lucene_chars(kwargs["query"]),
|
||||
**params,
|
||||
**filter_params,
|
||||
}
|
||||
|
||||
results = self.query(read_query, params=parameters)
|
||||
@ -688,6 +1003,7 @@ class Neo4jVector(VectorStore):
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
@ -700,7 +1016,7 @@ class Neo4jVector(VectorStore):
|
||||
List of Documents most similar to the query vector.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k, **kwargs
|
||||
embedding=embedding, k=k, filter=filter, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Test Neo4jVector functionality."""
|
||||
import os
|
||||
from typing import List
|
||||
from typing import Any, Dict, List, cast
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@ -11,6 +11,13 @@ from langchain_community.vectorstores.neo4j_vector import (
|
||||
)
|
||||
from langchain_community.vectorstores.utils import DistanceStrategy
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import (
|
||||
DOCUMENTS,
|
||||
TYPE_1_FILTERING_TEST_CASES,
|
||||
TYPE_2_FILTERING_TEST_CASES,
|
||||
TYPE_3_FILTERING_TEST_CASES,
|
||||
TYPE_4_FILTERING_TEST_CASES,
|
||||
)
|
||||
|
||||
url = os.environ.get("NEO4J_URL", "bolt://localhost:7687")
|
||||
username = os.environ.get("NEO4J_USERNAME", "neo4j")
|
||||
@ -721,6 +728,8 @@ def test_index_fetching() -> None:
|
||||
|
||||
index_0_store = fetch_store(index_0_str)
|
||||
assert index_0_store.index_name == index_0_str
|
||||
drop_vector_indexes(index_1_store)
|
||||
drop_vector_indexes(index_0_store)
|
||||
|
||||
|
||||
def test_retrieval_params() -> None:
|
||||
@ -741,6 +750,7 @@ def test_retrieval_params() -> None:
|
||||
Document(page_content="test", metadata={"test": "test1"}),
|
||||
Document(page_content="test", metadata={"test": "test1"}),
|
||||
]
|
||||
drop_vector_indexes(docsearch)
|
||||
|
||||
|
||||
def test_retrieval_dictionary() -> None:
|
||||
@ -767,3 +777,38 @@ def test_retrieval_dictionary() -> None:
|
||||
]
|
||||
output = docsearch.similarity_search("Foo", k=1)
|
||||
assert output == expected_output
|
||||
drop_vector_indexes(docsearch)
|
||||
|
||||
|
||||
def test_metadata_filters_type1() -> None:
|
||||
"""Test metadata filters"""
|
||||
docsearch = Neo4jVector.from_documents(
|
||||
DOCUMENTS,
|
||||
embedding=FakeEmbeddings(),
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
# We don't test type 5, because LIKE has very SQL specific examples
|
||||
for example in (
|
||||
TYPE_1_FILTERING_TEST_CASES
|
||||
+ TYPE_2_FILTERING_TEST_CASES
|
||||
+ TYPE_3_FILTERING_TEST_CASES
|
||||
+ TYPE_4_FILTERING_TEST_CASES
|
||||
):
|
||||
filter_dict = cast(Dict[str, Any], example[0])
|
||||
output = docsearch.similarity_search("Foo", filter=filter_dict)
|
||||
indices = cast(List[int], example[1])
|
||||
adjusted_indices = [index - 1 for index in indices]
|
||||
expected_output = [DOCUMENTS[index] for index in adjusted_indices]
|
||||
# We don't return id properties from similarity search by default
|
||||
# Also remove any key where the value is None
|
||||
for doc in expected_output:
|
||||
if "id" in doc.metadata:
|
||||
del doc.metadata["id"]
|
||||
keys_with_none = [
|
||||
key for key, value in doc.metadata.items() if value is None
|
||||
]
|
||||
for key in keys_with_none:
|
||||
del doc.metadata[key]
|
||||
|
||||
assert output == expected_output
|
||||
drop_vector_indexes(docsearch)
|
||||
|
Loading…
Reference in New Issue
Block a user