diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index b7399bf6..7d29dbe5 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -60,6 +60,7 @@ class Chroma(VectorStore): persist_directory: Optional[str] = None, client_settings: Optional[chromadb.config.Settings] = None, collection_metadata: Optional[Dict] = None, + client: Optional[chromadb.Client] = None, ) -> None: """Initialize with Chroma client.""" try: @@ -71,15 +72,20 @@ class Chroma(VectorStore): "Please install it with `pip install chromadb`." ) - if client_settings: - self._client_settings = client_settings + if client is not None: + self._client = client else: - self._client_settings = chromadb.config.Settings() - if persist_directory is not None: - self._client_settings = chromadb.config.Settings( - chroma_db_impl="duckdb+parquet", persist_directory=persist_directory - ) - self._client = chromadb.Client(self._client_settings) + if client_settings: + self._client_settings = client_settings + else: + self._client_settings = chromadb.config.Settings() + if persist_directory is not None: + self._client_settings = chromadb.config.Settings( + chroma_db_impl="duckdb+parquet", + persist_directory=persist_directory, + ) + self._client = chromadb.Client(self._client_settings) + self._embedding_function = embedding_function self._persist_directory = persist_directory self._collection = self._client.get_or_create_collection( @@ -279,6 +285,7 @@ class Chroma(VectorStore): collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, persist_directory: Optional[str] = None, client_settings: Optional[chromadb.config.Settings] = None, + client: Optional[chromadb.Client] = None, **kwargs: Any, ) -> Chroma: """Create a Chroma vectorstore from a raw documents. @@ -303,6 +310,7 @@ class Chroma(VectorStore): embedding_function=embedding, persist_directory=persist_directory, client_settings=client_settings, + client=client, ) chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids) return chroma_collection @@ -316,6 +324,7 @@ class Chroma(VectorStore): collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, persist_directory: Optional[str] = None, client_settings: Optional[chromadb.config.Settings] = None, + client: Optional[chromadb.Client] = None, # Add this line **kwargs: Any, ) -> Chroma: """Create a Chroma vectorstore from a list of documents. @@ -343,4 +352,5 @@ class Chroma(VectorStore): collection_name=collection_name, persist_directory=persist_directory, client_settings=client_settings, + client=client, )