Add mmr to neo4j vector (#25765)

This commit is contained in:
Tomaz Bratanic 2024-08-27 14:55:19 +02:00 committed by GitHub
parent 995305fdd5
commit f359e6b0a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 148 additions and 12 deletions

View File

@ -15,13 +15,17 @@ from typing import (
Type,
)
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_dict_or_env
from langchain_core.vectorstores import VectorStore
from langchain_community.graphs import Neo4jGraph
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_community.vectorstores.utils import (
DistanceStrategy,
maximal_marginal_relevance,
)
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
DISTANCE_MAPPING = {
@ -1042,17 +1046,35 @@ class Neo4jVector(VectorStore):
filter_params = {}
if self._index_type == IndexType.RELATIONSHIP:
default_retrieval = (
f"RETURN relationship.`{self.text_node_property}` AS text, score, "
f"relationship {{.*, `{self.text_node_property}`: Null, "
f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
)
if kwargs.get("return_embeddings"):
default_retrieval = (
f"RETURN relationship.`{self.text_node_property}` AS text, score, "
f"relationship {{.*, `{self.text_node_property}`: Null, "
f"`{self.embedding_node_property}`: Null, id: Null, "
f"_embedding_: relationship.`{self.embedding_node_property}`}} "
"AS metadata"
)
else:
default_retrieval = (
f"RETURN relationship.`{self.text_node_property}` AS text, score, "
f"relationship {{.*, `{self.text_node_property}`: Null, "
f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
)
else:
default_retrieval = (
f"RETURN node.`{self.text_node_property}` AS text, score, "
f"node {{.*, `{self.text_node_property}`: Null, "
f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
)
if kwargs.get("return_embeddings"):
default_retrieval = (
f"RETURN node.`{self.text_node_property}` AS text, score, "
f"node {{.*, `{self.text_node_property}`: Null, "
f"`{self.embedding_node_property}`: Null, id: Null, "
f"_embedding_: node.`{self.embedding_node_property}`}} AS metadata"
)
else:
default_retrieval = (
f"RETURN node.`{self.text_node_property}` AS text, score, "
f"node {{.*, `{self.text_node_property}`: Null, "
f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
)
retrieval_query = (
self.retrieval_query if self.retrieval_query else default_retrieval
@ -1083,6 +1105,20 @@ class Neo4jVector(VectorStore):
"Inspect the `retrieval_query` and ensure it doesn't "
"return None for the `text` column"
)
if kwargs.get("return_embeddings") and any(
result["metadata"]["_embedding_"] is None for result in results
):
if not self.retrieval_query:
raise ValueError(
f"Make sure that none of the `{self.embedding_node_property}` "
f"properties on nodes with label `{self.node_label}` "
"are missing or empty"
)
else:
raise ValueError(
"Inspect the `retrieval_query` and ensure it doesn't "
"return None for the `_embedding_` metadata column"
)
docs = [
(
@ -1487,6 +1523,64 @@ class Neo4jVector(VectorStore):
break
return store
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: search query text.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filter: Filter on metadata properties, e.g.
{
"str_property": "foo",
"int_property": 123
}
Returns:
List of Documents selected by maximal marginal relevance.
"""
# Embed the query
query_embedding = self.embedding.embed_query(query)
# Fetch the initial documents
got_docs = self.similarity_search_with_score_by_vector(
embedding=query_embedding,
query=query,
k=fetch_k,
return_embeddings=True,
filter=filter,
**kwargs,
)
# Get the embeddings for the fetched documents
got_embeddings = [doc.metadata["_embedding_"] for doc, _ in got_docs]
# Select documents using maximal marginal relevance
selected_indices = maximal_marginal_relevance(
np.array(query_embedding), got_embeddings, lambda_mult=lambda_mult, k=k
)
selected_docs = [got_docs[i][0] for i in selected_indices]
# Remove embedding values from metadata
for doc in selected_docs:
del doc.metadata["_embedding_"]
return selected_docs
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
The 'correct' relevance function

View File

@ -14,7 +14,10 @@ from langchain_community.vectorstores.neo4j_vector import (
_get_search_index_query,
)
from langchain_community.vectorstores.utils import DistanceStrategy
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.integration_tests.vectorstores.fake_embeddings import (
AngularTwoDimensionalEmbeddings,
FakeEmbeddings,
)
from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import (
DOCUMENTS,
TYPE_1_FILTERING_TEST_CASES,
@ -928,6 +931,45 @@ OPTIONS {indexConfig: {
drop_vector_indexes(docsearch)
def test_neo4j_max_marginal_relevance_search() -> None:
"""
Test end to end construction and MMR search.
The embedding function used here ensures `texts` become
the following vectors on a circle (numbered v0 through v3):
______ v2
/ \
/ | v1
v3 | . | query
| / v0
|______/ (N.B. very crude drawing)
With fetch_k==3 and k==2, when query is at (1, ),
one expects that v2 and v0 are returned (in some order).
"""
texts = ["-0.124", "+0.127", "+0.25", "+1.0"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = Neo4jVector.from_texts(
texts,
metadatas=metadatas,
embedding=AngularTwoDimensionalEmbeddings(),
pre_delete_collection=True,
)
expected_set = {
("+0.25", 2),
("-0.124", 0),
}
output = docsearch.max_marginal_relevance_search("0.0", k=2, fetch_k=3)
output_set = {
(mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output
}
assert output_set == expected_set
drop_vector_indexes(docsearch)
def test_neo4jvector_passing_graph_object() -> None:
"""Test end to end construction and search with passing graph object."""
graph = Neo4jGraph()