@ -14,10 +14,12 @@ from typing import (
Union ,
)
import numpy as np
from langchain . docstore . document import Document
from langchain . schema . embeddings import Embeddings
from langchain . schema . vectorstore import VectorStore
from langchain . vectorstores . utils import DistanceStrategy
from langchain . vectorstores . utils import DistanceStrategy , maximal_marginal_relevance
if TYPE_CHECKING :
from elasticsearch import Elasticsearch
@ -603,6 +605,67 @@ class ElasticsearchStore(VectorStore):
results = self . _search ( query = query , k = k , filter = filter , * * kwargs )
return [ doc for doc , _ in results ]
def max_marginal_relevance_search (
self ,
query : str ,
k : int = 4 ,
fetch_k : int = 20 ,
lambda_mult : float = 0.5 ,
fields : Optional [ List [ str ] ] = None ,
* * kwargs : Any ,
) - > List [ Document ] :
""" Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents .
Args :
query ( str ) : Text to look up documents similar to .
k ( int ) : Number of Documents to return . Defaults to 4.
fetch_k ( int ) : Number of Documents to fetch to pass to MMR algorithm .
lambda_mult ( float ) : Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity .
Defaults to 0.5 .
fields : Other fields to get from elasticsearch source . These fields
will be added to the document metadata .
Returns :
List [ Document ] : A list of Documents selected by maximal marginal relevance .
"""
if self . embedding is None :
raise ValueError ( " You must provide an embedding function to perform MMR " )
remove_vector_query_field_from_metadata = True
if fields is None :
fields = [ self . vector_query_field ]
elif self . vector_query_field not in fields :
fields . append ( self . vector_query_field )
else :
remove_vector_query_field_from_metadata = False
# Embed the query
query_embedding = self . embedding . embed_query ( query )
# Fetch the initial documents
got_docs = self . _search (
query_vector = query_embedding , k = fetch_k , fields = fields , * * kwargs
)
# Get the embeddings for the fetched documents
got_embeddings = [ doc . metadata [ self . vector_query_field ] for doc , _ in got_docs ]
# Select documents using maximal marginal relevance
selected_indices = maximal_marginal_relevance (
np . array ( query_embedding ) , got_embeddings , lambda_mult = lambda_mult , k = k
)
selected_docs = [ got_docs [ i ] [ 0 ] for i in selected_indices ]
if remove_vector_query_field_from_metadata :
for doc in selected_docs :
del doc . metadata [ " vector " ]
return selected_docs
def similarity_search_with_score (
self , query : str , k : int = 4 , filter : Optional [ List [ dict ] ] = None , * * kwargs : Any
) - > List [ Tuple [ Document , float ] ] :
@ -665,7 +728,10 @@ class ElasticsearchStore(VectorStore):
List of Documents most similar to the query and score for each
"""
if fields is None :
fields = [ " metadata " ]
fields = [ ]
if " metadata " not in fields :
fields . append ( " metadata " )
if self . query_field not in fields :
fields . append ( self . query_field )
@ -689,7 +755,6 @@ class ElasticsearchStore(VectorStore):
if custom_query is not None :
query_body = custom_query ( query_body , query )
logger . debug ( f " Calling custom_query, Query body now: { query_body } " )
# Perform the kNN search on the Elasticsearch index and return the results.
response = self . client . search (
index = self . index_name ,
@ -698,18 +763,24 @@ class ElasticsearchStore(VectorStore):
source = fields ,
)
hits = [ hit for hit in response [ " hits " ] [ " hits " ] ]
docs_and_scores = [
(
Document (
page_content = hit [ " _source " ] [ self . query_field ] ,
metadata = hit [ " _source " ] [ " metadata " ] ,
) ,
hit [ " _score " ] ,
docs_and_scores = [ ]
for hit in response [ " hits " ] [ " hits " ] :
for field in fields :
if field in hit [ " _source " ] and field not in [
" metadata " ,
self . query_field ,
] :
hit [ " _source " ] [ " metadata " ] [ field ] = hit [ " _source " ] [ field ]
docs_and_scores . append (
(
Document (
page_content = hit [ " _source " ] [ self . query_field ] ,
metadata = hit [ " _source " ] [ " metadata " ] ,
) ,
hit [ " _score " ] ,
)
)
for hit in hits
]
return docs_and_scores
def delete (