diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index 02e67479a6..46c057077c 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -164,9 +164,36 @@ class Chroma(VectorStore): embeddings = None if self._embedding_function is not None: embeddings = self._embedding_function.embed_documents(list(texts)) - self._collection.upsert( - metadatas=metadatas, embeddings=embeddings, documents=texts, ids=ids - ) + + if metadatas: + texts = list(texts) + empty = [] + non_empty = [] + for i, m in enumerate(metadatas): + if m: + non_empty.append(i) + else: + empty.append(i) + if non_empty: + metadatas = [metadatas[i] for i in non_empty] + texts_with_metadatas = [texts[i] for i in non_empty] + embeddings_with_metadatas = ( + [embeddings[i] for i in non_empty] if embeddings else None + ) + ids_with_metadata = [ids[i] for i in non_empty] + self._collection.upsert( + metadatas=metadatas, + embeddings=embeddings_with_metadatas, + documents=texts_with_metadatas, + ids=ids_with_metadata, + ) + + texts = [texts[j] for j in empty] + embeddings = [embeddings[j] for j in empty] if embeddings else None + ids = [ids[j] for j in empty] + + if texts: + self._collection.upsert(embeddings=embeddings, documents=texts, ids=ids) return ids def similarity_search( diff --git a/tests/integration_tests/vectorstores/test_chroma.py b/tests/integration_tests/vectorstores/test_chroma.py index 1e2e0dfb2b..2a8ac0eeef 100644 --- a/tests/integration_tests/vectorstores/test_chroma.py +++ b/tests/integration_tests/vectorstores/test_chroma.py @@ -281,3 +281,21 @@ def test_init_from_client_settings() -> None: client_settings = chromadb.config.Settings() Chroma(client_settings=client_settings) + + +def test_chroma_add_documents_no_metadata() -> None: + db = Chroma(embedding_function=FakeEmbeddings()) + db.add_documents([Document(page_content="foo")]) + + +def test_chroma_add_documents_mixed_metadata() -> None: + db = Chroma(embedding_function=FakeEmbeddings()) + docs = [ + Document(page_content="foo"), + Document(page_content="bar", metadata={"baz": 1}), + ] + db.add_documents(docs) + search = db.similarity_search("foo bar") + assert sorted(search, key=lambda d: d.page_content) == sorted( + docs, key=lambda d: d.page_content + )