diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index 7679b22e..7a0b5c77 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -3,12 +3,16 @@ from __future__ import annotations import logging import uuid -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.vectorstores.base import VectorStore +if TYPE_CHECKING: + import chromadb + import chromadb.config + logger = logging.getLogger() @@ -34,6 +38,7 @@ class Chroma(VectorStore): collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, embedding_function: Optional[Embeddings] = None, persist_directory: Optional[str] = None, + client_settings: Optional[chromadb.config.Settings] = None, ) -> None: """Initialize with Chroma client.""" try: @@ -45,12 +50,14 @@ class Chroma(VectorStore): "Please install it with `pip install chromadb`." ) - # TODO: Add support for custom client. For now this is in-memory only. - self._client_settings = chromadb.config.Settings() - if persist_directory is not None: - self._client_settings = chromadb.config.Settings( - chroma_db_impl="duckdb+parquet", persist_directory=persist_directory - ) + if client_settings: + self._client_settings = client_settings + else: + self._client_settings = chromadb.config.Settings() + if persist_directory is not None: + self._client_settings = chromadb.config.Settings( + chroma_db_impl="duckdb+parquet", persist_directory=persist_directory + ) self._client = chromadb.Client(self._client_settings) self._embedding_function = embedding_function self._persist_directory = persist_directory @@ -185,6 +192,7 @@ class Chroma(VectorStore): ids: Optional[List[str]] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, persist_directory: Optional[str] = None, + client_settings: Optional[chromadb.config.Settings] = None, **kwargs: Any, ) -> Chroma: """Create a Chroma vectorstore from a raw documents. @@ -207,6 +215,7 @@ class Chroma(VectorStore): collection_name=collection_name, embedding_function=embedding, persist_directory=persist_directory, + client_settings=client_settings, ) chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids) return chroma_collection @@ -219,6 +228,7 @@ class Chroma(VectorStore): ids: Optional[List[str]] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, persist_directory: Optional[str] = None, + client_settings: Optional[chromadb.config.Settings] = None, **kwargs: Any, ) -> Chroma: """Create a Chroma vectorstore from a list of documents. @@ -244,4 +254,5 @@ class Chroma(VectorStore): ids=ids, collection_name=collection_name, persist_directory=persist_directory, + client_settings=client_settings, )