mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add retry logic for ChromaDB (#3372)
Rewrite of #3368 Mainly an issue for when people are just getting started, but still nice to not throw an error if the number of docs is < k. Add a little decorator utility to block mutually exclusive keyword arguments
This commit is contained in:
parent
6b49be9951
commit
b89c258bc5
@ -1,6 +1,6 @@
|
||||
"""Generic utility functions."""
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
|
||||
def get_from_dict_or_env(
|
||||
@ -19,3 +19,28 @@ def get_from_dict_or_env(
|
||||
f" `{env_key}` which contains it, or pass"
|
||||
f" `{key}` as a named parameter."
|
||||
)
|
||||
|
||||
|
||||
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
||||
"""Validate specified keyword args are mutually exclusive."""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Callable:
|
||||
"""Validate exactly one arg in each group is not None."""
|
||||
counts = [
|
||||
sum(1 for arg in arg_group if kwargs.get(arg) is not None)
|
||||
for arg_group in arg_groups
|
||||
]
|
||||
invalid_groups = [i for i, count in enumerate(counts) if count != 1]
|
||||
if invalid_groups:
|
||||
invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups]
|
||||
raise ValueError(
|
||||
"Exactly one argument in each of the following"
|
||||
" groups must be defined:"
|
||||
f" {', '.join(invalid_group_names)}"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
@ -9,6 +9,7 @@ import numpy as np
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import xor_args
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
@ -96,6 +97,32 @@ class Chroma(VectorStore):
|
||||
metadata=collection_metadata,
|
||||
)
|
||||
|
||||
@xor_args(("query_texts", "query_embeddings"))
|
||||
def __query_collection(
|
||||
self,
|
||||
query_texts: Optional[List[str]] = None,
|
||||
query_embeddings: Optional[List[List[float]]] = None,
|
||||
n_results: int = 4,
|
||||
where: Optional[Dict[str, str]] = None,
|
||||
) -> List[Document]:
|
||||
"""Query the chroma collection."""
|
||||
for i in range(n_results, 0, -1):
|
||||
try:
|
||||
return self._collection.query(
|
||||
query_texts=query_texts,
|
||||
query_embeddings=query_embeddings,
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
)
|
||||
except chromadb.errors.NotEnoughElementsException:
|
||||
logger.error(
|
||||
f"Chroma collection {self._collection.name} "
|
||||
f"contains fewer than {i} elements."
|
||||
)
|
||||
raise chromadb.errors.NotEnoughElementsException(
|
||||
f"No documents found for Chroma collection {self._collection.name}"
|
||||
)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
@ -158,7 +185,7 @@ class Chroma(VectorStore):
|
||||
Returns:
|
||||
List of Documents most similar to the query vector.
|
||||
"""
|
||||
results = self._collection.query(
|
||||
results = self.__query_collection(
|
||||
query_embeddings=embedding, n_results=k, where=filter
|
||||
)
|
||||
return _results_to_docs(results)
|
||||
@ -182,12 +209,12 @@ class Chroma(VectorStore):
|
||||
text with distance in float.
|
||||
"""
|
||||
if self._embedding_function is None:
|
||||
results = self._collection.query(
|
||||
results = self.__query_collection(
|
||||
query_texts=[query], n_results=k, where=filter
|
||||
)
|
||||
else:
|
||||
query_embedding = self._embedding_function.embed_query(query)
|
||||
results = self._collection.query(
|
||||
results = self.__query_collection(
|
||||
query_embeddings=[query_embedding], n_results=k, where=filter
|
||||
)
|
||||
|
||||
@ -218,7 +245,7 @@ class Chroma(VectorStore):
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
|
||||
results = self._collection.query(
|
||||
results = self.__query_collection(
|
||||
query_embeddings=embedding,
|
||||
n_results=fetch_k,
|
||||
where=filter,
|
||||
|
Loading…
Reference in New Issue
Block a user