From 245131097557b73774197b01e326206fa2a1b83a Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Mon, 1 May 2023 10:47:15 -0700 Subject: [PATCH] Chroma fix mmr (#3897) Fixes #3628, thanks @derekmoeller for the issue! --- langchain/vectorstores/chroma.py | 10 +++++++++ .../vectorstores/test_chroma.py | 22 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index d3b4cfbc..b5d77a96 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -104,8 +104,17 @@ class Chroma(VectorStore): query_embeddings: Optional[List[List[float]]] = None, n_results: int = 4, where: Optional[Dict[str, str]] = None, + **kwargs: Any, ) -> List[Document]: """Query the chroma collection.""" + try: + import chromadb + except ImportError: + raise ValueError( + "Could not import chromadb python package. " + "Please install it with `pip install chromadb`." + ) + for i in range(n_results, 0, -1): try: return self._collection.query( @@ -113,6 +122,7 @@ class Chroma(VectorStore): query_embeddings=query_embeddings, n_results=i, where=where, + **kwargs, ) except chromadb.errors.NotEnoughElementsException: logger.error( diff --git a/tests/integration_tests/vectorstores/test_chroma.py b/tests/integration_tests/vectorstores/test_chroma.py index 0075a163..9f51f253 100644 --- a/tests/integration_tests/vectorstores/test_chroma.py +++ b/tests/integration_tests/vectorstores/test_chroma.py @@ -126,3 +126,25 @@ def test_chroma_with_persistence() -> None: # Persist doesn't need to be called again # Data will be automatically persisted on object deletion # Or on program exit + + +def test_chroma_mmr() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + docsearch = Chroma.from_texts( + collection_name="test_collection", texts=texts, embedding=FakeEmbeddings() + ) + output = docsearch.max_marginal_relevance_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_chroma_mmr_by_vector() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + embeddings = FakeEmbeddings() + docsearch = Chroma.from_texts( + collection_name="test_collection", texts=texts, embedding=embeddings + ) + embedded_query = embeddings.embed_query("foo") + output = docsearch.max_marginal_relevance_search_by_vector(embedded_query, k=1) + assert output == [Document(page_content="foo")]