mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
community[minor]: Add metadata filtering support for neo4j vector (#20001)
This commit is contained in:
parent
b52b78478f
commit
df25829f33
@ -28,6 +28,35 @@ DISTANCE_MAPPING = {
|
|||||||
DistanceStrategy.COSINE: "cosine",
|
DistanceStrategy.COSINE: "cosine",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
COMPARISONS_TO_NATIVE = {
|
||||||
|
"$eq": "=",
|
||||||
|
"$ne": "<>",
|
||||||
|
"$lt": "<",
|
||||||
|
"$lte": "<=",
|
||||||
|
"$gt": ">",
|
||||||
|
"$gte": ">=",
|
||||||
|
}
|
||||||
|
|
||||||
|
SPECIAL_CASED_OPERATORS = {
|
||||||
|
"$in",
|
||||||
|
"$nin",
|
||||||
|
"$between",
|
||||||
|
}
|
||||||
|
|
||||||
|
TEXT_OPERATORS = {
|
||||||
|
"$like",
|
||||||
|
"$ilike",
|
||||||
|
}
|
||||||
|
|
||||||
|
LOGICAL_OPERATORS = {"$and", "$or"}
|
||||||
|
|
||||||
|
SUPPORTED_OPERATORS = (
|
||||||
|
set(COMPARISONS_TO_NATIVE)
|
||||||
|
.union(TEXT_OPERATORS)
|
||||||
|
.union(LOGICAL_OPERATORS)
|
||||||
|
.union(SPECIAL_CASED_OPERATORS)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SearchType(str, enum.Enum):
|
class SearchType(str, enum.Enum):
|
||||||
"""Enumerator of the Distance strategies."""
|
"""Enumerator of the Distance strategies."""
|
||||||
@ -133,6 +162,240 @@ def dict_to_yaml_str(input_dict: Dict, indent: int = 0) -> str:
|
|||||||
return yaml_str
|
return yaml_str
|
||||||
|
|
||||||
|
|
||||||
|
def combine_queries(
|
||||||
|
input_queries: List[Tuple[str, Dict[str, Any]]], operator: str
|
||||||
|
) -> Tuple[str, Dict[str, Any]]:
|
||||||
|
# Initialize variables to hold the combined query and parameters
|
||||||
|
combined_query: str = ""
|
||||||
|
combined_params: Dict = {}
|
||||||
|
param_counter: Dict = {}
|
||||||
|
|
||||||
|
for query, params in input_queries:
|
||||||
|
# Process each query fragment and its parameters
|
||||||
|
new_query = query
|
||||||
|
for param, value in params.items():
|
||||||
|
# Update the parameter name to ensure uniqueness
|
||||||
|
if param in param_counter:
|
||||||
|
param_counter[param] += 1
|
||||||
|
else:
|
||||||
|
param_counter[param] = 1
|
||||||
|
new_param_name = f"{param}_{param_counter[param]}"
|
||||||
|
|
||||||
|
# Replace the parameter in the query fragment
|
||||||
|
new_query = new_query.replace(f"${param}", f"${new_param_name}")
|
||||||
|
# Add the parameter to the combined parameters dictionary
|
||||||
|
combined_params[new_param_name] = value
|
||||||
|
|
||||||
|
# Combine the query fragments with an AND operator
|
||||||
|
if combined_query:
|
||||||
|
combined_query += f" {operator} "
|
||||||
|
combined_query += f"({new_query})"
|
||||||
|
|
||||||
|
return combined_query, combined_params
|
||||||
|
|
||||||
|
|
||||||
|
def collect_params(
|
||||||
|
input_data: List[Tuple[str, Dict[str, str]]],
|
||||||
|
) -> Tuple[List[str], Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Transform the input data into the desired format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- input_data (list of tuples): Input data to transform.
|
||||||
|
Each tuple contains a string and a dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- tuple: A tuple containing a list of strings and a dictionary.
|
||||||
|
"""
|
||||||
|
# Initialize variables to hold the output parts
|
||||||
|
query_parts = []
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
# Loop through each item in the input data
|
||||||
|
for query_part, param in input_data:
|
||||||
|
# Append the query part to the list
|
||||||
|
query_parts.append(query_part)
|
||||||
|
# Update the params dictionary with the param dictionary
|
||||||
|
params.update(param)
|
||||||
|
|
||||||
|
# Return the transformed data
|
||||||
|
return (query_parts, params)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_field_filter(
|
||||||
|
field: str, value: Any, param_number: int = 1
|
||||||
|
) -> Tuple[str, Dict]:
|
||||||
|
"""Create a filter for a specific field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field: name of field
|
||||||
|
value: value to filter
|
||||||
|
If provided as is then this will be an equality filter
|
||||||
|
If provided as a dictionary then this will be a filter, the key
|
||||||
|
will be the operator and the value will be the value to filter by
|
||||||
|
param_number: sequence number of parameters used to map between param
|
||||||
|
dict and Cypher snippet
|
||||||
|
|
||||||
|
Returns a tuple of
|
||||||
|
- Cypher filter snippet
|
||||||
|
- Dictionary with parameters used in filter snippet
|
||||||
|
"""
|
||||||
|
if not isinstance(field, str):
|
||||||
|
raise ValueError(
|
||||||
|
f"field should be a string but got: {type(field)} with value: {field}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if field.startswith("$"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid filter condition. Expected a field but got an operator: "
|
||||||
|
f"{field}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters
|
||||||
|
if not field.isidentifier():
|
||||||
|
raise ValueError(f"Invalid field name: {field}. Expected a valid identifier.")
|
||||||
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# This is a filter specification
|
||||||
|
if len(value) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid filter condition. Expected a value which "
|
||||||
|
"is a dictionary with a single key that corresponds to an operator "
|
||||||
|
f"but got a dictionary with {len(value)} keys. The first few "
|
||||||
|
f"keys are: {list(value.keys())[:3]}"
|
||||||
|
)
|
||||||
|
operator, filter_value = list(value.items())[0]
|
||||||
|
# Verify that that operator is an operator
|
||||||
|
if operator not in SUPPORTED_OPERATORS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid operator: {operator}. "
|
||||||
|
f"Expected one of {SUPPORTED_OPERATORS}"
|
||||||
|
)
|
||||||
|
else: # Then we assume an equality operator
|
||||||
|
operator = "$eq"
|
||||||
|
filter_value = value
|
||||||
|
|
||||||
|
if operator in COMPARISONS_TO_NATIVE:
|
||||||
|
# Then we implement an equality filter
|
||||||
|
# native is trusted input
|
||||||
|
native = COMPARISONS_TO_NATIVE[operator]
|
||||||
|
query_snippet = f"n.`{field}` {native} $param_{param_number}"
|
||||||
|
query_param = {f"param_{param_number}": filter_value}
|
||||||
|
return (query_snippet, query_param)
|
||||||
|
elif operator == "$between":
|
||||||
|
low, high = filter_value
|
||||||
|
query_snippet = (
|
||||||
|
f"$param_{param_number}_low <= n.`{field}` <= $param_{param_number}_high"
|
||||||
|
)
|
||||||
|
query_param = {
|
||||||
|
f"param_{param_number}_low": low,
|
||||||
|
f"param_{param_number}_high": high,
|
||||||
|
}
|
||||||
|
return (query_snippet, query_param)
|
||||||
|
|
||||||
|
elif operator in {"$in", "$nin", "$like", "$ilike"}:
|
||||||
|
# We'll do force coercion to text
|
||||||
|
if operator in {"$in", "$nin"}:
|
||||||
|
for val in filter_value:
|
||||||
|
if not isinstance(val, (str, int, float)):
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Unsupported type: {type(val)} for value: {val}"
|
||||||
|
)
|
||||||
|
if operator in {"$in"}:
|
||||||
|
query_snippet = f"n.`{field}` IN $param_{param_number}"
|
||||||
|
query_param = {f"param_{param_number}": filter_value}
|
||||||
|
return (query_snippet, query_param)
|
||||||
|
elif operator in {"$nin"}:
|
||||||
|
query_snippet = f"n.`{field}` NOT IN $param_{param_number}"
|
||||||
|
query_param = {f"param_{param_number}": filter_value}
|
||||||
|
return (query_snippet, query_param)
|
||||||
|
elif operator in {"$like"}:
|
||||||
|
query_snippet = f"n.`{field}` CONTAINS $param_{param_number}"
|
||||||
|
query_param = {f"param_{param_number}": filter_value.rstrip("%")}
|
||||||
|
return (query_snippet, query_param)
|
||||||
|
elif operator in {"$ilike"}:
|
||||||
|
query_snippet = f"toLower(n.`{field}`) CONTAINS $param_{param_number}"
|
||||||
|
query_param = {f"param_{param_number}": filter_value.rstrip("%")}
|
||||||
|
return (query_snippet, query_param)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def construct_metadata_filter(filter: Dict[str, Any]) -> Tuple[str, Dict]:
|
||||||
|
if isinstance(filter, dict):
|
||||||
|
if len(filter) == 1:
|
||||||
|
# The only operators allowed at the top level are $AND and $OR
|
||||||
|
# First check if an operator or a field
|
||||||
|
key, value = list(filter.items())[0]
|
||||||
|
if key.startswith("$"):
|
||||||
|
# Then it's an operator
|
||||||
|
if key.lower() not in ["$and", "$or"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid filter condition. Expected $and or $or "
|
||||||
|
f"but got: {key}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Then it's a field
|
||||||
|
return _handle_field_filter(key, filter[key])
|
||||||
|
|
||||||
|
# Here we handle the $and and $or operators
|
||||||
|
if not isinstance(value, list):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected a list, but got {type(value)} for value: {value}"
|
||||||
|
)
|
||||||
|
if key.lower() == "$and":
|
||||||
|
and_ = combine_queries(
|
||||||
|
[construct_metadata_filter(el) for el in value], "AND"
|
||||||
|
)
|
||||||
|
if len(and_) >= 1:
|
||||||
|
return and_
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid filter condition. Expected a dictionary "
|
||||||
|
"but got an empty dictionary"
|
||||||
|
)
|
||||||
|
elif key.lower() == "$or":
|
||||||
|
or_ = combine_queries(
|
||||||
|
[construct_metadata_filter(el) for el in value], "OR"
|
||||||
|
)
|
||||||
|
if len(or_) >= 1:
|
||||||
|
return or_
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid filter condition. Expected a dictionary "
|
||||||
|
"but got an empty dictionary"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid filter condition. Expected $and or $or " f"but got: {key}"
|
||||||
|
)
|
||||||
|
elif len(filter) > 1:
|
||||||
|
# Then all keys have to be fields (they cannot be operators)
|
||||||
|
for key in filter.keys():
|
||||||
|
if key.startswith("$"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid filter condition. Expected a field but got: {key}"
|
||||||
|
)
|
||||||
|
# These should all be fields and combined using an $and operator
|
||||||
|
and_multiple = collect_params(
|
||||||
|
[
|
||||||
|
_handle_field_filter(k, v, index)
|
||||||
|
for index, (k, v) in enumerate(filter.items())
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if len(and_multiple) >= 1:
|
||||||
|
return " AND ".join(and_multiple[0]), and_multiple[1]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid filter condition. Expected a dictionary "
|
||||||
|
"but got an empty dictionary"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Got an empty dictionary for filters.")
|
||||||
|
|
||||||
|
|
||||||
class Neo4jVector(VectorStore):
|
class Neo4jVector(VectorStore):
|
||||||
"""`Neo4j` vector index.
|
"""`Neo4j` vector index.
|
||||||
|
|
||||||
@ -243,6 +506,7 @@ class Neo4jVector(VectorStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Verify if the version support vector index
|
# Verify if the version support vector index
|
||||||
|
self._is_enterprise = False
|
||||||
self.verify_version()
|
self.verify_version()
|
||||||
|
|
||||||
# Verify that required values are not null
|
# Verify that required values are not null
|
||||||
@ -318,7 +582,8 @@ class Neo4jVector(VectorStore):
|
|||||||
indexing. Raises a ValueError if the connected Neo4j version is
|
indexing. Raises a ValueError if the connected Neo4j version is
|
||||||
not supported.
|
not supported.
|
||||||
"""
|
"""
|
||||||
version = self.query("CALL dbms.components()")[0]["versions"][0]
|
db_data = self.query("CALL dbms.components()")
|
||||||
|
version = db_data[0]["versions"][0]
|
||||||
if "aura" in version:
|
if "aura" in version:
|
||||||
version_tuple = tuple(map(int, version.split("-")[0].split("."))) + (0,)
|
version_tuple = tuple(map(int, version.split("-")[0].split("."))) + (0,)
|
||||||
else:
|
else:
|
||||||
@ -331,6 +596,15 @@ class Neo4jVector(VectorStore):
|
|||||||
"Version index is only supported in Neo4j version 5.11 or greater"
|
"Version index is only supported in Neo4j version 5.11 or greater"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Flag for metadata filtering
|
||||||
|
metadata_target_version = (5, 18, 0)
|
||||||
|
if version_tuple < metadata_target_version:
|
||||||
|
self.support_metadata_filter = False
|
||||||
|
else:
|
||||||
|
self.support_metadata_filter = True
|
||||||
|
# Flag for enterprise
|
||||||
|
self._is_enterprise = True if db_data[0]["edition"] == "enterprise" else False
|
||||||
|
|
||||||
def retrieve_existing_index(self) -> Optional[int]:
|
def retrieve_existing_index(self) -> Optional[int]:
|
||||||
"""
|
"""
|
||||||
Check if the vector index exists in the Neo4j database
|
Check if the vector index exists in the Neo4j database
|
||||||
@ -583,6 +857,7 @@ class Neo4jVector(VectorStore):
|
|||||||
query: str,
|
query: str,
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
params: Dict[str, Any] = {},
|
params: Dict[str, Any] = {},
|
||||||
|
filter: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Run similarity search with Neo4jVector.
|
"""Run similarity search with Neo4jVector.
|
||||||
@ -596,7 +871,12 @@ class Neo4jVector(VectorStore):
|
|||||||
"""
|
"""
|
||||||
embedding = self.embedding.embed_query(text=query)
|
embedding = self.embedding.embed_query(text=query)
|
||||||
return self.similarity_search_by_vector(
|
return self.similarity_search_by_vector(
|
||||||
embedding=embedding, k=k, query=query, params=params, **kwargs
|
embedding=embedding,
|
||||||
|
k=k,
|
||||||
|
query=query,
|
||||||
|
params=params,
|
||||||
|
filter=filter,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def similarity_search_with_score(
|
def similarity_search_with_score(
|
||||||
@ -604,6 +884,7 @@ class Neo4jVector(VectorStore):
|
|||||||
query: str,
|
query: str,
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
params: Dict[str, Any] = {},
|
params: Dict[str, Any] = {},
|
||||||
|
filter: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
"""Return docs most similar to query.
|
"""Return docs most similar to query.
|
||||||
@ -617,7 +898,12 @@ class Neo4jVector(VectorStore):
|
|||||||
"""
|
"""
|
||||||
embedding = self.embedding.embed_query(query)
|
embedding = self.embedding.embed_query(query)
|
||||||
docs = self.similarity_search_with_score_by_vector(
|
docs = self.similarity_search_with_score_by_vector(
|
||||||
embedding=embedding, k=k, query=query, params=params, **kwargs
|
embedding=embedding,
|
||||||
|
k=k,
|
||||||
|
query=query,
|
||||||
|
params=params,
|
||||||
|
filter=filter,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@ -625,6 +911,7 @@ class Neo4jVector(VectorStore):
|
|||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
|
filter: Optional[Dict[str, Any]] = None,
|
||||||
params: Dict[str, Any] = {},
|
params: Dict[str, Any] = {},
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
@ -646,6 +933,33 @@ class Neo4jVector(VectorStore):
|
|||||||
List[Tuple[Document, float]]: A list of tuples, each containing
|
List[Tuple[Document, float]]: A list of tuples, each containing
|
||||||
a Document object and its similarity score.
|
a Document object and its similarity score.
|
||||||
"""
|
"""
|
||||||
|
if filter:
|
||||||
|
# Verify that 5.18 or later is used
|
||||||
|
if not self.support_metadata_filter:
|
||||||
|
raise ValueError(
|
||||||
|
"Metadata filtering is only supported in "
|
||||||
|
"Neo4j version 5.18 or greater"
|
||||||
|
)
|
||||||
|
# Metadata filtering and hybrid doesn't work
|
||||||
|
if self.search_type == SearchType.HYBRID:
|
||||||
|
raise ValueError(
|
||||||
|
"Metadata filtering can't be use in combination with "
|
||||||
|
"a hybrid search approach"
|
||||||
|
)
|
||||||
|
parallel_query = "CYPHER runtime = parallel " if self._is_enterprise else ""
|
||||||
|
base_index_query = parallel_query + f"MATCH (n:`{self.node_label}`) WHERE "
|
||||||
|
base_cosine_query = (
|
||||||
|
" WITH n as node, vector.similarity.cosine("
|
||||||
|
f"n.`{self.embedding_node_property}`, "
|
||||||
|
"$embedding) AS score ORDER BY score DESC LIMIT toInteger($k) "
|
||||||
|
)
|
||||||
|
filter_snippets, filter_params = construct_metadata_filter(filter)
|
||||||
|
index_query = base_index_query + filter_snippets + base_cosine_query
|
||||||
|
|
||||||
|
else:
|
||||||
|
index_query = _get_search_index_query(self.search_type)
|
||||||
|
filter_params = {}
|
||||||
|
|
||||||
default_retrieval = (
|
default_retrieval = (
|
||||||
f"RETURN node.`{self.text_node_property}` AS text, score, "
|
f"RETURN node.`{self.text_node_property}` AS text, score, "
|
||||||
f"node {{.*, `{self.text_node_property}`: Null, "
|
f"node {{.*, `{self.text_node_property}`: Null, "
|
||||||
@ -656,7 +970,7 @@ class Neo4jVector(VectorStore):
|
|||||||
self.retrieval_query if self.retrieval_query else default_retrieval
|
self.retrieval_query if self.retrieval_query else default_retrieval
|
||||||
)
|
)
|
||||||
|
|
||||||
read_query = _get_search_index_query(self.search_type) + retrieval_query
|
read_query = index_query + retrieval_query
|
||||||
parameters = {
|
parameters = {
|
||||||
"index": self.index_name,
|
"index": self.index_name,
|
||||||
"k": k,
|
"k": k,
|
||||||
@ -664,6 +978,7 @@ class Neo4jVector(VectorStore):
|
|||||||
"keyword_index": self.keyword_index_name,
|
"keyword_index": self.keyword_index_name,
|
||||||
"query": remove_lucene_chars(kwargs["query"]),
|
"query": remove_lucene_chars(kwargs["query"]),
|
||||||
**params,
|
**params,
|
||||||
|
**filter_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
results = self.query(read_query, params=parameters)
|
results = self.query(read_query, params=parameters)
|
||||||
@ -688,6 +1003,7 @@ class Neo4jVector(VectorStore):
|
|||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
|
filter: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Return docs most similar to embedding vector.
|
"""Return docs most similar to embedding vector.
|
||||||
@ -700,7 +1016,7 @@ class Neo4jVector(VectorStore):
|
|||||||
List of Documents most similar to the query vector.
|
List of Documents most similar to the query vector.
|
||||||
"""
|
"""
|
||||||
docs_and_scores = self.similarity_search_with_score_by_vector(
|
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||||
embedding=embedding, k=k, **kwargs
|
embedding=embedding, k=k, filter=filter, **kwargs
|
||||||
)
|
)
|
||||||
return [doc for doc, _ in docs_and_scores]
|
return [doc for doc, _ in docs_and_scores]
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Test Neo4jVector functionality."""
|
"""Test Neo4jVector functionality."""
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import Any, Dict, List, cast
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
@ -11,6 +11,13 @@ from langchain_community.vectorstores.neo4j_vector import (
|
|||||||
)
|
)
|
||||||
from langchain_community.vectorstores.utils import DistanceStrategy
|
from langchain_community.vectorstores.utils import DistanceStrategy
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||||
|
from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import (
|
||||||
|
DOCUMENTS,
|
||||||
|
TYPE_1_FILTERING_TEST_CASES,
|
||||||
|
TYPE_2_FILTERING_TEST_CASES,
|
||||||
|
TYPE_3_FILTERING_TEST_CASES,
|
||||||
|
TYPE_4_FILTERING_TEST_CASES,
|
||||||
|
)
|
||||||
|
|
||||||
url = os.environ.get("NEO4J_URL", "bolt://localhost:7687")
|
url = os.environ.get("NEO4J_URL", "bolt://localhost:7687")
|
||||||
username = os.environ.get("NEO4J_USERNAME", "neo4j")
|
username = os.environ.get("NEO4J_USERNAME", "neo4j")
|
||||||
@ -721,6 +728,8 @@ def test_index_fetching() -> None:
|
|||||||
|
|
||||||
index_0_store = fetch_store(index_0_str)
|
index_0_store = fetch_store(index_0_str)
|
||||||
assert index_0_store.index_name == index_0_str
|
assert index_0_store.index_name == index_0_str
|
||||||
|
drop_vector_indexes(index_1_store)
|
||||||
|
drop_vector_indexes(index_0_store)
|
||||||
|
|
||||||
|
|
||||||
def test_retrieval_params() -> None:
|
def test_retrieval_params() -> None:
|
||||||
@ -741,6 +750,7 @@ def test_retrieval_params() -> None:
|
|||||||
Document(page_content="test", metadata={"test": "test1"}),
|
Document(page_content="test", metadata={"test": "test1"}),
|
||||||
Document(page_content="test", metadata={"test": "test1"}),
|
Document(page_content="test", metadata={"test": "test1"}),
|
||||||
]
|
]
|
||||||
|
drop_vector_indexes(docsearch)
|
||||||
|
|
||||||
|
|
||||||
def test_retrieval_dictionary() -> None:
|
def test_retrieval_dictionary() -> None:
|
||||||
@ -767,3 +777,38 @@ def test_retrieval_dictionary() -> None:
|
|||||||
]
|
]
|
||||||
output = docsearch.similarity_search("Foo", k=1)
|
output = docsearch.similarity_search("Foo", k=1)
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
drop_vector_indexes(docsearch)
|
||||||
|
|
||||||
|
|
||||||
|
def test_metadata_filters_type1() -> None:
|
||||||
|
"""Test metadata filters"""
|
||||||
|
docsearch = Neo4jVector.from_documents(
|
||||||
|
DOCUMENTS,
|
||||||
|
embedding=FakeEmbeddings(),
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
# We don't test type 5, because LIKE has very SQL specific examples
|
||||||
|
for example in (
|
||||||
|
TYPE_1_FILTERING_TEST_CASES
|
||||||
|
+ TYPE_2_FILTERING_TEST_CASES
|
||||||
|
+ TYPE_3_FILTERING_TEST_CASES
|
||||||
|
+ TYPE_4_FILTERING_TEST_CASES
|
||||||
|
):
|
||||||
|
filter_dict = cast(Dict[str, Any], example[0])
|
||||||
|
output = docsearch.similarity_search("Foo", filter=filter_dict)
|
||||||
|
indices = cast(List[int], example[1])
|
||||||
|
adjusted_indices = [index - 1 for index in indices]
|
||||||
|
expected_output = [DOCUMENTS[index] for index in adjusted_indices]
|
||||||
|
# We don't return id properties from similarity search by default
|
||||||
|
# Also remove any key where the value is None
|
||||||
|
for doc in expected_output:
|
||||||
|
if "id" in doc.metadata:
|
||||||
|
del doc.metadata["id"]
|
||||||
|
keys_with_none = [
|
||||||
|
key for key, value in doc.metadata.items() if value is None
|
||||||
|
]
|
||||||
|
for key in keys_with_none:
|
||||||
|
del doc.metadata[key]
|
||||||
|
|
||||||
|
assert output == expected_output
|
||||||
|
drop_vector_indexes(docsearch)
|
||||||
|
Loading…
Reference in New Issue
Block a user