|
|
|
@ -28,6 +28,35 @@ DISTANCE_MAPPING = {
|
|
|
|
|
DistanceStrategy.COSINE: "cosine",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
COMPARISONS_TO_NATIVE = {
|
|
|
|
|
"$eq": "=",
|
|
|
|
|
"$ne": "<>",
|
|
|
|
|
"$lt": "<",
|
|
|
|
|
"$lte": "<=",
|
|
|
|
|
"$gt": ">",
|
|
|
|
|
"$gte": ">=",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SPECIAL_CASED_OPERATORS = {
|
|
|
|
|
"$in",
|
|
|
|
|
"$nin",
|
|
|
|
|
"$between",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEXT_OPERATORS = {
|
|
|
|
|
"$like",
|
|
|
|
|
"$ilike",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LOGICAL_OPERATORS = {"$and", "$or"}
|
|
|
|
|
|
|
|
|
|
SUPPORTED_OPERATORS = (
|
|
|
|
|
set(COMPARISONS_TO_NATIVE)
|
|
|
|
|
.union(TEXT_OPERATORS)
|
|
|
|
|
.union(LOGICAL_OPERATORS)
|
|
|
|
|
.union(SPECIAL_CASED_OPERATORS)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SearchType(str, enum.Enum):
|
|
|
|
|
"""Enumerator of the Distance strategies."""
|
|
|
|
@ -133,6 +162,240 @@ def dict_to_yaml_str(input_dict: Dict, indent: int = 0) -> str:
|
|
|
|
|
return yaml_str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def combine_queries(
|
|
|
|
|
input_queries: List[Tuple[str, Dict[str, Any]]], operator: str
|
|
|
|
|
) -> Tuple[str, Dict[str, Any]]:
|
|
|
|
|
# Initialize variables to hold the combined query and parameters
|
|
|
|
|
combined_query: str = ""
|
|
|
|
|
combined_params: Dict = {}
|
|
|
|
|
param_counter: Dict = {}
|
|
|
|
|
|
|
|
|
|
for query, params in input_queries:
|
|
|
|
|
# Process each query fragment and its parameters
|
|
|
|
|
new_query = query
|
|
|
|
|
for param, value in params.items():
|
|
|
|
|
# Update the parameter name to ensure uniqueness
|
|
|
|
|
if param in param_counter:
|
|
|
|
|
param_counter[param] += 1
|
|
|
|
|
else:
|
|
|
|
|
param_counter[param] = 1
|
|
|
|
|
new_param_name = f"{param}_{param_counter[param]}"
|
|
|
|
|
|
|
|
|
|
# Replace the parameter in the query fragment
|
|
|
|
|
new_query = new_query.replace(f"${param}", f"${new_param_name}")
|
|
|
|
|
# Add the parameter to the combined parameters dictionary
|
|
|
|
|
combined_params[new_param_name] = value
|
|
|
|
|
|
|
|
|
|
# Combine the query fragments with an AND operator
|
|
|
|
|
if combined_query:
|
|
|
|
|
combined_query += f" {operator} "
|
|
|
|
|
combined_query += f"({new_query})"
|
|
|
|
|
|
|
|
|
|
return combined_query, combined_params
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collect_params(
|
|
|
|
|
input_data: List[Tuple[str, Dict[str, str]]],
|
|
|
|
|
) -> Tuple[List[str], Dict[str, Any]]:
|
|
|
|
|
"""
|
|
|
|
|
Transform the input data into the desired format.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
- input_data (list of tuples): Input data to transform.
|
|
|
|
|
Each tuple contains a string and a dictionary.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
- tuple: A tuple containing a list of strings and a dictionary.
|
|
|
|
|
"""
|
|
|
|
|
# Initialize variables to hold the output parts
|
|
|
|
|
query_parts = []
|
|
|
|
|
params = {}
|
|
|
|
|
|
|
|
|
|
# Loop through each item in the input data
|
|
|
|
|
for query_part, param in input_data:
|
|
|
|
|
# Append the query part to the list
|
|
|
|
|
query_parts.append(query_part)
|
|
|
|
|
# Update the params dictionary with the param dictionary
|
|
|
|
|
params.update(param)
|
|
|
|
|
|
|
|
|
|
# Return the transformed data
|
|
|
|
|
return (query_parts, params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _handle_field_filter(
|
|
|
|
|
field: str, value: Any, param_number: int = 1
|
|
|
|
|
) -> Tuple[str, Dict]:
|
|
|
|
|
"""Create a filter for a specific field.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
field: name of field
|
|
|
|
|
value: value to filter
|
|
|
|
|
If provided as is then this will be an equality filter
|
|
|
|
|
If provided as a dictionary then this will be a filter, the key
|
|
|
|
|
will be the operator and the value will be the value to filter by
|
|
|
|
|
param_number: sequence number of parameters used to map between param
|
|
|
|
|
dict and Cypher snippet
|
|
|
|
|
|
|
|
|
|
Returns a tuple of
|
|
|
|
|
- Cypher filter snippet
|
|
|
|
|
- Dictionary with parameters used in filter snippet
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(field, str):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"field should be a string but got: {type(field)} with value: {field}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if field.startswith("$"):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Invalid filter condition. Expected a field but got an operator: "
|
|
|
|
|
f"{field}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters
|
|
|
|
|
if not field.isidentifier():
|
|
|
|
|
raise ValueError(f"Invalid field name: {field}. Expected a valid identifier.")
|
|
|
|
|
|
|
|
|
|
if isinstance(value, dict):
|
|
|
|
|
# This is a filter specification
|
|
|
|
|
if len(value) != 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Invalid filter condition. Expected a value which "
|
|
|
|
|
"is a dictionary with a single key that corresponds to an operator "
|
|
|
|
|
f"but got a dictionary with {len(value)} keys. The first few "
|
|
|
|
|
f"keys are: {list(value.keys())[:3]}"
|
|
|
|
|
)
|
|
|
|
|
operator, filter_value = list(value.items())[0]
|
|
|
|
|
# Verify that that operator is an operator
|
|
|
|
|
if operator not in SUPPORTED_OPERATORS:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Invalid operator: {operator}. "
|
|
|
|
|
f"Expected one of {SUPPORTED_OPERATORS}"
|
|
|
|
|
)
|
|
|
|
|
else: # Then we assume an equality operator
|
|
|
|
|
operator = "$eq"
|
|
|
|
|
filter_value = value
|
|
|
|
|
|
|
|
|
|
if operator in COMPARISONS_TO_NATIVE:
|
|
|
|
|
# Then we implement an equality filter
|
|
|
|
|
# native is trusted input
|
|
|
|
|
native = COMPARISONS_TO_NATIVE[operator]
|
|
|
|
|
query_snippet = f"n.`{field}` {native} $param_{param_number}"
|
|
|
|
|
query_param = {f"param_{param_number}": filter_value}
|
|
|
|
|
return (query_snippet, query_param)
|
|
|
|
|
elif operator == "$between":
|
|
|
|
|
low, high = filter_value
|
|
|
|
|
query_snippet = (
|
|
|
|
|
f"$param_{param_number}_low <= n.`{field}` <= $param_{param_number}_high"
|
|
|
|
|
)
|
|
|
|
|
query_param = {
|
|
|
|
|
f"param_{param_number}_low": low,
|
|
|
|
|
f"param_{param_number}_high": high,
|
|
|
|
|
}
|
|
|
|
|
return (query_snippet, query_param)
|
|
|
|
|
|
|
|
|
|
elif operator in {"$in", "$nin", "$like", "$ilike"}:
|
|
|
|
|
# We'll do force coercion to text
|
|
|
|
|
if operator in {"$in", "$nin"}:
|
|
|
|
|
for val in filter_value:
|
|
|
|
|
if not isinstance(val, (str, int, float)):
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
f"Unsupported type: {type(val)} for value: {val}"
|
|
|
|
|
)
|
|
|
|
|
if operator in {"$in"}:
|
|
|
|
|
query_snippet = f"n.`{field}` IN $param_{param_number}"
|
|
|
|
|
query_param = {f"param_{param_number}": filter_value}
|
|
|
|
|
return (query_snippet, query_param)
|
|
|
|
|
elif operator in {"$nin"}:
|
|
|
|
|
query_snippet = f"n.`{field}` NOT IN $param_{param_number}"
|
|
|
|
|
query_param = {f"param_{param_number}": filter_value}
|
|
|
|
|
return (query_snippet, query_param)
|
|
|
|
|
elif operator in {"$like"}:
|
|
|
|
|
query_snippet = f"n.`{field}` CONTAINS $param_{param_number}"
|
|
|
|
|
query_param = {f"param_{param_number}": filter_value.rstrip("%")}
|
|
|
|
|
return (query_snippet, query_param)
|
|
|
|
|
elif operator in {"$ilike"}:
|
|
|
|
|
query_snippet = f"toLower(n.`{field}`) CONTAINS $param_{param_number}"
|
|
|
|
|
query_param = {f"param_{param_number}": filter_value.rstrip("%")}
|
|
|
|
|
return (query_snippet, query_param)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def construct_metadata_filter(filter: Dict[str, Any]) -> Tuple[str, Dict]:
|
|
|
|
|
if isinstance(filter, dict):
|
|
|
|
|
if len(filter) == 1:
|
|
|
|
|
# The only operators allowed at the top level are $AND and $OR
|
|
|
|
|
# First check if an operator or a field
|
|
|
|
|
key, value = list(filter.items())[0]
|
|
|
|
|
if key.startswith("$"):
|
|
|
|
|
# Then it's an operator
|
|
|
|
|
if key.lower() not in ["$and", "$or"]:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Invalid filter condition. Expected $and or $or "
|
|
|
|
|
f"but got: {key}"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
# Then it's a field
|
|
|
|
|
return _handle_field_filter(key, filter[key])
|
|
|
|
|
|
|
|
|
|
# Here we handle the $and and $or operators
|
|
|
|
|
if not isinstance(value, list):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Expected a list, but got {type(value)} for value: {value}"
|
|
|
|
|
)
|
|
|
|
|
if key.lower() == "$and":
|
|
|
|
|
and_ = combine_queries(
|
|
|
|
|
[construct_metadata_filter(el) for el in value], "AND"
|
|
|
|
|
)
|
|
|
|
|
if len(and_) >= 1:
|
|
|
|
|
return and_
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Invalid filter condition. Expected a dictionary "
|
|
|
|
|
"but got an empty dictionary"
|
|
|
|
|
)
|
|
|
|
|
elif key.lower() == "$or":
|
|
|
|
|
or_ = combine_queries(
|
|
|
|
|
[construct_metadata_filter(el) for el in value], "OR"
|
|
|
|
|
)
|
|
|
|
|
if len(or_) >= 1:
|
|
|
|
|
return or_
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Invalid filter condition. Expected a dictionary "
|
|
|
|
|
"but got an empty dictionary"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Invalid filter condition. Expected $and or $or " f"but got: {key}"
|
|
|
|
|
)
|
|
|
|
|
elif len(filter) > 1:
|
|
|
|
|
# Then all keys have to be fields (they cannot be operators)
|
|
|
|
|
for key in filter.keys():
|
|
|
|
|
if key.startswith("$"):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Invalid filter condition. Expected a field but got: {key}"
|
|
|
|
|
)
|
|
|
|
|
# These should all be fields and combined using an $and operator
|
|
|
|
|
and_multiple = collect_params(
|
|
|
|
|
[
|
|
|
|
|
_handle_field_filter(k, v, index)
|
|
|
|
|
for index, (k, v) in enumerate(filter.items())
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
if len(and_multiple) >= 1:
|
|
|
|
|
return " AND ".join(and_multiple[0]), and_multiple[1]
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Invalid filter condition. Expected a dictionary "
|
|
|
|
|
"but got an empty dictionary"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Got an empty dictionary for filters.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Neo4jVector(VectorStore):
|
|
|
|
|
"""`Neo4j` vector index.
|
|
|
|
|
|
|
|
|
@ -243,6 +506,7 @@ class Neo4jVector(VectorStore):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Verify if the version support vector index
|
|
|
|
|
self._is_enterprise = False
|
|
|
|
|
self.verify_version()
|
|
|
|
|
|
|
|
|
|
# Verify that required values are not null
|
|
|
|
@ -318,7 +582,8 @@ class Neo4jVector(VectorStore):
|
|
|
|
|
indexing. Raises a ValueError if the connected Neo4j version is
|
|
|
|
|
not supported.
|
|
|
|
|
"""
|
|
|
|
|
version = self.query("CALL dbms.components()")[0]["versions"][0]
|
|
|
|
|
db_data = self.query("CALL dbms.components()")
|
|
|
|
|
version = db_data[0]["versions"][0]
|
|
|
|
|
if "aura" in version:
|
|
|
|
|
version_tuple = tuple(map(int, version.split("-")[0].split("."))) + (0,)
|
|
|
|
|
else:
|
|
|
|
@ -331,6 +596,15 @@ class Neo4jVector(VectorStore):
|
|
|
|
|
"Version index is only supported in Neo4j version 5.11 or greater"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Flag for metadata filtering
|
|
|
|
|
metadata_target_version = (5, 18, 0)
|
|
|
|
|
if version_tuple < metadata_target_version:
|
|
|
|
|
self.support_metadata_filter = False
|
|
|
|
|
else:
|
|
|
|
|
self.support_metadata_filter = True
|
|
|
|
|
# Flag for enterprise
|
|
|
|
|
self._is_enterprise = True if db_data[0]["edition"] == "enterprise" else False
|
|
|
|
|
|
|
|
|
|
def retrieve_existing_index(self) -> Optional[int]:
|
|
|
|
|
"""
|
|
|
|
|
Check if the vector index exists in the Neo4j database
|
|
|
|
@ -583,6 +857,7 @@ class Neo4jVector(VectorStore):
|
|
|
|
|
query: str,
|
|
|
|
|
k: int = 4,
|
|
|
|
|
params: Dict[str, Any] = {},
|
|
|
|
|
filter: Optional[Dict[str, Any]] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> List[Document]:
|
|
|
|
|
"""Run similarity search with Neo4jVector.
|
|
|
|
@ -596,7 +871,12 @@ class Neo4jVector(VectorStore):
|
|
|
|
|
"""
|
|
|
|
|
embedding = self.embedding.embed_query(text=query)
|
|
|
|
|
return self.similarity_search_by_vector(
|
|
|
|
|
embedding=embedding, k=k, query=query, params=params, **kwargs
|
|
|
|
|
embedding=embedding,
|
|
|
|
|
k=k,
|
|
|
|
|
query=query,
|
|
|
|
|
params=params,
|
|
|
|
|
filter=filter,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def similarity_search_with_score(
|
|
|
|
@ -604,6 +884,7 @@ class Neo4jVector(VectorStore):
|
|
|
|
|
query: str,
|
|
|
|
|
k: int = 4,
|
|
|
|
|
params: Dict[str, Any] = {},
|
|
|
|
|
filter: Optional[Dict[str, Any]] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> List[Tuple[Document, float]]:
|
|
|
|
|
"""Return docs most similar to query.
|
|
|
|
@ -617,7 +898,12 @@ class Neo4jVector(VectorStore):
|
|
|
|
|
"""
|
|
|
|
|
embedding = self.embedding.embed_query(query)
|
|
|
|
|
docs = self.similarity_search_with_score_by_vector(
|
|
|
|
|
embedding=embedding, k=k, query=query, params=params, **kwargs
|
|
|
|
|
embedding=embedding,
|
|
|
|
|
k=k,
|
|
|
|
|
query=query,
|
|
|
|
|
params=params,
|
|
|
|
|
filter=filter,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
@ -625,6 +911,7 @@ class Neo4jVector(VectorStore):
|
|
|
|
|
self,
|
|
|
|
|
embedding: List[float],
|
|
|
|
|
k: int = 4,
|
|
|
|
|
filter: Optional[Dict[str, Any]] = None,
|
|
|
|
|
params: Dict[str, Any] = {},
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> List[Tuple[Document, float]]:
|
|
|
|
@ -646,6 +933,33 @@ class Neo4jVector(VectorStore):
|
|
|
|
|
List[Tuple[Document, float]]: A list of tuples, each containing
|
|
|
|
|
a Document object and its similarity score.
|
|
|
|
|
"""
|
|
|
|
|
if filter:
|
|
|
|
|
# Verify that 5.18 or later is used
|
|
|
|
|
if not self.support_metadata_filter:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Metadata filtering is only supported in "
|
|
|
|
|
"Neo4j version 5.18 or greater"
|
|
|
|
|
)
|
|
|
|
|
# Metadata filtering and hybrid doesn't work
|
|
|
|
|
if self.search_type == SearchType.HYBRID:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Metadata filtering can't be use in combination with "
|
|
|
|
|
"a hybrid search approach"
|
|
|
|
|
)
|
|
|
|
|
parallel_query = "CYPHER runtime = parallel " if self._is_enterprise else ""
|
|
|
|
|
base_index_query = parallel_query + f"MATCH (n:`{self.node_label}`) WHERE "
|
|
|
|
|
base_cosine_query = (
|
|
|
|
|
" WITH n as node, vector.similarity.cosine("
|
|
|
|
|
f"n.`{self.embedding_node_property}`, "
|
|
|
|
|
"$embedding) AS score ORDER BY score DESC LIMIT toInteger($k) "
|
|
|
|
|
)
|
|
|
|
|
filter_snippets, filter_params = construct_metadata_filter(filter)
|
|
|
|
|
index_query = base_index_query + filter_snippets + base_cosine_query
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
index_query = _get_search_index_query(self.search_type)
|
|
|
|
|
filter_params = {}
|
|
|
|
|
|
|
|
|
|
default_retrieval = (
|
|
|
|
|
f"RETURN node.`{self.text_node_property}` AS text, score, "
|
|
|
|
|
f"node {{.*, `{self.text_node_property}`: Null, "
|
|
|
|
@ -656,7 +970,7 @@ class Neo4jVector(VectorStore):
|
|
|
|
|
self.retrieval_query if self.retrieval_query else default_retrieval
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
read_query = _get_search_index_query(self.search_type) + retrieval_query
|
|
|
|
|
read_query = index_query + retrieval_query
|
|
|
|
|
parameters = {
|
|
|
|
|
"index": self.index_name,
|
|
|
|
|
"k": k,
|
|
|
|
@ -664,6 +978,7 @@ class Neo4jVector(VectorStore):
|
|
|
|
|
"keyword_index": self.keyword_index_name,
|
|
|
|
|
"query": remove_lucene_chars(kwargs["query"]),
|
|
|
|
|
**params,
|
|
|
|
|
**filter_params,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
results = self.query(read_query, params=parameters)
|
|
|
|
@ -688,6 +1003,7 @@ class Neo4jVector(VectorStore):
|
|
|
|
|
self,
|
|
|
|
|
embedding: List[float],
|
|
|
|
|
k: int = 4,
|
|
|
|
|
filter: Optional[Dict[str, Any]] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> List[Document]:
|
|
|
|
|
"""Return docs most similar to embedding vector.
|
|
|
|
@ -700,7 +1016,7 @@ class Neo4jVector(VectorStore):
|
|
|
|
|
List of Documents most similar to the query vector.
|
|
|
|
|
"""
|
|
|
|
|
docs_and_scores = self.similarity_search_with_score_by_vector(
|
|
|
|
|
embedding=embedding, k=k, **kwargs
|
|
|
|
|
embedding=embedding, k=k, filter=filter, **kwargs
|
|
|
|
|
)
|
|
|
|
|
return [doc for doc, _ in docs_and_scores]
|
|
|
|
|
|
|
|
|
|