Add MMR functionality to elasticsearch retriever (#11633)

Allows MMR functionality only for the case where we have access to the
embedding function. Also allows for users to request for fields from
elasticsearch store. These are added to the document metadata.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/11718/head
sudranga 12 months ago committed by GitHub
parent ead9d5b55c
commit 361f8e1bc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,10 +14,12 @@ from typing import (
Union,
)
import numpy as np
from langchain.docstore.document import Document
from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore
from langchain.vectorstores.utils import DistanceStrategy
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
@ -603,6 +605,67 @@ class ElasticsearchStore(VectorStore):
results = self._search(query=query, k=k, filter=filter, **kwargs)
return [doc for doc, _ in results]
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
fields: Optional[List[str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query (str): Text to look up documents similar to.
k (int): Number of Documents to return. Defaults to 4.
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
lambda_mult (float): Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
fields: Other fields to get from elasticsearch source. These fields
will be added to the document metadata.
Returns:
List[Document]: A list of Documents selected by maximal marginal relevance.
"""
if self.embedding is None:
raise ValueError("You must provide an embedding function to perform MMR")
remove_vector_query_field_from_metadata = True
if fields is None:
fields = [self.vector_query_field]
elif self.vector_query_field not in fields:
fields.append(self.vector_query_field)
else:
remove_vector_query_field_from_metadata = False
# Embed the query
query_embedding = self.embedding.embed_query(query)
# Fetch the initial documents
got_docs = self._search(
query_vector=query_embedding, k=fetch_k, fields=fields, **kwargs
)
# Get the embeddings for the fetched documents
got_embeddings = [doc.metadata[self.vector_query_field] for doc, _ in got_docs]
# Select documents using maximal marginal relevance
selected_indices = maximal_marginal_relevance(
np.array(query_embedding), got_embeddings, lambda_mult=lambda_mult, k=k
)
selected_docs = [got_docs[i][0] for i in selected_indices]
if remove_vector_query_field_from_metadata:
for doc in selected_docs:
del doc.metadata["vector"]
return selected_docs
def similarity_search_with_score(
self, query: str, k: int = 4, filter: Optional[List[dict]] = None, **kwargs: Any
) -> List[Tuple[Document, float]]:
@ -665,7 +728,10 @@ class ElasticsearchStore(VectorStore):
List of Documents most similar to the query and score for each
"""
if fields is None:
fields = ["metadata"]
fields = []
if "metadata" not in fields:
fields.append("metadata")
if self.query_field not in fields:
fields.append(self.query_field)
@ -689,7 +755,6 @@ class ElasticsearchStore(VectorStore):
if custom_query is not None:
query_body = custom_query(query_body, query)
logger.debug(f"Calling custom_query, Query body now: {query_body}")
# Perform the kNN search on the Elasticsearch index and return the results.
response = self.client.search(
index=self.index_name,
@ -698,18 +763,24 @@ class ElasticsearchStore(VectorStore):
source=fields,
)
hits = [hit for hit in response["hits"]["hits"]]
docs_and_scores = [
(
Document(
page_content=hit["_source"][self.query_field],
metadata=hit["_source"]["metadata"],
),
hit["_score"],
docs_and_scores = []
for hit in response["hits"]["hits"]:
for field in fields:
if field in hit["_source"] and field not in [
"metadata",
self.query_field,
]:
hit["_source"]["metadata"][field] = hit["_source"][field]
docs_and_scores.append(
(
Document(
page_content=hit["_source"][self.query_field],
metadata=hit["_source"]["metadata"],
),
hit["_score"],
)
)
for hit in hits
]
return docs_and_scores
def delete(

@ -385,6 +385,39 @@ class TestElasticsearch:
distance_strategy="NOT_A_STRATEGY",
)
def test_max_marginal_relevance_search(
self, elasticsearch_connection: dict, index_name: str
) -> None:
"""Test max marginal relevance search."""
texts = ["foo", "bar", "baz"]
docsearch = ElasticsearchStore.from_texts(
texts,
FakeEmbeddings(),
**elasticsearch_connection,
index_name=index_name,
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
)
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=3)
sim_output = docsearch.similarity_search(texts[0], k=3)
assert mmr_output == sim_output
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=2, fetch_k=3)
assert len(mmr_output) == 2
assert mmr_output[0].page_content == texts[0]
assert mmr_output[1].page_content == texts[1]
mmr_output = docsearch.max_marginal_relevance_search(
texts[0], k=2, fetch_k=3, lambda_mult=0.1 # more diversity
)
assert len(mmr_output) == 2
assert mmr_output[0].page_content == texts[0]
assert mmr_output[1].page_content == texts[2]
# if fetch_k < k, then the output will be less than k
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=2)
assert len(mmr_output) == 2
def test_similarity_search_approx_with_hybrid_search(
self, elasticsearch_connection: dict, index_name: str
) -> None:

Loading…
Cancel
Save