Add maximal relevance search to SKLearnVectorStore (#5430)

# Add maximal relevance search to SKLearnVectorStore

This PR implements the maximum relevance search in SKLearnVectorStore. 

Twitter handle: jtolgyesi (I submitted also the original implementation
of SKLearnVectorStore)

## Before submitting

Unit tests are included.

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
searx_updates
Janos Tolgyesi 1 year ago committed by GitHub
parent 8181f9e362
commit 1111f18eb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,6 +14,10 @@ from uuid import uuid4
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
DEFAULT_K = 4 # Number of Documents to return.
DEFAULT_FETCH_K = 20 # Number of Documents to initially fetch during MMR search.
def guard_import( def guard_import(
@ -223,39 +227,127 @@ class SKLearnVectorStore(VectorStore):
self._neighbors.fit(self._embeddings_np) self._neighbors.fit(self._embeddings_np)
self._neighbors_fitted = True self._neighbors_fitted = True
def similarity_search_with_score( def _similarity_index_search_with_score(
self, query: str, *, k: int = 4, **kwargs: Any self, query_embedding: List[float], *, k: int = DEFAULT_K, **kwargs: Any
) -> List[Tuple[Document, float]]: ) -> List[Tuple[int, float]]:
"""Search k embeddings similar to the query embedding. Returns a list of
(index, distance) tuples."""
if not self._neighbors_fitted: if not self._neighbors_fitted:
raise SKLearnVectorStoreException( raise SKLearnVectorStoreException(
"No data was added to SKLearnVectorStore." "No data was added to SKLearnVectorStore."
) )
query_embedding = self._embedding_function.embed_query(query)
neigh_dists, neigh_idxs = self._neighbors.kneighbors( neigh_dists, neigh_idxs = self._neighbors.kneighbors(
[query_embedding], n_neighbors=k [query_embedding], n_neighbors=k
) )
res = [] return list(zip(neigh_idxs[0], neigh_dists[0]))
for idx, dist in zip(neigh_idxs[0], neigh_dists[0]):
_idx = int(idx) def similarity_search_with_score(
metadata = {"id": self._ids[_idx], **self._metadatas[_idx]} self, query: str, *, k: int = DEFAULT_K, **kwargs: Any
doc = Document(page_content=self._texts[_idx], metadata=metadata) ) -> List[Tuple[Document, float]]:
res.append((doc, dist)) query_embedding = self._embedding_function.embed_query(query)
return res indices_dists = self._similarity_index_search_with_score(
query_embedding, k=k, **kwargs
)
return [
(
Document(
page_content=self._texts[idx],
metadata={"id": self._ids[idx], **self._metadatas[idx]},
),
dist,
)
for idx, dist in indices_dists
]
def similarity_search( def similarity_search(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = DEFAULT_K, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
docs_scores = self.similarity_search_with_score(query, k=k, **kwargs) docs_scores = self.similarity_search_with_score(query, k=k, **kwargs)
return [doc for doc, _ in docs_scores] return [doc for doc, _ in docs_scores]
def _similarity_search_with_relevance_scores( def _similarity_search_with_relevance_scores(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = DEFAULT_K, **kwargs: Any
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
docs_dists = self.similarity_search_with_score(query=query, k=k, **kwargs) docs_dists = self.similarity_search_with_score(query, k=k, **kwargs)
docs, dists = zip(*docs_dists) docs, dists = zip(*docs_dists)
scores = [1 / math.exp(dist) for dist in dists] scores = [1 / math.exp(dist) for dist in dists]
return list(zip(list(docs), scores)) return list(zip(list(docs), scores))
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = DEFAULT_K,
fetch_k: int = DEFAULT_FETCH_K,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
indices_dists = self._similarity_index_search_with_score(
embedding, k=fetch_k, **kwargs
)
indices, _ = zip(*indices_dists)
result_embeddings = self._embeddings_np[indices,]
mmr_selected = maximal_marginal_relevance(
self._np.array(embedding, dtype=self._np.float32),
result_embeddings,
k=k,
lambda_mult=lambda_mult,
)
mmr_indices = [indices[i] for i in mmr_selected]
return [
Document(
page_content=self._texts[idx],
metadata={"id": self._ids[idx], **self._metadatas[idx]},
)
for idx in mmr_indices
]
def max_marginal_relevance_search(
self,
query: str,
k: int = DEFAULT_K,
fetch_k: int = DEFAULT_FETCH_K,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
if self._embedding_function is None:
raise ValueError(
"For MMR search, you must specify an embedding function on creation."
)
embedding = self._embedding_function.embed_query(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, lambda_mul=lambda_mult
)
return docs
@classmethod @classmethod
def from_texts( def from_texts(
cls, cls,

@ -11,7 +11,7 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
def test_sklearn() -> None: def test_sklearn() -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
docsearch = SKLearnVectorStore.from_texts(texts, embedding=FakeEmbeddings()) docsearch = SKLearnVectorStore.from_texts(texts, FakeEmbeddings())
output = docsearch.similarity_search("foo", k=1) output = docsearch.similarity_search("foo", k=1)
assert len(output) == 1 assert len(output) == 1
assert output[0].page_content == "foo" assert output[0].page_content == "foo"
@ -24,7 +24,7 @@ def test_sklearn_with_metadatas() -> None:
metadatas = [{"page": str(i)} for i in range(len(texts))] metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = SKLearnVectorStore.from_texts( docsearch = SKLearnVectorStore.from_texts(
texts, texts,
embedding=FakeEmbeddings(), FakeEmbeddings(),
metadatas=metadatas, metadatas=metadatas,
) )
output = docsearch.similarity_search("foo", k=1) output = docsearch.similarity_search("foo", k=1)
@ -38,7 +38,7 @@ def test_sklearn_with_metadatas_with_scores() -> None:
metadatas = [{"page": str(i)} for i in range(len(texts))] metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = SKLearnVectorStore.from_texts( docsearch = SKLearnVectorStore.from_texts(
texts, texts,
embedding=FakeEmbeddings(), FakeEmbeddings(),
metadatas=metadatas, metadatas=metadatas,
) )
output = docsearch.similarity_search_with_relevance_scores("foo", k=1) output = docsearch.similarity_search_with_relevance_scores("foo", k=1)
@ -69,8 +69,32 @@ def test_sklearn_with_persistence(tmpdir: Path) -> None:
# Get a new VectorStore from the persisted directory # Get a new VectorStore from the persisted directory
docsearch = SKLearnVectorStore( docsearch = SKLearnVectorStore(
embedding=FakeEmbeddings(), persist_path=str(persist_path), serializer="json" FakeEmbeddings(), persist_path=str(persist_path), serializer="json"
) )
output = docsearch.similarity_search("foo", k=1) output = docsearch.similarity_search("foo", k=1)
assert len(output) == 1 assert len(output) == 1
assert output[0].page_content == "foo" assert output[0].page_content == "foo"
@pytest.mark.requires("numpy", "sklearn")
def test_sklearn_mmr() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch = SKLearnVectorStore.from_texts(texts, FakeEmbeddings())
output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3)
assert len(output) == 1
assert output[0].page_content == "foo"
@pytest.mark.requires("numpy", "sklearn")
def test_sklearn_mmr_by_vector() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
embeddings = FakeEmbeddings()
docsearch = SKLearnVectorStore.from_texts(texts, embeddings)
embedded_query = embeddings.embed_query("foo")
output = docsearch.max_marginal_relevance_search_by_vector(
embedded_query, k=1, fetch_k=3
)
assert len(output) == 1
assert output[0].page_content == "foo"

Loading…
Cancel
Save