diff --git a/libs/langchain/langchain/vectorstores/chroma.py b/libs/langchain/langchain/vectorstores/chroma.py index 63c63068dd..87e71995bc 100644 --- a/libs/langchain/langchain/vectorstores/chroma.py +++ b/libs/langchain/langchain/vectorstores/chroma.py @@ -171,38 +171,52 @@ class Chroma(VectorStore): if ids is None: ids = [str(uuid.uuid1()) for _ in texts] embeddings = None + texts = list(texts) if self._embedding_function is not None: - embeddings = self._embedding_function.embed_documents(list(texts)) - + embeddings = self._embedding_function.embed_documents(texts) if metadatas: - texts = list(texts) - empty = [] - non_empty = [] - for i, m in enumerate(metadatas): + # fill metadatas with empty dicts if somebody + # did not specify metadata for all texts + length_diff = len(texts) - len(metadatas) + if length_diff: + metadatas = metadatas + [{}] * length_diff + empty_ids = [] + non_empty_ids = [] + for idx, m in enumerate(metadatas): if m: - non_empty.append(i) + non_empty_ids.append(idx) else: - empty.append(i) - if non_empty: - metadatas = [metadatas[i] for i in non_empty] - texts_with_metadatas = [texts[i] for i in non_empty] + empty_ids.append(idx) + if non_empty_ids: + metadatas = [metadatas[idx] for idx in non_empty_ids] + texts_with_metadatas = [texts[idx] for idx in non_empty_ids] embeddings_with_metadatas = ( - [embeddings[i] for i in non_empty] if embeddings else None + [embeddings[idx] for idx in non_empty_ids] if embeddings else None ) - ids_with_metadata = [ids[i] for i in non_empty] + ids_with_metadata = [ids[idx] for idx in non_empty_ids] self._collection.upsert( metadatas=metadatas, embeddings=embeddings_with_metadatas, documents=texts_with_metadatas, ids=ids_with_metadata, ) - - texts = [texts[j] for j in empty] - embeddings = [embeddings[j] for j in empty] if embeddings else None - ids = [ids[j] for j in empty] - - if texts: - self._collection.upsert(embeddings=embeddings, documents=texts, ids=ids) + if empty_ids: + texts_without_metadatas = [texts[j] for j in empty_ids] + embeddings_without_metadatas = ( + [embeddings[j] for j in empty_ids] if embeddings else None + ) + ids_without_metadatas = [ids[j] for j in empty_ids] + self._collection.upsert( + embeddings=embeddings_without_metadatas, + documents=texts_without_metadatas, + ids=ids_without_metadatas, + ) + else: + self._collection.upsert( + embeddings=embeddings, + documents=texts, + ids=ids, + ) return ids def similarity_search( diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_chroma.py b/libs/langchain/tests/integration_tests/vectorstores/test_chroma.py index 2a8ac0eeef..99f7360537 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_chroma.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_chroma.py @@ -294,7 +294,9 @@ def test_chroma_add_documents_mixed_metadata() -> None: Document(page_content="foo"), Document(page_content="bar", metadata={"baz": 1}), ] - db.add_documents(docs) + ids = ["0", "1"] + actual_ids = db.add_documents(docs, ids=ids) + assert actual_ids == ids search = db.similarity_search("foo bar") assert sorted(search, key=lambda d: d.page_content) == sorted( docs, key=lambda d: d.page_content