From b809c243afb182efc5868ca495074a41e2f7c40c Mon Sep 17 00:00:00 2001 From: Richard Wang Date: Sat, 23 Sep 2023 10:41:07 +0800 Subject: [PATCH] Fix bug in `index` api (#10614) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - **Description:** a fix for `index`. - **Issue:** Not applicable. - **Dependencies:** None - **Tag maintainer:** - **Twitter handle:** richarddwang # Problem Replication code ```python from pprint import pprint from langchain.embeddings import OpenAIEmbeddings from langchain.indexes import SQLRecordManager, index from langchain.schema import Document from langchain.vectorstores import Qdrant from langchain_setup.qdrant import pprint_qdrant_documents, create_inmemory_empty_qdrant # Documents metadata1 = {"source": "fullhell.alchemist"} doc1_1 = Document(page_content="1-1 I have a dog~", metadata=metadata1) doc1_2 = Document(page_content="1-2 I have a daugter~", metadata=metadata1) doc1_3 = Document(page_content="1-3 Ahh! O..Oniichan", metadata=metadata1) doc2 = Document(page_content="2 Lancer died again.", metadata={"source": "fate.docx"}) # Create empty vectorstore collection_name = "secret_of_D_disk" vectorstore: Qdrant = create_inmemory_empty_qdrant() # Create record Manager import tempfile from pathlib import Path record_manager = SQLRecordManager( namespace="qdrant/{collection_name}", db_url=f"sqlite:///{Path(tempfile.gettempdir())/collection_name}.sql", ) record_manager.create_schema() # 必須 sync_result = index( [doc1_1, doc1_2, doc1_2, doc2], record_manager, vectorstore, cleanup="full", source_id_key="source", ) print(sync_result, end="\n\n") pprint_qdrant_documents(vectorstore) ```
Code of helper functions `pprint_qdrant_documents` and `create_inmemory_empty_qdrant` ```python def create_inmemory_empty_qdrant(**from_texts_kwargs): # Qdrant requires vector size, which can be only know after applying embedder vectorstore = Qdrant.from_texts(["dummy"], location=":memory:", embedding=OpenAIEmbeddings(), **from_texts_kwargs) dummy_document_id = vectorstore.client.scroll(vectorstore.collection_name)[0][0].id vectorstore.delete([dummy_document_id]) return vectorstore def pprint_qdrant_documents(vectorstore, limit: int = 100, **scroll_kwargs): document_ids, documents = [], [] for record in vectorstore.client.scroll( vectorstore.collection_name, limit=100, **scroll_kwargs )[0]: document_ids.append(record.id) documents.append( Document( page_content=record.payload["page_content"], metadata=record.payload["metadata"] or {}, ) ) pprint_documents(documents, document_ids=document_ids) def pprint_document(document: Document = None, document_id=None, return_string=False): displayed_text = "" if document_id: displayed_text += f"Document {document_id}:\n\n" displayed_text += f"{document.page_content}\n\n" metadata_text = pformat(document.metadata, indent=1) if "\n" in metadata_text: displayed_text += f"Metadata:\n{metadata_text}" else: displayed_text += f"Metadata:{metadata_text}" if return_string: return displayed_text else: print(displayed_text) def pprint_documents(documents, document_ids=None): if not document_ids: document_ids = [i + 1 for i in range(len(documents))] displayed_texts = [] for document_id, document in zip(document_ids, documents): displayed_text = pprint_document( document_id=document_id, document=document, return_string=True ) displayed_texts.append(displayed_text) print(f"\n{'-' * 100}\n".join(displayed_texts)) ```
You will get ``` {'num_added': 3, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0} Document 1b19816e-b802-53c0-ad60-5ff9d9b9b911: 1-2 I have a daugter~ Metadata:{'source': 'fullhell.alchemist'} ---------------------------------------------------------------------------------------------------- Document 3362f9bc-991a-5dd5-b465-c564786ce19c: 1-1 I have a dog~ Metadata:{'source': 'fullhell.alchemist'} ---------------------------------------------------------------------------------------------------- Document a4d50169-2fda-5339-a196-249b5f54a0de: 1-2 I have a daugter~ Metadata:{'source': 'fullhell.alchemist'} ``` This is not correct. We should be able to expect that the vectorsotre now includes doc1_1, doc1_2, and doc2, but not doc1_1, doc1_2, and doc1_2. # Reason In `index`, the original code is ```python uids = [] docs_to_index = [] for doc, hashed_doc, doc_exists in zip(doc_batch, hashed_docs, exists_batch): if doc_exists: # Must be updated to refresh timestamp. record_manager.update([hashed_doc.uid], time_at_least=index_start_dt) num_skipped += 1 continue uids.append(hashed_doc.uid) docs_to_index.append(doc) ``` In the aforementioned example, `len(doc_batch) == 4`, but `len(hashed_docs) == len(exists_batch) == 3`. This is because the deduplication of input documents [doc1_1, doc1_2, doc1_2, doc2] is [doc1_1, doc1_2, doc2]. So `index` insert doc1_1, doc1_2, doc1_2 with the uid of doc1_1, doc1_2, doc2. --------- Co-authored-by: Eugene Yurtsev --- libs/langchain/langchain/indexes/_api.py | 4 +- .../tests/unit_tests/indexes/test_indexing.py | 39 +++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/indexes/_api.py b/libs/langchain/langchain/indexes/_api.py index 6f2ebf9828..c471d07d9c 100644 --- a/libs/langchain/langchain/indexes/_api.py +++ b/libs/langchain/langchain/indexes/_api.py @@ -282,14 +282,14 @@ def index( # Filter out documents that already exist in the record store. uids = [] docs_to_index = [] - for doc, hashed_doc, doc_exists in zip(doc_batch, hashed_docs, exists_batch): + for hashed_doc, doc_exists in zip(hashed_docs, exists_batch): if doc_exists: # Must be updated to refresh timestamp. record_manager.update([hashed_doc.uid], time_at_least=index_start_dt) num_skipped += 1 continue uids.append(hashed_doc.uid) - docs_to_index.append(doc) + docs_to_index.append(hashed_doc.to_document()) # Be pessimistic and assume that all vector store write will fail. # First write to vector store diff --git a/libs/langchain/tests/unit_tests/indexes/test_indexing.py b/libs/langchain/tests/unit_tests/indexes/test_indexing.py index 7783f41f07..3567dd648c 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_indexing.py +++ b/libs/langchain/tests/unit_tests/indexes/test_indexing.py @@ -472,3 +472,42 @@ def test_deduplication( "num_skipped": 0, "num_updated": 0, } + + +def test_deduplication_v2( + record_manager: SQLRecordManager, vector_store: VectorStore +) -> None: + """Check edge case when loader returns no new docs.""" + docs = [ + Document( + page_content="1", + metadata={"source": "1"}, + ), + Document( + page_content="1", + metadata={"source": "1"}, + ), + Document( + page_content="2", + metadata={"source": "2"}, + ), + Document( + page_content="3", + metadata={"source": "3"}, + ), + ] + + # Should result in only a single document being added + assert index(docs, record_manager, vector_store, cleanup="full") == { + "num_added": 3, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + # using in memory implementation here + assert isinstance(vector_store, InMemoryVectorStore) + contents = sorted( + [document.page_content for document in vector_store.store.values()] + ) + assert contents == ["1", "2", "3"]