From b08f903755804b2f0c43be8b0cb84041da49edc4 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Thu, 13 Jul 2023 03:00:33 -0400 Subject: [PATCH] fix chroma init bug (#7639) --- langchain/vectorstores/chroma.py | 25 +++++++++++-------- .../vectorstores/test_chroma.py | 14 +++++++++++ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index 246d6e2d02..02e67479a6 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -87,23 +87,26 @@ class Chroma(VectorStore): ) if client is not None: + self._client_settings = client_settings self._client = client + self._persist_directory = persist_directory else: if client_settings: - self._client_settings = client_settings + _client_settings = client_settings + elif persist_directory: + _client_settings = chromadb.config.Settings( + chroma_db_impl="duckdb+parquet", + persist_directory=persist_directory, + ) else: - self._client_settings = chromadb.config.Settings() - if persist_directory is not None: - self._client_settings = chromadb.config.Settings( - chroma_db_impl="duckdb+parquet", - persist_directory=persist_directory, - ) - self._client = chromadb.Client(self._client_settings) + _client_settings = chromadb.config.Settings() + self._client_settings = _client_settings + self._client = chromadb.Client(_client_settings) + self._persist_directory = ( + _client_settings.persist_directory or persist_directory + ) self._embedding_function = embedding_function - self._persist_directory = ( - self._client_settings.persist_directory or persist_directory - ) self._collection = self._client.get_or_create_collection( name=collection_name, embedding_function=self._embedding_function.embed_documents diff --git a/tests/integration_tests/vectorstores/test_chroma.py b/tests/integration_tests/vectorstores/test_chroma.py index 80be89596e..1e2e0dfb2b 100644 --- a/tests/integration_tests/vectorstores/test_chroma.py +++ b/tests/integration_tests/vectorstores/test_chroma.py @@ -267,3 +267,17 @@ def test_chroma_with_relevance_score_custom_normalization_fn() -> None: (Document(page_content="bar", metadata={"page": "1"}), -0.0), (Document(page_content="baz", metadata={"page": "2"}), -0.0), ] + + +def test_init_from_client() -> None: + import chromadb + + client = chromadb.Client(chromadb.config.Settings()) + Chroma(client=client) + + +def test_init_from_client_settings() -> None: + import chromadb + + client_settings = chromadb.config.Settings() + Chroma(client_settings=client_settings)