Add kwargs to VectorStore.maximum_marginal_relevance (#2921)

Same as similarity_search, allows child classes to add vector
store-specific args (this was technically already happening in couple
places but now typing is correct).
fix_agent_callbacks
dev2049 1 year ago committed by GitHub
parent b3a5b51728
commit 7c73e9df5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -118,7 +118,7 @@ class VectorStore(ABC):
return await asyncio.get_event_loop().run_in_executor(None, func) return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search( def max_marginal_relevance_search(
self, query: str, k: int = 4, fetch_k: int = 20 self, query: str, k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
@ -136,18 +136,18 @@ class VectorStore(ABC):
raise NotImplementedError raise NotImplementedError
async def amax_marginal_relevance_search( async def amax_marginal_relevance_search(
self, query: str, k: int = 4, fetch_k: int = 20 self, query: str, k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.""" """Return docs selected using the maximal marginal relevance."""
# This is a temporary workaround to make the similarity search # This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search # asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations. # asynchronous in the vector store implementations.
func = partial(self.max_marginal_relevance_search, query, k, fetch_k) func = partial(self.max_marginal_relevance_search, query, k, fetch_k, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func) return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector(
self, embedding: List[float], k: int = 4, fetch_k: int = 20 self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
@ -165,7 +165,7 @@ class VectorStore(ABC):
raise NotImplementedError raise NotImplementedError
async def amax_marginal_relevance_search_by_vector( async def amax_marginal_relevance_search_by_vector(
self, embedding: List[float], k: int = 4, fetch_k: int = 20 self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.""" """Return docs selected using the maximal marginal relevance."""
raise NotImplementedError raise NotImplementedError

@ -193,6 +193,7 @@ class Chroma(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
filter: Optional[Dict[str, str]] = None, filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
@ -227,6 +228,7 @@ class Chroma(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
filter: Optional[Dict[str, str]] = None, filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity

@ -391,7 +391,7 @@ class DeepLake(VectorStore):
) )
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector(
self, embedding: List[float], k: int = 4, fetch_k: int = 20 self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
@ -411,7 +411,7 @@ class DeepLake(VectorStore):
) )
def max_marginal_relevance_search( def max_marginal_relevance_search(
self, query: str, k: int = 4, fetch_k: int = 20 self, query: str, k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity

@ -208,7 +208,7 @@ class FAISS(VectorStore):
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector(
self, embedding: List[float], k: int = 4, fetch_k: int = 20 self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
@ -243,7 +243,11 @@ class FAISS(VectorStore):
return docs return docs
def max_marginal_relevance_search( def max_marginal_relevance_search(
self, query: str, k: int = 4, fetch_k: int = 20 self,
query: str,
k: int = 4,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.

@ -147,7 +147,11 @@ class Qdrant(VectorStore):
] ]
def max_marginal_relevance_search( def max_marginal_relevance_search(
self, query: str, k: int = 4, fetch_k: int = 20 self,
query: str,
k: int = 4,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.

Loading…
Cancel
Save