add with score option for max marginal relevance (#6867)

### Adding the functionality to return the scores with retrieved
documents when using the max marginal relevance
- Description: Add the method
`max_marginal_relevance_search_with_score_by_vector` to the FAISS
wrapper. Functionality operates the same as
`similarity_search_with_score_by_vector` except for using the max
marginal relevance retrieval framework like is used in the
`max_marginal_relevance_search_by_vector` method.
  - Dependencies: None
  - Tag maintainer: @rlancemartin @eyurtsev 
  - Twitter handle: @RianDolphin

---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
pull/6030/head^2
Rian Dolphin 1 year ago committed by GitHub
parent 398e4cd2dc
commit 2e39ede848
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -37,7 +37,7 @@ def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
else:
import faiss
except ImportError:
raise ValueError(
raise ImportError(
"Could not import faiss python package. "
"Please install it with `pip install faiss` "
"or `pip install faiss-cpu` (depending on Python version)."
@ -321,16 +321,17 @@ class FAISS(VectorStore):
)
return [doc for doc, _ in docs_and_scores]
def max_marginal_relevance_search_by_vector(
def max_marginal_relevance_search_with_score_by_vector(
self,
embedding: List[float],
*,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
) -> List[Tuple[Document, float]]:
"""Return docs and their similarity scores selected using the maximal marginal
relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
@ -345,9 +346,10 @@ class FAISS(VectorStore):
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
List of Documents and similarity scores selected by maximal marginal
relevance and score for each.
"""
_, indices = self.index.search(
scores, indices = self.index.search(
np.array([embedding], dtype=np.float32),
fetch_k if filter is None else fetch_k * 2,
)
@ -373,8 +375,9 @@ class FAISS(VectorStore):
lambda_mult=lambda_mult,
)
selected_indices = [indices[0][i] for i in mmr_selected]
docs = []
for i in selected_indices:
selected_scores = [scores[0][i] for i in mmr_selected]
docs_and_scores = []
for i, score in zip(selected_indices, selected_scores):
if i == -1:
# This happens when not enough docs are returned.
continue
@ -382,8 +385,39 @@ class FAISS(VectorStore):
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs.append(doc)
return docs
docs_and_scores.append((doc, score))
return docs_and_scores
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch before filtering to
pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter
)
return [doc for doc, _ in docs_and_scores]
def max_marginal_relevance_search(
self,
@ -414,8 +448,8 @@ class FAISS(VectorStore):
embedding = self.embedding_function(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding,
k,
fetch_k,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
filter=filter,
**kwargs,

@ -46,9 +46,19 @@ def test_faiss_vector_sim() -> None:
output = docsearch.similarity_search_by_vector(query_vec, k=1)
assert output == [Document(page_content="foo")]
def test_faiss_mmr() -> None:
texts = ["foo", "foo", "fou", "foy"]
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
query_vec = FakeEmbeddings().embed_query(text="foo")
# make sure we can have k > docstore size
output = docsearch.max_marginal_relevance_search_by_vector(query_vec, k=10)
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
query_vec, k=10, lambda_mult=0.1
)
assert len(output) == len(texts)
assert output[0][0] == Document(page_content="foo")
assert output[0][1] == 0.0
assert output[1][0] != Document(page_content="foo")
def test_faiss_with_metadatas() -> None:

Loading…
Cancel
Save