From f0a258555b0b7c128def0bb19bb74fba0fb82185 Mon Sep 17 00:00:00 2001 From: seanaedmiston Date: Thu, 16 Feb 2023 17:50:00 +1100 Subject: [PATCH] Support similarity search by vector (in FAISS) (#961) Alternate implementation to PR #960 Again - only FAISS is implemented. If accepted can add this to other vectorstores or leave as NotImplemented? Suggestions welcome... --- .../combine_docs_examples/vectorstores.ipynb | 20 ++++++ langchain/vectorstores/base.py | 32 +++++++++ langchain/vectorstores/faiss.py | 67 ++++++++++++++++--- .../vectorstores/test_faiss.py | 18 +++++ 4 files changed, 129 insertions(+), 8 deletions(-) diff --git a/docs/modules/utils/combine_docs_examples/vectorstores.ipynb b/docs/modules/utils/combine_docs_examples/vectorstores.ipynb index 1acc581a..04d8073e 100644 --- a/docs/modules/utils/combine_docs_examples/vectorstores.ipynb +++ b/docs/modules/utils/combine_docs_examples/vectorstores.ipynb @@ -297,6 +297,26 @@ "docs_and_scores[0]" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "d5170563", + "metadata": {}, + "source": [ + "It is also possible to do a search for documents similar to a given embedding vector using `similarity_search_by_vector` which accepts an embedding vector as a parameter instead of a string." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7675b0aa", + "metadata": {}, + "outputs": [], + "source": [ + "embedding_vector = embeddings.embed_query(query)\n", + "docs_and_scores = docsearch.similarity_search_by_vector(embedding_vector)" + ] + }, { "cell_type": "markdown", "id": "b386dbb8", diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index 9bfddcb1..c7e1a33a 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -31,6 +31,20 @@ class VectorStore(ABC): ) -> List[Document]: """Return docs most similar to query.""" + def similarity_search_by_vector( + self, embedding: List[float], k: int = 4, **kwargs: Any + ) -> List[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of Documents most similar to the query vector. + """ + raise NotImplementedError + def max_marginal_relevance_search( self, query: str, k: int = 4, fetch_k: int = 20 ) -> List[Document]: @@ -49,6 +63,24 @@ class VectorStore(ABC): """ raise NotImplementedError + def max_marginal_relevance_search_by_vector( + self, embedding: List[float], k: int = 4, fetch_k: int = 20 + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + + Returns: + List of Documents selected by maximal marginal relevance. + """ + raise NotImplementedError + @classmethod def from_documents( cls, diff --git a/langchain/vectorstores/faiss.py b/langchain/vectorstores/faiss.py index b3d532d9..ed1eeccd 100644 --- a/langchain/vectorstores/faiss.py +++ b/langchain/vectorstores/faiss.py @@ -92,8 +92,8 @@ class FAISS(VectorStore): self.index_to_docstore_id.update(index_to_id) return [_id for _, _id, _ in full_info] - def similarity_search_with_score( - self, query: str, k: int = 4 + def similarity_search_with_score_by_vector( + self, embedding: List[float], k: int = 4 ) -> List[Tuple[Document, float]]: """Return docs most similar to query. @@ -104,7 +104,6 @@ class FAISS(VectorStore): Returns: List of Documents most similar to the query and score for each """ - embedding = self.embedding_function(query) scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) docs = [] for j, i in enumerate(indices[0]): @@ -118,6 +117,37 @@ class FAISS(VectorStore): docs.append((doc, scores[0][j])) return docs + def similarity_search_with_score( + self, query: str, k: int = 4 + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of Documents most similar to the query and score for each + """ + embedding = self.embedding_function(query) + docs = self.similarity_search_with_score_by_vector(embedding, k) + return docs + + def similarity_search_by_vector( + self, embedding: List[float], k: int = 4, **kwargs: Any + ) -> List[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of Documents most similar to the embedding. + """ + docs_and_scores = self.similarity_search_with_score_by_vector(embedding, k) + return [doc for doc, _ in docs_and_scores] + def similarity_search( self, query: str, k: int = 4, **kwargs: Any ) -> List[Document]: @@ -133,8 +163,8 @@ class FAISS(VectorStore): docs_and_scores = self.similarity_search_with_score(query, k) return [doc for doc, _ in docs_and_scores] - def max_marginal_relevance_search( - self, query: str, k: int = 4, fetch_k: int = 20 + def max_marginal_relevance_search_by_vector( + self, embedding: List[float], k: int = 4, fetch_k: int = 20 ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -142,18 +172,19 @@ class FAISS(VectorStore): among selected documents. Args: - query: Text to look up documents similar to. + embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. fetch_k: Number of Documents to fetch to pass to MMR algorithm. Returns: List of Documents selected by maximal marginal relevance. """ - embedding = self.embedding_function(query) _, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k) # -1 happens when not enough docs are returned. embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1] - mmr_selected = maximal_marginal_relevance(embedding, embeddings, k=k) + mmr_selected = maximal_marginal_relevance( + np.array([embedding], dtype=np.float32), embeddings, k=k + ) selected_indices = [indices[0][i] for i in mmr_selected] docs = [] for i in selected_indices: @@ -164,6 +195,26 @@ class FAISS(VectorStore): docs.append(doc) return docs + def max_marginal_relevance_search( + self, query: str, k: int = 4, fetch_k: int = 20 + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + + Returns: + List of Documents selected by maximal marginal relevance. + """ + embedding = self.embedding_function(query) + docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k) + return docs + @classmethod def from_texts( cls, diff --git a/tests/integration_tests/vectorstores/test_faiss.py b/tests/integration_tests/vectorstores/test_faiss.py index 3ee396f3..4cfe18ed 100644 --- a/tests/integration_tests/vectorstores/test_faiss.py +++ b/tests/integration_tests/vectorstores/test_faiss.py @@ -27,6 +27,24 @@ def test_faiss() -> None: assert output == [Document(page_content="foo")] +def test_faiss_vector_sim() -> None: + """Test vector similarity.""" + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + index_to_id = docsearch.index_to_docstore_id + expected_docstore = InMemoryDocstore( + { + index_to_id[0]: Document(page_content="foo"), + index_to_id[1]: Document(page_content="bar"), + index_to_id[2]: Document(page_content="baz"), + } + ) + assert docsearch.docstore.__dict__ == expected_docstore.__dict__ + query_vec = FakeEmbeddings().embed_query(text="foo") + output = docsearch.similarity_search_by_vector(query_vec, k=1) + assert output == [Document(page_content="foo")] + + def test_faiss_with_metadatas() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"]