From 4e9ee566ef0e1c073fad2f13621a08745f2ff63b Mon Sep 17 00:00:00 2001 From: Arttii Date: Fri, 31 Mar 2023 05:51:16 +0200 Subject: [PATCH] Add MMR methods to chroma (#2148) Hi, I added MMR similar to faais and milvus to chroma. Please let me know what you think. --- langchain/vectorstores/chroma.py | 66 ++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index f9048eec03..e70b497d73 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -5,9 +5,12 @@ import logging import uuid from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +import numpy as np + from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.vectorstores.base import VectorStore +from langchain.vectorstores.utils import maximal_marginal_relevance if TYPE_CHECKING: import chromadb @@ -182,6 +185,69 @@ class Chroma(VectorStore): return _results_to_docs_and_scores(results) + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + filter: Optional[Dict[str, str]] = None, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + Returns: + List of Documents selected by maximal marginal relevance. + """ + + results = self._collection.query( + query_embeddings=embedding, + n_results=fetch_k, + where=filter, + include=["metadatas", "documents", "distances", "embeddings"], + ) + mmr_selected = maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), results["embeddings"][0], k=k + ) + + candidates = _results_to_docs(results) + + selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected] + return selected_results + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + filter: Optional[Dict[str, str]] = None, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + Returns: + List of Documents selected by maximal marginal relevance. + """ + if self._embedding_function is None: + raise ValueError( + "For MMR search, you must specify an embedding function on" "creation." + ) + + embedding = self._embedding_function.embed_query(query) + docs = self.max_marginal_relevance_search_by_vector( + embedding, k, fetch_k, filter + ) + return docs + def delete_collection(self) -> None: """Delete the collection.""" self._client.delete_collection(self._collection.name)