From 61e4a1adf96064c2096c0396b64be5b1adb65e75 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 17 Jun 2023 11:00:47 -0700 Subject: [PATCH] Harrison/faiss score (#6341) Co-authored-by: Frank Stein <16441059+simonfromla@users.noreply.github.com> Co-authored-by: Sims Juju --- langchain/retrievers/contextual_compression.py | 16 ++++++++++++---- langchain/vectorstores/base.py | 4 ++-- langchain/vectorstores/faiss.py | 12 ++++++++++++ 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/langchain/retrievers/contextual_compression.py b/langchain/retrievers/contextual_compression.py index 788a391981..8850991d2a 100644 --- a/langchain/retrievers/contextual_compression.py +++ b/langchain/retrievers/contextual_compression.py @@ -34,8 +34,11 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel): Sequence of relevant documents """ docs = self.base_retriever.get_relevant_documents(query) - compressed_docs = self.base_compressor.compress_documents(docs, query) - return list(compressed_docs) + if docs: + compressed_docs = self.base_compressor.compress_documents(docs, query) + return list(compressed_docs) + else: + return [] async def aget_relevant_documents(self, query: str) -> List[Document]: """Get documents relevant for a query. @@ -47,5 +50,10 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel): List of relevant documents """ docs = await self.base_retriever.aget_relevant_documents(query) - compressed_docs = await self.base_compressor.acompress_documents(docs, query) - return list(compressed_docs) + if docs: + compressed_docs = await self.base_compressor.acompress_documents( + docs, query + ) + return list(compressed_docs) + else: + return [] diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index 7a708596a7..6bc5daa03a 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -159,8 +159,8 @@ class VectorStore(ABC): ] if len(docs_and_similarities) == 0: warnings.warn( - f"No relevant docs were retrieved using the relevance score\ - threshold {score_threshold}" + "No relevant docs were retrieved using the relevance score" + f" threshold {score_threshold}" ) return docs_and_similarities diff --git a/langchain/vectorstores/faiss.py b/langchain/vectorstores/faiss.py index dc83326cf4..6dd7fd544e 100644 --- a/langchain/vectorstores/faiss.py +++ b/langchain/vectorstores/faiss.py @@ -185,6 +185,7 @@ class FAISS(VectorStore): k: int = 4, filter: Optional[Dict[str, Any]] = None, fetch_k: int = 20, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs most similar to query. @@ -194,6 +195,9 @@ class FAISS(VectorStore): filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. fetch_k: (Optional[int]) Number of Documents to fetch before filtering. Defaults to 20. + **kwargs: kwargs to be passed to similarity search. Can include: + score_threshold: Optional, a floating point value between 0 to 1 to + filter the resulting set of retrieved docs Returns: List of documents most similar to the query text and L2 distance @@ -218,6 +222,14 @@ class FAISS(VectorStore): docs.append((doc, scores[0][j])) else: docs.append((doc, scores[0][j])) + + score_threshold = kwargs.get("score_threshold") + if score_threshold is not None: + docs = [ + (doc, similarity) + for doc, similarity in docs + if similarity >= score_threshold + ] return docs[:k] def similarity_search_with_score(