diff --git a/libs/community/langchain_community/vectorstores/neo4j_vector.py b/libs/community/langchain_community/vectorstores/neo4j_vector.py index 05f4b3b38b..4fd6cee51d 100644 --- a/libs/community/langchain_community/vectorstores/neo4j_vector.py +++ b/libs/community/langchain_community/vectorstores/neo4j_vector.py @@ -28,6 +28,35 @@ DISTANCE_MAPPING = { DistanceStrategy.COSINE: "cosine", } +COMPARISONS_TO_NATIVE = { + "$eq": "=", + "$ne": "<>", + "$lt": "<", + "$lte": "<=", + "$gt": ">", + "$gte": ">=", +} + +SPECIAL_CASED_OPERATORS = { + "$in", + "$nin", + "$between", +} + +TEXT_OPERATORS = { + "$like", + "$ilike", +} + +LOGICAL_OPERATORS = {"$and", "$or"} + +SUPPORTED_OPERATORS = ( + set(COMPARISONS_TO_NATIVE) + .union(TEXT_OPERATORS) + .union(LOGICAL_OPERATORS) + .union(SPECIAL_CASED_OPERATORS) +) + class SearchType(str, enum.Enum): """Enumerator of the Distance strategies.""" @@ -133,6 +162,240 @@ def dict_to_yaml_str(input_dict: Dict, indent: int = 0) -> str: return yaml_str +def combine_queries( + input_queries: List[Tuple[str, Dict[str, Any]]], operator: str +) -> Tuple[str, Dict[str, Any]]: + # Initialize variables to hold the combined query and parameters + combined_query: str = "" + combined_params: Dict = {} + param_counter: Dict = {} + + for query, params in input_queries: + # Process each query fragment and its parameters + new_query = query + for param, value in params.items(): + # Update the parameter name to ensure uniqueness + if param in param_counter: + param_counter[param] += 1 + else: + param_counter[param] = 1 + new_param_name = f"{param}_{param_counter[param]}" + + # Replace the parameter in the query fragment + new_query = new_query.replace(f"${param}", f"${new_param_name}") + # Add the parameter to the combined parameters dictionary + combined_params[new_param_name] = value + + # Combine the query fragments with an AND operator + if combined_query: + combined_query += f" {operator} " + combined_query += f"({new_query})" + + return combined_query, combined_params + + +def collect_params( + input_data: List[Tuple[str, Dict[str, str]]], +) -> Tuple[List[str], Dict[str, Any]]: + """ + Transform the input data into the desired format. + + Args: + - input_data (list of tuples): Input data to transform. + Each tuple contains a string and a dictionary. + + Returns: + - tuple: A tuple containing a list of strings and a dictionary. + """ + # Initialize variables to hold the output parts + query_parts = [] + params = {} + + # Loop through each item in the input data + for query_part, param in input_data: + # Append the query part to the list + query_parts.append(query_part) + # Update the params dictionary with the param dictionary + params.update(param) + + # Return the transformed data + return (query_parts, params) + + +def _handle_field_filter( + field: str, value: Any, param_number: int = 1 +) -> Tuple[str, Dict]: + """Create a filter for a specific field. + + Args: + field: name of field + value: value to filter + If provided as is then this will be an equality filter + If provided as a dictionary then this will be a filter, the key + will be the operator and the value will be the value to filter by + param_number: sequence number of parameters used to map between param + dict and Cypher snippet + + Returns a tuple of + - Cypher filter snippet + - Dictionary with parameters used in filter snippet + """ + if not isinstance(field, str): + raise ValueError( + f"field should be a string but got: {type(field)} with value: {field}" + ) + + if field.startswith("$"): + raise ValueError( + f"Invalid filter condition. Expected a field but got an operator: " + f"{field}" + ) + + # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters + if not field.isidentifier(): + raise ValueError(f"Invalid field name: {field}. Expected a valid identifier.") + + if isinstance(value, dict): + # This is a filter specification + if len(value) != 1: + raise ValueError( + "Invalid filter condition. Expected a value which " + "is a dictionary with a single key that corresponds to an operator " + f"but got a dictionary with {len(value)} keys. The first few " + f"keys are: {list(value.keys())[:3]}" + ) + operator, filter_value = list(value.items())[0] + # Verify that that operator is an operator + if operator not in SUPPORTED_OPERATORS: + raise ValueError( + f"Invalid operator: {operator}. " + f"Expected one of {SUPPORTED_OPERATORS}" + ) + else: # Then we assume an equality operator + operator = "$eq" + filter_value = value + + if operator in COMPARISONS_TO_NATIVE: + # Then we implement an equality filter + # native is trusted input + native = COMPARISONS_TO_NATIVE[operator] + query_snippet = f"n.`{field}` {native} $param_{param_number}" + query_param = {f"param_{param_number}": filter_value} + return (query_snippet, query_param) + elif operator == "$between": + low, high = filter_value + query_snippet = ( + f"$param_{param_number}_low <= n.`{field}` <= $param_{param_number}_high" + ) + query_param = { + f"param_{param_number}_low": low, + f"param_{param_number}_high": high, + } + return (query_snippet, query_param) + + elif operator in {"$in", "$nin", "$like", "$ilike"}: + # We'll do force coercion to text + if operator in {"$in", "$nin"}: + for val in filter_value: + if not isinstance(val, (str, int, float)): + raise NotImplementedError( + f"Unsupported type: {type(val)} for value: {val}" + ) + if operator in {"$in"}: + query_snippet = f"n.`{field}` IN $param_{param_number}" + query_param = {f"param_{param_number}": filter_value} + return (query_snippet, query_param) + elif operator in {"$nin"}: + query_snippet = f"n.`{field}` NOT IN $param_{param_number}" + query_param = {f"param_{param_number}": filter_value} + return (query_snippet, query_param) + elif operator in {"$like"}: + query_snippet = f"n.`{field}` CONTAINS $param_{param_number}" + query_param = {f"param_{param_number}": filter_value.rstrip("%")} + return (query_snippet, query_param) + elif operator in {"$ilike"}: + query_snippet = f"toLower(n.`{field}`) CONTAINS $param_{param_number}" + query_param = {f"param_{param_number}": filter_value.rstrip("%")} + return (query_snippet, query_param) + else: + raise NotImplementedError() + else: + raise NotImplementedError() + + +def construct_metadata_filter(filter: Dict[str, Any]) -> Tuple[str, Dict]: + if isinstance(filter, dict): + if len(filter) == 1: + # The only operators allowed at the top level are $AND and $OR + # First check if an operator or a field + key, value = list(filter.items())[0] + if key.startswith("$"): + # Then it's an operator + if key.lower() not in ["$and", "$or"]: + raise ValueError( + f"Invalid filter condition. Expected $and or $or " + f"but got: {key}" + ) + else: + # Then it's a field + return _handle_field_filter(key, filter[key]) + + # Here we handle the $and and $or operators + if not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) + if key.lower() == "$and": + and_ = combine_queries( + [construct_metadata_filter(el) for el in value], "AND" + ) + if len(and_) >= 1: + return and_ + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + elif key.lower() == "$or": + or_ = combine_queries( + [construct_metadata_filter(el) for el in value], "OR" + ) + if len(or_) >= 1: + return or_ + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + else: + raise ValueError( + f"Invalid filter condition. Expected $and or $or " f"but got: {key}" + ) + elif len(filter) > 1: + # Then all keys have to be fields (they cannot be operators) + for key in filter.keys(): + if key.startswith("$"): + raise ValueError( + f"Invalid filter condition. Expected a field but got: {key}" + ) + # These should all be fields and combined using an $and operator + and_multiple = collect_params( + [ + _handle_field_filter(k, v, index) + for index, (k, v) in enumerate(filter.items()) + ] + ) + if len(and_multiple) >= 1: + return " AND ".join(and_multiple[0]), and_multiple[1] + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + else: + raise ValueError("Got an empty dictionary for filters.") + + class Neo4jVector(VectorStore): """`Neo4j` vector index. @@ -243,6 +506,7 @@ class Neo4jVector(VectorStore): ) # Verify if the version support vector index + self._is_enterprise = False self.verify_version() # Verify that required values are not null @@ -318,7 +582,8 @@ class Neo4jVector(VectorStore): indexing. Raises a ValueError if the connected Neo4j version is not supported. """ - version = self.query("CALL dbms.components()")[0]["versions"][0] + db_data = self.query("CALL dbms.components()") + version = db_data[0]["versions"][0] if "aura" in version: version_tuple = tuple(map(int, version.split("-")[0].split("."))) + (0,) else: @@ -331,6 +596,15 @@ class Neo4jVector(VectorStore): "Version index is only supported in Neo4j version 5.11 or greater" ) + # Flag for metadata filtering + metadata_target_version = (5, 18, 0) + if version_tuple < metadata_target_version: + self.support_metadata_filter = False + else: + self.support_metadata_filter = True + # Flag for enterprise + self._is_enterprise = True if db_data[0]["edition"] == "enterprise" else False + def retrieve_existing_index(self) -> Optional[int]: """ Check if the vector index exists in the Neo4j database @@ -583,6 +857,7 @@ class Neo4jVector(VectorStore): query: str, k: int = 4, params: Dict[str, Any] = {}, + filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """Run similarity search with Neo4jVector. @@ -596,7 +871,12 @@ class Neo4jVector(VectorStore): """ embedding = self.embedding.embed_query(text=query) return self.similarity_search_by_vector( - embedding=embedding, k=k, query=query, params=params, **kwargs + embedding=embedding, + k=k, + query=query, + params=params, + filter=filter, + **kwargs, ) def similarity_search_with_score( @@ -604,6 +884,7 @@ class Neo4jVector(VectorStore): query: str, k: int = 4, params: Dict[str, Any] = {}, + filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs most similar to query. @@ -617,7 +898,12 @@ class Neo4jVector(VectorStore): """ embedding = self.embedding.embed_query(query) docs = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, query=query, params=params, **kwargs + embedding=embedding, + k=k, + query=query, + params=params, + filter=filter, + **kwargs, ) return docs @@ -625,6 +911,7 @@ class Neo4jVector(VectorStore): self, embedding: List[float], k: int = 4, + filter: Optional[Dict[str, Any]] = None, params: Dict[str, Any] = {}, **kwargs: Any, ) -> List[Tuple[Document, float]]: @@ -646,6 +933,33 @@ class Neo4jVector(VectorStore): List[Tuple[Document, float]]: A list of tuples, each containing a Document object and its similarity score. """ + if filter: + # Verify that 5.18 or later is used + if not self.support_metadata_filter: + raise ValueError( + "Metadata filtering is only supported in " + "Neo4j version 5.18 or greater" + ) + # Metadata filtering and hybrid doesn't work + if self.search_type == SearchType.HYBRID: + raise ValueError( + "Metadata filtering can't be use in combination with " + "a hybrid search approach" + ) + parallel_query = "CYPHER runtime = parallel " if self._is_enterprise else "" + base_index_query = parallel_query + f"MATCH (n:`{self.node_label}`) WHERE " + base_cosine_query = ( + " WITH n as node, vector.similarity.cosine(" + f"n.`{self.embedding_node_property}`, " + "$embedding) AS score ORDER BY score DESC LIMIT toInteger($k) " + ) + filter_snippets, filter_params = construct_metadata_filter(filter) + index_query = base_index_query + filter_snippets + base_cosine_query + + else: + index_query = _get_search_index_query(self.search_type) + filter_params = {} + default_retrieval = ( f"RETURN node.`{self.text_node_property}` AS text, score, " f"node {{.*, `{self.text_node_property}`: Null, " @@ -656,7 +970,7 @@ class Neo4jVector(VectorStore): self.retrieval_query if self.retrieval_query else default_retrieval ) - read_query = _get_search_index_query(self.search_type) + retrieval_query + read_query = index_query + retrieval_query parameters = { "index": self.index_name, "k": k, @@ -664,6 +978,7 @@ class Neo4jVector(VectorStore): "keyword_index": self.keyword_index_name, "query": remove_lucene_chars(kwargs["query"]), **params, + **filter_params, } results = self.query(read_query, params=parameters) @@ -688,6 +1003,7 @@ class Neo4jVector(VectorStore): self, embedding: List[float], k: int = 4, + filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """Return docs most similar to embedding vector. @@ -700,7 +1016,7 @@ class Neo4jVector(VectorStore): List of Documents most similar to the query vector. """ docs_and_scores = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, **kwargs + embedding=embedding, k=k, filter=filter, **kwargs ) return [doc for doc, _ in docs_and_scores] diff --git a/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py b/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py index 8ef132b489..de68c59631 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py +++ b/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py @@ -1,6 +1,6 @@ """Test Neo4jVector functionality.""" import os -from typing import List +from typing import Any, Dict, List, cast from langchain_core.documents import Document @@ -11,6 +11,13 @@ from langchain_community.vectorstores.neo4j_vector import ( ) from langchain_community.vectorstores.utils import DistanceStrategy from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings +from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import ( + DOCUMENTS, + TYPE_1_FILTERING_TEST_CASES, + TYPE_2_FILTERING_TEST_CASES, + TYPE_3_FILTERING_TEST_CASES, + TYPE_4_FILTERING_TEST_CASES, +) url = os.environ.get("NEO4J_URL", "bolt://localhost:7687") username = os.environ.get("NEO4J_USERNAME", "neo4j") @@ -721,6 +728,8 @@ def test_index_fetching() -> None: index_0_store = fetch_store(index_0_str) assert index_0_store.index_name == index_0_str + drop_vector_indexes(index_1_store) + drop_vector_indexes(index_0_store) def test_retrieval_params() -> None: @@ -741,6 +750,7 @@ def test_retrieval_params() -> None: Document(page_content="test", metadata={"test": "test1"}), Document(page_content="test", metadata={"test": "test1"}), ] + drop_vector_indexes(docsearch) def test_retrieval_dictionary() -> None: @@ -767,3 +777,38 @@ def test_retrieval_dictionary() -> None: ] output = docsearch.similarity_search("Foo", k=1) assert output == expected_output + drop_vector_indexes(docsearch) + + +def test_metadata_filters_type1() -> None: + """Test metadata filters""" + docsearch = Neo4jVector.from_documents( + DOCUMENTS, + embedding=FakeEmbeddings(), + pre_delete_collection=True, + ) + # We don't test type 5, because LIKE has very SQL specific examples + for example in ( + TYPE_1_FILTERING_TEST_CASES + + TYPE_2_FILTERING_TEST_CASES + + TYPE_3_FILTERING_TEST_CASES + + TYPE_4_FILTERING_TEST_CASES + ): + filter_dict = cast(Dict[str, Any], example[0]) + output = docsearch.similarity_search("Foo", filter=filter_dict) + indices = cast(List[int], example[1]) + adjusted_indices = [index - 1 for index in indices] + expected_output = [DOCUMENTS[index] for index in adjusted_indices] + # We don't return id properties from similarity search by default + # Also remove any key where the value is None + for doc in expected_output: + if "id" in doc.metadata: + del doc.metadata["id"] + keys_with_none = [ + key for key, value in doc.metadata.items() if value is None + ] + for key in keys_with_none: + del doc.metadata[key] + + assert output == expected_output + drop_vector_indexes(docsearch)