From d1262766939fa54e1cbe8d35509964d3741a9260 Mon Sep 17 00:00:00 2001 From: Magnus Friberg Date: Tue, 16 May 2023 23:43:09 +0200 Subject: [PATCH] Specify which data to return from chromadb (#4393) # Improve the Chroma get() method by adding the optional "include" parameter. The Chroma get() method excludes embeddings by default. You can customize the response by specifying the "include" parameter to selectively retrieve the desired data from the collection. --------- Co-authored-by: Dev 2049 --- langchain/vectorstores/chroma.py | 14 +++++++++++--- .../integration_tests/vectorstores/test_chroma.py | 12 ++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index b4387fb1..e102abe4 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -313,9 +313,17 @@ class Chroma(VectorStore): """Delete the collection.""" self._client.delete_collection(self._collection.name) - def get(self) -> Chroma: - """Gets the collection""" - return self._collection.get() + def get(self, include: Optional[List[str]] = None) -> Dict[str, Any]: + """Gets the collection. + + Args: + include (Optional[List[str]]): List of fields to include from db. + Defaults to None. + """ + if include is not None: + return self._collection.get(include=include) + else: + return self._collection.get() def persist(self) -> None: """Persist the collection. diff --git a/tests/integration_tests/vectorstores/test_chroma.py b/tests/integration_tests/vectorstores/test_chroma.py index 9f51f253..17fb781e 100644 --- a/tests/integration_tests/vectorstores/test_chroma.py +++ b/tests/integration_tests/vectorstores/test_chroma.py @@ -148,3 +148,15 @@ def test_chroma_mmr_by_vector() -> None: embedded_query = embeddings.embed_query("foo") output = docsearch.max_marginal_relevance_search_by_vector(embedded_query, k=1) assert output == [Document(page_content="foo")] + + +def test_chroma_with_include_parameter() -> None: + """Test end to end construction and include parameter.""" + texts = ["foo", "bar", "baz"] + docsearch = Chroma.from_texts( + collection_name="test_collection", texts=texts, embedding=FakeEmbeddings() + ) + output = docsearch.get(include=["embeddings"]) + assert output["embeddings"] is not None + output = docsearch.get() + assert output["embeddings"] is None