diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index b4387fb1..e102abe4 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -313,9 +313,17 @@ class Chroma(VectorStore): """Delete the collection.""" self._client.delete_collection(self._collection.name) - def get(self) -> Chroma: - """Gets the collection""" - return self._collection.get() + def get(self, include: Optional[List[str]] = None) -> Dict[str, Any]: + """Gets the collection. + + Args: + include (Optional[List[str]]): List of fields to include from db. + Defaults to None. + """ + if include is not None: + return self._collection.get(include=include) + else: + return self._collection.get() def persist(self) -> None: """Persist the collection. diff --git a/tests/integration_tests/vectorstores/test_chroma.py b/tests/integration_tests/vectorstores/test_chroma.py index 9f51f253..17fb781e 100644 --- a/tests/integration_tests/vectorstores/test_chroma.py +++ b/tests/integration_tests/vectorstores/test_chroma.py @@ -148,3 +148,15 @@ def test_chroma_mmr_by_vector() -> None: embedded_query = embeddings.embed_query("foo") output = docsearch.max_marginal_relevance_search_by_vector(embedded_query, k=1) assert output == [Document(page_content="foo")] + + +def test_chroma_with_include_parameter() -> None: + """Test end to end construction and include parameter.""" + texts = ["foo", "bar", "baz"] + docsearch = Chroma.from_texts( + collection_name="test_collection", texts=texts, embedding=FakeEmbeddings() + ) + output = docsearch.get(include=["embeddings"]) + assert output["embeddings"] is not None + output = docsearch.get() + assert output["embeddings"] is None