From f9edf76e7c5d5a21e7b73517f324708fd62b84c1 Mon Sep 17 00:00:00 2001 From: Wenchen Li <9028430+neo@users.noreply.github.com> Date: Wed, 14 Jun 2023 01:46:45 +0800 Subject: [PATCH] Implement `max_marginal_relevance_search` in `VectorStore` of Pinecone (#6056) This adds implementation of MMR search in pinecone; and I have two semi-related observations about this vector store class: - Maybe we should also have a `similarity_search_by_vector_returning_embeddings` like in supabase, but it's not in the base `VectorStore` class so I didn't implement - Talking about the base class, there's `similarity_search_with_relevance_scores`, but in pinecone it is called `similarity_search_with_score`; maybe we should consider renaming it to align with other `VectorStore` base and sub classes (or add that as an alias for backward compatibility) #### Who can review? Tag maintainers/contributors who might be interested: - VectorStores / Retrievers / Memory - @dev2049 --- .../vectorstores/examples/pinecone.ipynb | 46 ++++++++++- langchain/vectorstores/pinecone.py | 82 +++++++++++++++++++ 2 files changed, 124 insertions(+), 4 deletions(-) diff --git a/docs/modules/indexes/vectorstores/examples/pinecone.ipynb b/docs/modules/indexes/vectorstores/examples/pinecone.ipynb index 104b23c6..c77edf23 100644 --- a/docs/modules/indexes/vectorstores/examples/pinecone.ipynb +++ b/docs/modules/indexes/vectorstores/examples/pinecone.ipynb @@ -24,7 +24,7 @@ }, "outputs": [], "source": [ - "!pip install pinecone-client" + "!pip install pinecone-client openai tiktoken" ] }, { @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "aac9563e", "metadata": { "tags": [] @@ -85,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "a3c3999a", "metadata": {}, "outputs": [], @@ -135,13 +135,51 @@ "print(docs[0].page_content)" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "d46d1452", + "metadata": {}, + "source": [ + "### Maximal Marginal Relevance Searches\n", + "\n", + "In addition to using similarity search in the retriever object, you can also use `mmr` as retriever.\n" + ] + }, { "cell_type": "code", "execution_count": null, "id": "a359ed74", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "retriever = docsearch.as_retriever(search_type=\"mmr\")\n", + "matched_docs = retriever.get_relevant_documents(query)\n", + "for i, d in enumerate(matched_docs):\n", + " print(f\"\\n## Document {i}\\n\")\n", + " print(d.page_content)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "7c477287", + "metadata": {}, + "source": [ + "Or use `max_marginal_relevance_search` directly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ca82740", + "metadata": {}, + "outputs": [], + "source": [ + "found_docs = docsearch.max_marginal_relevance_search(query, k=2, fetch_k=10)\n", + "for i, doc in enumerate(found_docs):\n", + " print(f\"{i + 1}.\", doc.page_content, \"\\n\")" + ] } ], "metadata": { diff --git a/langchain/vectorstores/pinecone.py b/langchain/vectorstores/pinecone.py index f9a6fe9b..48ec6128 100644 --- a/langchain/vectorstores/pinecone.py +++ b/langchain/vectorstores/pinecone.py @@ -5,9 +5,12 @@ import logging import uuid from typing import Any, Callable, Iterable, List, Optional, Tuple +import numpy as np + from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.vectorstores.base import VectorStore +from langchain.vectorstores.utils import maximal_marginal_relevance logger = logging.getLogger(__name__) @@ -157,6 +160,85 @@ class Pinecone(VectorStore): ) return [doc for doc, _ in docs_and_scores] + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[dict] = None, + namespace: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + Returns: + List of Documents selected by maximal marginal relevance. + """ + if namespace is None: + namespace = self._namespace + results = self._index.query( + [embedding], + top_k=fetch_k, + include_values=True, + include_metadata=True, + namespace=namespace, + filter=filter, + ) + mmr_selected = maximal_marginal_relevance( + np.array([embedding], dtype=np.float32), + [item["values"] for item in results["matches"]], + k=k, + lambda_mult=lambda_mult, + ) + selected = [results["matches"][i]["metadata"] for i in mmr_selected] + return [ + Document(page_content=metadata.pop((self._text_key)), metadata=metadata) + for metadata in selected + ] + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[dict] = None, + namespace: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + Returns: + List of Documents selected by maximal marginal relevance. + """ + embedding = self._embedding_function(query) + return self.max_marginal_relevance_search_by_vector( + embedding, k, fetch_k, lambda_mult, filter, namespace + ) + @classmethod def from_texts( cls,