Add MMR methods to chroma (#2148)

Hi, I added MMR similar to faais and milvus to chroma. Please let me
know what you think.
This commit is contained in:
Arttii 2023-03-31 05:51:16 +02:00 committed by GitHub
parent fc009f61c8
commit 4e9ee566ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,9 +5,12 @@ import logging
import uuid
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
import numpy as np
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
import chromadb
@ -182,6 +185,69 @@ class Chroma(VectorStore):
return _results_to_docs_and_scores(results)
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
filter: Optional[Dict[str, str]] = None,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
Returns:
List of Documents selected by maximal marginal relevance.
"""
results = self._collection.query(
query_embeddings=embedding,
n_results=fetch_k,
where=filter,
include=["metadatas", "documents", "distances", "embeddings"],
)
mmr_selected = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32), results["embeddings"][0], k=k
)
candidates = _results_to_docs(results)
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected]
return selected_results
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
filter: Optional[Dict[str, str]] = None,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
Returns:
List of Documents selected by maximal marginal relevance.
"""
if self._embedding_function is None:
raise ValueError(
"For MMR search, you must specify an embedding function on" "creation."
)
embedding = self._embedding_function.embed_query(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, filter
)
return docs
def delete_collection(self) -> None:
"""Delete the collection."""
self._client.delete_collection(self._collection.name)