|
|
|
@ -5,9 +5,12 @@ import logging
|
|
|
|
|
import uuid
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
from langchain.docstore.document import Document
|
|
|
|
|
from langchain.embeddings.base import Embeddings
|
|
|
|
|
from langchain.vectorstores.base import VectorStore
|
|
|
|
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
import chromadb
|
|
|
|
@ -182,6 +185,69 @@ class Chroma(VectorStore):
|
|
|
|
|
|
|
|
|
|
return _results_to_docs_and_scores(results)
|
|
|
|
|
|
|
|
|
|
def max_marginal_relevance_search_by_vector(
|
|
|
|
|
self,
|
|
|
|
|
embedding: List[float],
|
|
|
|
|
k: int = 4,
|
|
|
|
|
fetch_k: int = 20,
|
|
|
|
|
filter: Optional[Dict[str, str]] = None,
|
|
|
|
|
) -> List[Document]:
|
|
|
|
|
"""Return docs selected using the maximal marginal relevance.
|
|
|
|
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
|
|
|
|
among selected documents.
|
|
|
|
|
Args:
|
|
|
|
|
embedding: Embedding to look up documents similar to.
|
|
|
|
|
k: Number of Documents to return. Defaults to 4.
|
|
|
|
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
|
|
|
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
|
|
|
|
Returns:
|
|
|
|
|
List of Documents selected by maximal marginal relevance.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
results = self._collection.query(
|
|
|
|
|
query_embeddings=embedding,
|
|
|
|
|
n_results=fetch_k,
|
|
|
|
|
where=filter,
|
|
|
|
|
include=["metadatas", "documents", "distances", "embeddings"],
|
|
|
|
|
)
|
|
|
|
|
mmr_selected = maximal_marginal_relevance(
|
|
|
|
|
np.array(embedding, dtype=np.float32), results["embeddings"][0], k=k
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
candidates = _results_to_docs(results)
|
|
|
|
|
|
|
|
|
|
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected]
|
|
|
|
|
return selected_results
|
|
|
|
|
|
|
|
|
|
def max_marginal_relevance_search(
|
|
|
|
|
self,
|
|
|
|
|
query: str,
|
|
|
|
|
k: int = 4,
|
|
|
|
|
fetch_k: int = 20,
|
|
|
|
|
filter: Optional[Dict[str, str]] = None,
|
|
|
|
|
) -> List[Document]:
|
|
|
|
|
"""Return docs selected using the maximal marginal relevance.
|
|
|
|
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
|
|
|
|
among selected documents.
|
|
|
|
|
Args:
|
|
|
|
|
query: Text to look up documents similar to.
|
|
|
|
|
k: Number of Documents to return. Defaults to 4.
|
|
|
|
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
|
|
|
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
|
|
|
|
Returns:
|
|
|
|
|
List of Documents selected by maximal marginal relevance.
|
|
|
|
|
"""
|
|
|
|
|
if self._embedding_function is None:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"For MMR search, you must specify an embedding function on" "creation."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
embedding = self._embedding_function.embed_query(query)
|
|
|
|
|
docs = self.max_marginal_relevance_search_by_vector(
|
|
|
|
|
embedding, k, fetch_k, filter
|
|
|
|
|
)
|
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
def delete_collection(self) -> None:
|
|
|
|
|
"""Delete the collection."""
|
|
|
|
|
self._client.delete_collection(self._collection.name)
|
|
|
|
|