Add retry logic for ChromaDB (#3372)

Rewrite of #3368

Mainly an issue for when people are just getting started, but still nice
to not throw an error if the number of docs is < k.

Add a little decorator utility to block mutually exclusive keyword
arguments
This commit is contained in:
Zander Chase 2023-04-24 21:48:29 -07:00 committed by GitHub
parent 6b49be9951
commit b89c258bc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 5 deletions

View File

@ -1,6 +1,6 @@
"""Generic utility functions."""
import os
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional, Tuple
def get_from_dict_or_env(
@ -19,3 +19,28 @@ def get_from_dict_or_env(
f" `{env_key}` which contains it, or pass"
f" `{key}` as a named parameter."
)
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive."""
def decorator(func: Callable) -> Callable:
def wrapper(*args: Any, **kwargs: Any) -> Callable:
"""Validate exactly one arg in each group is not None."""
counts = [
sum(1 for arg in arg_group if kwargs.get(arg) is not None)
for arg_group in arg_groups
]
invalid_groups = [i for i, count in enumerate(counts) if count != 1]
if invalid_groups:
invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups]
raise ValueError(
"Exactly one argument in each of the following"
" groups must be defined:"
f" {', '.join(invalid_group_names)}"
)
return func(*args, **kwargs)
return wrapper
return decorator

View File

@ -9,6 +9,7 @@ import numpy as np
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.utils import xor_args
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
@ -96,6 +97,32 @@ class Chroma(VectorStore):
metadata=collection_metadata,
)
@xor_args(("query_texts", "query_embeddings"))
def __query_collection(
self,
query_texts: Optional[List[str]] = None,
query_embeddings: Optional[List[List[float]]] = None,
n_results: int = 4,
where: Optional[Dict[str, str]] = None,
) -> List[Document]:
"""Query the chroma collection."""
for i in range(n_results, 0, -1):
try:
return self._collection.query(
query_texts=query_texts,
query_embeddings=query_embeddings,
n_results=n_results,
where=where,
)
except chromadb.errors.NotEnoughElementsException:
logger.error(
f"Chroma collection {self._collection.name} "
f"contains fewer than {i} elements."
)
raise chromadb.errors.NotEnoughElementsException(
f"No documents found for Chroma collection {self._collection.name}"
)
def add_texts(
self,
texts: Iterable[str],
@ -158,7 +185,7 @@ class Chroma(VectorStore):
Returns:
List of Documents most similar to the query vector.
"""
results = self._collection.query(
results = self.__query_collection(
query_embeddings=embedding, n_results=k, where=filter
)
return _results_to_docs(results)
@ -182,12 +209,12 @@ class Chroma(VectorStore):
text with distance in float.
"""
if self._embedding_function is None:
results = self._collection.query(
results = self.__query_collection(
query_texts=[query], n_results=k, where=filter
)
else:
query_embedding = self._embedding_function.embed_query(query)
results = self._collection.query(
results = self.__query_collection(
query_embeddings=[query_embedding], n_results=k, where=filter
)
@ -218,7 +245,7 @@ class Chroma(VectorStore):
List of Documents selected by maximal marginal relevance.
"""
results = self._collection.query(
results = self.__query_collection(
query_embeddings=embedding,
n_results=fetch_k,
where=filter,