From af5390d416bae17a4773fc5ac3da12e6311a6098 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 25 Sep 2023 14:52:32 -0400 Subject: [PATCH] Add a batch size for cleanup (#10948) Add pagination to indexing cleanup to deal with large numbers of documents that need to be deleted. --- libs/langchain/langchain/indexes/_api.py | 10 +++-- .../langchain/indexes/_sql_record_manager.py | 4 ++ libs/langchain/langchain/indexes/base.py | 2 + .../tests/unit_tests/indexes/test_indexing.py | 38 ++++++++++++++++++- 4 files changed, 49 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/indexes/_api.py b/libs/langchain/langchain/indexes/_api.py index c471d07d9c..c62f0e1ed7 100644 --- a/libs/langchain/langchain/indexes/_api.py +++ b/libs/langchain/langchain/indexes/_api.py @@ -171,6 +171,7 @@ def index( batch_size: int = 100, cleanup: Literal["incremental", "full", None] = None, source_id_key: Union[str, Callable[[Document], str], None] = None, + cleanup_batch_size: int = 1_000, ) -> IndexingResult: """Index data from the loader into the vector store. @@ -208,6 +209,7 @@ def index( - None: Do not delete any documents. source_id_key: Optional key that helps identify the original source of the document. + cleanup_batch_size: Batch size to use when cleaning up documents. Returns: Indexing result which contains information about how many documents @@ -329,14 +331,14 @@ def index( num_deleted += len(uids_to_delete) if cleanup == "full": - uids_to_delete = record_manager.list_keys(before=index_start_dt) - - if uids_to_delete: + while uids_to_delete := record_manager.list_keys( + before=index_start_dt, limit=cleanup_batch_size + ): # First delete from record store. vector_store.delete(uids_to_delete) # Then delete from record manager. record_manager.delete_keys(uids_to_delete) - num_deleted = len(uids_to_delete) + num_deleted += len(uids_to_delete) return { "num_added": num_added, diff --git a/libs/langchain/langchain/indexes/_sql_record_manager.py b/libs/langchain/langchain/indexes/_sql_record_manager.py index f47f4e9239..d9e579aa5a 100644 --- a/libs/langchain/langchain/indexes/_sql_record_manager.py +++ b/libs/langchain/langchain/indexes/_sql_record_manager.py @@ -259,6 +259,7 @@ class SQLRecordManager(RecordManager): before: Optional[float] = None, after: Optional[float] = None, group_ids: Optional[Sequence[str]] = None, + limit: Optional[int] = None, ) -> List[str]: """List records in the SQLite database based on the provided date range.""" with self._make_session() as session: @@ -279,6 +280,9 @@ class SQLRecordManager(RecordManager): query = query.filter( # type: ignore[attr-defined] UpsertionRecord.group_id.in_(group_ids) ) + + if limit: + query = query.limit(limit) # type: ignore[attr-defined] records = query.all() # type: ignore[attr-defined] return [r.key for r in records] diff --git a/libs/langchain/langchain/indexes/base.py b/libs/langchain/langchain/indexes/base.py index 128455253a..69b6e6b5bf 100644 --- a/libs/langchain/langchain/indexes/base.py +++ b/libs/langchain/langchain/indexes/base.py @@ -74,6 +74,7 @@ class RecordManager(ABC): before: Optional[float] = None, after: Optional[float] = None, group_ids: Optional[Sequence[str]] = None, + limit: Optional[int] = None, ) -> List[str]: """List records in the database based on the provided filters. @@ -81,6 +82,7 @@ class RecordManager(ABC): before: Filter to list records updated before this time. after: Filter to list records updated after this time. group_ids: Filter to list records with specific group IDs. + limit: optional limit on the number of records to return. Returns: A list of keys for the matching records. diff --git a/libs/langchain/tests/unit_tests/indexes/test_indexing.py b/libs/langchain/tests/unit_tests/indexes/test_indexing.py index 3567dd648c..9e4a59e1b6 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_indexing.py +++ b/libs/langchain/tests/unit_tests/indexes/test_indexing.py @@ -474,6 +474,43 @@ def test_deduplication( } +def test_cleanup_with_different_batchsize( + record_manager: SQLRecordManager, vector_store: VectorStore +) -> None: + """Check that we can clean up with different batch size.""" + docs = [ + Document( + page_content="This is a test document.", + metadata={"source": str(d)}, + ) + for d in range(1000) + ] + + assert index(docs, record_manager, vector_store, cleanup="full") == { + "num_added": 1000, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + docs = [ + Document( + page_content="Different doc", + metadata={"source": str(d)}, + ) + for d in range(1001) + ] + + assert index( + docs, record_manager, vector_store, cleanup="full", cleanup_batch_size=17 + ) == { + "num_added": 1001, + "num_deleted": 1000, + "num_skipped": 0, + "num_updated": 0, + } + + def test_deduplication_v2( record_manager: SQLRecordManager, vector_store: VectorStore ) -> None: @@ -497,7 +534,6 @@ def test_deduplication_v2( ), ] - # Should result in only a single document being added assert index(docs, record_manager, vector_store, cleanup="full") == { "num_added": 3, "num_deleted": 0,