Chroma fix mmr (#3897)

Fixes #3628, thanks @derekmoeller for the issue!
This commit is contained in:
Davis Chase 2023-05-01 10:47:15 -07:00 committed by GitHub
parent 3e1cb31f63
commit 2451310975
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 0 deletions

View File

@ -104,8 +104,17 @@ class Chroma(VectorStore):
query_embeddings: Optional[List[List[float]]] = None,
n_results: int = 4,
where: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Query the chroma collection."""
try:
import chromadb
except ImportError:
raise ValueError(
"Could not import chromadb python package. "
"Please install it with `pip install chromadb`."
)
for i in range(n_results, 0, -1):
try:
return self._collection.query(
@ -113,6 +122,7 @@ class Chroma(VectorStore):
query_embeddings=query_embeddings,
n_results=i,
where=where,
**kwargs,
)
except chromadb.errors.NotEnoughElementsException:
logger.error(

View File

@ -126,3 +126,25 @@ def test_chroma_with_persistence() -> None:
# Persist doesn't need to be called again
# Data will be automatically persisted on object deletion
# Or on program exit
def test_chroma_mmr() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch = Chroma.from_texts(
collection_name="test_collection", texts=texts, embedding=FakeEmbeddings()
)
output = docsearch.max_marginal_relevance_search("foo", k=1)
assert output == [Document(page_content="foo")]
def test_chroma_mmr_by_vector() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
embeddings = FakeEmbeddings()
docsearch = Chroma.from_texts(
collection_name="test_collection", texts=texts, embedding=embeddings
)
embedded_query = embeddings.embed_query("foo")
output = docsearch.max_marginal_relevance_search_by_vector(embedded_query, k=1)
assert output == [Document(page_content="foo")]