mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
Add mmr to neo4j vector (#25765)
This commit is contained in:
parent
995305fdd5
commit
f359e6b0a5
@ -15,13 +15,17 @@ from typing import (
|
||||
Type,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from langchain_community.vectorstores.utils import DistanceStrategy
|
||||
from langchain_community.vectorstores.utils import (
|
||||
DistanceStrategy,
|
||||
maximal_marginal_relevance,
|
||||
)
|
||||
|
||||
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
|
||||
DISTANCE_MAPPING = {
|
||||
@ -1042,17 +1046,35 @@ class Neo4jVector(VectorStore):
|
||||
filter_params = {}
|
||||
|
||||
if self._index_type == IndexType.RELATIONSHIP:
|
||||
default_retrieval = (
|
||||
f"RETURN relationship.`{self.text_node_property}` AS text, score, "
|
||||
f"relationship {{.*, `{self.text_node_property}`: Null, "
|
||||
f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
|
||||
)
|
||||
if kwargs.get("return_embeddings"):
|
||||
default_retrieval = (
|
||||
f"RETURN relationship.`{self.text_node_property}` AS text, score, "
|
||||
f"relationship {{.*, `{self.text_node_property}`: Null, "
|
||||
f"`{self.embedding_node_property}`: Null, id: Null, "
|
||||
f"_embedding_: relationship.`{self.embedding_node_property}`}} "
|
||||
"AS metadata"
|
||||
)
|
||||
else:
|
||||
default_retrieval = (
|
||||
f"RETURN relationship.`{self.text_node_property}` AS text, score, "
|
||||
f"relationship {{.*, `{self.text_node_property}`: Null, "
|
||||
f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
|
||||
)
|
||||
|
||||
else:
|
||||
default_retrieval = (
|
||||
f"RETURN node.`{self.text_node_property}` AS text, score, "
|
||||
f"node {{.*, `{self.text_node_property}`: Null, "
|
||||
f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
|
||||
)
|
||||
if kwargs.get("return_embeddings"):
|
||||
default_retrieval = (
|
||||
f"RETURN node.`{self.text_node_property}` AS text, score, "
|
||||
f"node {{.*, `{self.text_node_property}`: Null, "
|
||||
f"`{self.embedding_node_property}`: Null, id: Null, "
|
||||
f"_embedding_: node.`{self.embedding_node_property}`}} AS metadata"
|
||||
)
|
||||
else:
|
||||
default_retrieval = (
|
||||
f"RETURN node.`{self.text_node_property}` AS text, score, "
|
||||
f"node {{.*, `{self.text_node_property}`: Null, "
|
||||
f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
|
||||
)
|
||||
|
||||
retrieval_query = (
|
||||
self.retrieval_query if self.retrieval_query else default_retrieval
|
||||
@ -1083,6 +1105,20 @@ class Neo4jVector(VectorStore):
|
||||
"Inspect the `retrieval_query` and ensure it doesn't "
|
||||
"return None for the `text` column"
|
||||
)
|
||||
if kwargs.get("return_embeddings") and any(
|
||||
result["metadata"]["_embedding_"] is None for result in results
|
||||
):
|
||||
if not self.retrieval_query:
|
||||
raise ValueError(
|
||||
f"Make sure that none of the `{self.embedding_node_property}` "
|
||||
f"properties on nodes with label `{self.node_label}` "
|
||||
"are missing or empty"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Inspect the `retrieval_query` and ensure it doesn't "
|
||||
"return None for the `_embedding_` metadata column"
|
||||
)
|
||||
|
||||
docs = [
|
||||
(
|
||||
@ -1487,6 +1523,64 @@ class Neo4jVector(VectorStore):
|
||||
break
|
||||
return store
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query: search query text.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter: Filter on metadata properties, e.g.
|
||||
{
|
||||
"str_property": "foo",
|
||||
"int_property": 123
|
||||
}
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
# Embed the query
|
||||
query_embedding = self.embedding.embed_query(query)
|
||||
|
||||
# Fetch the initial documents
|
||||
got_docs = self.similarity_search_with_score_by_vector(
|
||||
embedding=query_embedding,
|
||||
query=query,
|
||||
k=fetch_k,
|
||||
return_embeddings=True,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Get the embeddings for the fetched documents
|
||||
got_embeddings = [doc.metadata["_embedding_"] for doc, _ in got_docs]
|
||||
|
||||
# Select documents using maximal marginal relevance
|
||||
selected_indices = maximal_marginal_relevance(
|
||||
np.array(query_embedding), got_embeddings, lambda_mult=lambda_mult, k=k
|
||||
)
|
||||
selected_docs = [got_docs[i][0] for i in selected_indices]
|
||||
|
||||
# Remove embedding values from metadata
|
||||
for doc in selected_docs:
|
||||
del doc.metadata["_embedding_"]
|
||||
|
||||
return selected_docs
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
|
@ -14,7 +14,10 @@ from langchain_community.vectorstores.neo4j_vector import (
|
||||
_get_search_index_query,
|
||||
)
|
||||
from langchain_community.vectorstores.utils import DistanceStrategy
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
AngularTwoDimensionalEmbeddings,
|
||||
FakeEmbeddings,
|
||||
)
|
||||
from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import (
|
||||
DOCUMENTS,
|
||||
TYPE_1_FILTERING_TEST_CASES,
|
||||
@ -928,6 +931,45 @@ OPTIONS {indexConfig: {
|
||||
drop_vector_indexes(docsearch)
|
||||
|
||||
|
||||
def test_neo4j_max_marginal_relevance_search() -> None:
|
||||
"""
|
||||
Test end to end construction and MMR search.
|
||||
The embedding function used here ensures `texts` become
|
||||
the following vectors on a circle (numbered v0 through v3):
|
||||
|
||||
______ v2
|
||||
/ \
|
||||
/ | v1
|
||||
v3 | . | query
|
||||
| / v0
|
||||
|______/ (N.B. very crude drawing)
|
||||
|
||||
With fetch_k==3 and k==2, when query is at (1, ),
|
||||
one expects that v2 and v0 are returned (in some order).
|
||||
"""
|
||||
texts = ["-0.124", "+0.127", "+0.25", "+1.0"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = Neo4jVector.from_texts(
|
||||
texts,
|
||||
metadatas=metadatas,
|
||||
embedding=AngularTwoDimensionalEmbeddings(),
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
|
||||
expected_set = {
|
||||
("+0.25", 2),
|
||||
("-0.124", 0),
|
||||
}
|
||||
|
||||
output = docsearch.max_marginal_relevance_search("0.0", k=2, fetch_k=3)
|
||||
output_set = {
|
||||
(mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output
|
||||
}
|
||||
assert output_set == expected_set
|
||||
|
||||
drop_vector_indexes(docsearch)
|
||||
|
||||
|
||||
def test_neo4jvector_passing_graph_object() -> None:
|
||||
"""Test end to end construction and search with passing graph object."""
|
||||
graph = Neo4jGraph()
|
||||
|
Loading…
Reference in New Issue
Block a user