Add kwargs to VectorStore.maximum_marginal_relevance (#2921)

Same as similarity_search, allows child classes to add vector
store-specific args (this was technically already happening in couple
places but now typing is correct).
This commit is contained in:
dev2049 2023-04-15 10:49:49 -07:00 committed by GitHub
parent b3a5b51728
commit 7c73e9df5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 10 deletions

View File

@ -118,7 +118,7 @@ class VectorStore(ABC):
return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search(
self, query: str, k: int = 4, fetch_k: int = 20
self, query: str, k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
@ -136,18 +136,18 @@ class VectorStore(ABC):
raise NotImplementedError
async def amax_marginal_relevance_search(
self, query: str, k: int = 4, fetch_k: int = 20
self, query: str, k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance."""
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(self.max_marginal_relevance_search, query, k, fetch_k)
func = partial(self.max_marginal_relevance_search, query, k, fetch_k, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search_by_vector(
self, embedding: List[float], k: int = 4, fetch_k: int = 20
self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
@ -165,7 +165,7 @@ class VectorStore(ABC):
raise NotImplementedError
async def amax_marginal_relevance_search_by_vector(
self, embedding: List[float], k: int = 4, fetch_k: int = 20
self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance."""
raise NotImplementedError

View File

@ -193,6 +193,7 @@ class Chroma(VectorStore):
k: int = 4,
fetch_k: int = 20,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
@ -227,6 +228,7 @@ class Chroma(VectorStore):
k: int = 4,
fetch_k: int = 20,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity

View File

@ -391,7 +391,7 @@ class DeepLake(VectorStore):
)
def max_marginal_relevance_search_by_vector(
self, embedding: List[float], k: int = 4, fetch_k: int = 20
self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
@ -411,7 +411,7 @@ class DeepLake(VectorStore):
)
def max_marginal_relevance_search(
self, query: str, k: int = 4, fetch_k: int = 20
self, query: str, k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity

View File

@ -208,7 +208,7 @@ class FAISS(VectorStore):
return [doc for doc, _ in docs_and_scores]
def max_marginal_relevance_search_by_vector(
self, embedding: List[float], k: int = 4, fetch_k: int = 20
self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
@ -243,7 +243,11 @@ class FAISS(VectorStore):
return docs
def max_marginal_relevance_search(
self, query: str, k: int = 4, fetch_k: int = 20
self,
query: str,
k: int = 4,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.

View File

@ -147,7 +147,11 @@ class Qdrant(VectorStore):
]
def max_marginal_relevance_search(
self, query: str, k: int = 4, fetch_k: int = 20
self,
query: str,
k: int = 4,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.