From 7c73e9df5d1e24f5b7d34c4de7bebdaa4bfa962a Mon Sep 17 00:00:00 2001 From: dev2049 <130488702+dev2049@users.noreply.github.com> Date: Sat, 15 Apr 2023 10:49:49 -0700 Subject: [PATCH] Add kwargs to VectorStore.maximum_marginal_relevance (#2921) Same as similarity_search, allows child classes to add vector store-specific args (this was technically already happening in couple places but now typing is correct). --- langchain/vectorstores/base.py | 10 +++++----- langchain/vectorstores/chroma.py | 2 ++ langchain/vectorstores/deeplake.py | 4 ++-- langchain/vectorstores/faiss.py | 8 ++++++-- langchain/vectorstores/qdrant.py | 6 +++++- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index f995d3dd..f6a37286 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -118,7 +118,7 @@ class VectorStore(ABC): return await asyncio.get_event_loop().run_in_executor(None, func) def max_marginal_relevance_search( - self, query: str, k: int = 4, fetch_k: int = 20 + self, query: str, k: int = 4, fetch_k: int = 20, **kwargs: Any ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -136,18 +136,18 @@ class VectorStore(ABC): raise NotImplementedError async def amax_marginal_relevance_search( - self, query: str, k: int = 4, fetch_k: int = 20 + self, query: str, k: int = 4, fetch_k: int = 20, **kwargs: Any ) -> List[Document]: """Return docs selected using the maximal marginal relevance.""" # This is a temporary workaround to make the similarity search # asynchronous. The proper solution is to make the similarity search # asynchronous in the vector store implementations. - func = partial(self.max_marginal_relevance_search, query, k, fetch_k) + func = partial(self.max_marginal_relevance_search, query, k, fetch_k, **kwargs) return await asyncio.get_event_loop().run_in_executor(None, func) def max_marginal_relevance_search_by_vector( - self, embedding: List[float], k: int = 4, fetch_k: int = 20 + self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -165,7 +165,7 @@ class VectorStore(ABC): raise NotImplementedError async def amax_marginal_relevance_search_by_vector( - self, embedding: List[float], k: int = 4, fetch_k: int = 20 + self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any ) -> List[Document]: """Return docs selected using the maximal marginal relevance.""" raise NotImplementedError diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index 60ecf3de..64d34efb 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -193,6 +193,7 @@ class Chroma(VectorStore): k: int = 4, fetch_k: int = 20, filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity @@ -227,6 +228,7 @@ class Chroma(VectorStore): k: int = 4, fetch_k: int = 20, filter: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity diff --git a/langchain/vectorstores/deeplake.py b/langchain/vectorstores/deeplake.py index 8e2e3e5e..71d51cf4 100644 --- a/langchain/vectorstores/deeplake.py +++ b/langchain/vectorstores/deeplake.py @@ -391,7 +391,7 @@ class DeepLake(VectorStore): ) def max_marginal_relevance_search_by_vector( - self, embedding: List[float], k: int = 4, fetch_k: int = 20 + self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any ) -> List[Document]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity @@ -411,7 +411,7 @@ class DeepLake(VectorStore): ) def max_marginal_relevance_search( - self, query: str, k: int = 4, fetch_k: int = 20 + self, query: str, k: int = 4, fetch_k: int = 20, **kwargs: Any ) -> List[Document]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity diff --git a/langchain/vectorstores/faiss.py b/langchain/vectorstores/faiss.py index 4157fa9d..4727a555 100644 --- a/langchain/vectorstores/faiss.py +++ b/langchain/vectorstores/faiss.py @@ -208,7 +208,7 @@ class FAISS(VectorStore): return [doc for doc, _ in docs_and_scores] def max_marginal_relevance_search_by_vector( - self, embedding: List[float], k: int = 4, fetch_k: int = 20 + self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -243,7 +243,11 @@ class FAISS(VectorStore): return docs def max_marginal_relevance_search( - self, query: str, k: int = 4, fetch_k: int = 20 + self, + query: str, + k: int = 4, + fetch_k: int = 20, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. diff --git a/langchain/vectorstores/qdrant.py b/langchain/vectorstores/qdrant.py index 1526379e..b46f38cd 100644 --- a/langchain/vectorstores/qdrant.py +++ b/langchain/vectorstores/qdrant.py @@ -147,7 +147,11 @@ class Qdrant(VectorStore): ] def max_marginal_relevance_search( - self, query: str, k: int = 4, fetch_k: int = 20 + self, + query: str, + k: int = 4, + fetch_k: int = 20, + **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance.