From b89c258bc599e2528c11b0aba62c3ac2616e6155 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Mon, 24 Apr 2023 21:48:29 -0700 Subject: [PATCH] Add retry logic for ChromaDB (#3372) Rewrite of #3368 Mainly an issue for when people are just getting started, but still nice to not throw an error if the number of docs is < k. Add a little decorator utility to block mutually exclusive keyword arguments --- langchain/utils.py | 27 +++++++++++++++++++++++- langchain/vectorstores/chroma.py | 35 ++++++++++++++++++++++++++++---- 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/langchain/utils.py b/langchain/utils.py index 08fa4327..0daf9c52 100644 --- a/langchain/utils.py +++ b/langchain/utils.py @@ -1,6 +1,6 @@ """Generic utility functions.""" import os -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple def get_from_dict_or_env( @@ -19,3 +19,28 @@ def get_from_dict_or_env( f" `{env_key}` which contains it, or pass" f" `{key}` as a named parameter." ) + + +def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: + """Validate specified keyword args are mutually exclusive.""" + + def decorator(func: Callable) -> Callable: + def wrapper(*args: Any, **kwargs: Any) -> Callable: + """Validate exactly one arg in each group is not None.""" + counts = [ + sum(1 for arg in arg_group if kwargs.get(arg) is not None) + for arg_group in arg_groups + ] + invalid_groups = [i for i, count in enumerate(counts) if count != 1] + if invalid_groups: + invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups] + raise ValueError( + "Exactly one argument in each of the following" + " groups must be defined:" + f" {', '.join(invalid_group_names)}" + ) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index 1068963c..c3d977aa 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -9,6 +9,7 @@ import numpy as np from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings +from langchain.utils import xor_args from langchain.vectorstores.base import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance @@ -96,6 +97,32 @@ class Chroma(VectorStore): metadata=collection_metadata, ) + @xor_args(("query_texts", "query_embeddings")) + def __query_collection( + self, + query_texts: Optional[List[str]] = None, + query_embeddings: Optional[List[List[float]]] = None, + n_results: int = 4, + where: Optional[Dict[str, str]] = None, + ) -> List[Document]: + """Query the chroma collection.""" + for i in range(n_results, 0, -1): + try: + return self._collection.query( + query_texts=query_texts, + query_embeddings=query_embeddings, + n_results=n_results, + where=where, + ) + except chromadb.errors.NotEnoughElementsException: + logger.error( + f"Chroma collection {self._collection.name} " + f"contains fewer than {i} elements." + ) + raise chromadb.errors.NotEnoughElementsException( + f"No documents found for Chroma collection {self._collection.name}" + ) + def add_texts( self, texts: Iterable[str], @@ -158,7 +185,7 @@ class Chroma(VectorStore): Returns: List of Documents most similar to the query vector. """ - results = self._collection.query( + results = self.__query_collection( query_embeddings=embedding, n_results=k, where=filter ) return _results_to_docs(results) @@ -182,12 +209,12 @@ class Chroma(VectorStore): text with distance in float. """ if self._embedding_function is None: - results = self._collection.query( + results = self.__query_collection( query_texts=[query], n_results=k, where=filter ) else: query_embedding = self._embedding_function.embed_query(query) - results = self._collection.query( + results = self.__query_collection( query_embeddings=[query_embedding], n_results=k, where=filter ) @@ -218,7 +245,7 @@ class Chroma(VectorStore): List of Documents selected by maximal marginal relevance. """ - results = self._collection.query( + results = self.__query_collection( query_embeddings=embedding, n_results=fetch_k, where=filter,