From 6ed16e13b15a370a8717c9b3e0f8dcffdc94ed46 Mon Sep 17 00:00:00 2001 From: Claus Thomasen Date: Fri, 10 Mar 2023 21:40:14 +0100 Subject: [PATCH] Readded similarity_search_by_vector (#1568) I am redoing this PR, as I made a mistake by merging the latest changes into my fork's branch, sorry. This added a bunch of commits to my previous PR. This fixes #1451. --- langchain/vectorstores/chroma.py | 45 ++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index 7a0b5c77..df729467 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -16,6 +16,23 @@ if TYPE_CHECKING: logger = logging.getLogger() +def _results_to_docs(results: Any) -> List[Document]: + return [doc for doc, _ in _results_to_docs_and_scores(results)] + + +def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]: + return [ + # TODO: Chroma can do batch querying, + # we shouldn't hard code to the 1st result + (Document(page_content=result[0], metadata=result[1]), result[2]) + for result in zip( + results["documents"][0], + results["metadatas"][0], + results["distances"][0], + ) + ] + + class Chroma(VectorStore): """Wrapper around ChromaDB embeddings platform. @@ -126,6 +143,22 @@ class Chroma(VectorStore): docs_and_scores = self.similarity_search_with_score(query, k) return [doc for doc, _ in docs_and_scores] + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + **kwargs: Any, + ) -> List[Document]: + """Return docs most similar to embedding vector. + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + Returns: + List of Documents most similar to the query vector. + """ + results = self._collection.query(query_embeddings=embedding, n_results=k) + return _results_to_docs(results) + def similarity_search_with_score( self, query: str, @@ -154,17 +187,7 @@ class Chroma(VectorStore): query_embeddings=[query_embedding], n_results=k, where=filter ) - docs = [ - # TODO: Chroma can do batch querying, - # we shouldn't hard code to the 1st result - (Document(page_content=result[0], metadata=result[1]), result[2]) - for result in zip( - results["documents"][0], - results["metadatas"][0], - results["distances"][0], - ) - ] - return docs + return _results_to_docs_and_scores(results) def delete_collection(self) -> None: """Delete the collection."""