langchain: Passthrough batch_size on index()/aindex() calls (#19443)

**Description:** This change passes through `batch_size` to
`add_documents()`/`aadd_documents()` on calls to `index()` and
`aindex()` such that the documents are processed in the expected batch
size.
**Issue:** #19415
**Dependencies:** N/A
**Twitter handle:** N/A
pull/19293/head^2
Zachary Wilkins 3 months ago committed by GitHub
parent 82de8fd6c9
commit e1a6341940
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -330,7 +330,7 @@ def index(
# Be pessimistic and assume that all vector store write will fail.
# First write to vector store
if docs_to_index:
vector_store.add_documents(docs_to_index, ids=uids)
vector_store.add_documents(docs_to_index, ids=uids, batch_size=batch_size)
num_added += len(docs_to_index) - len(seen_docs)
num_updated += len(seen_docs)
@ -544,7 +544,9 @@ async def aindex(
# Be pessimistic and assume that all vector store write will fail.
# First write to vector store
if docs_to_index:
await vector_store.aadd_documents(docs_to_index, ids=uids)
await vector_store.aadd_documents(
docs_to_index, ids=uids, batch_size=batch_size
)
num_added += len(docs_to_index) - len(seen_docs)
num_updated += len(seen_docs)

@ -20,7 +20,7 @@ from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VST, VectorStore
from langchain.indexes import aindex, index
from langchain.indexes._api import _abatch
from langchain.indexes._api import _abatch, _HashedDocument
from langchain.indexes._sql_record_manager import SQLRecordManager
@ -1304,3 +1304,44 @@ async def test_aindexing_force_update(
"num_skipped": 0,
"num_updated": 2,
}
def test_indexing_custom_batch_size(
record_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
"""Test indexing with a custom batch size."""
docs = [
Document(
page_content="This is a test document.",
metadata={"source": "1"},
),
]
ids = [_HashedDocument.from_document(doc).uid for doc in docs]
batch_size = 1
with patch.object(vector_store, "add_documents") as mock_add_documents:
index(docs, record_manager, vector_store, batch_size=batch_size)
args, kwargs = mock_add_documents.call_args
assert args == (docs,)
assert kwargs == {"ids": ids, "batch_size": batch_size}
@pytest.mark.requires("aiosqlite")
async def test_aindexing_custom_batch_size(
arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
"""Test indexing with a custom batch size."""
docs = [
Document(
page_content="This is a test document.",
metadata={"source": "1"},
),
]
ids = [_HashedDocument.from_document(doc).uid for doc in docs]
batch_size = 1
with patch.object(vector_store, "aadd_documents") as mock_add_documents:
await aindex(docs, arecord_manager, vector_store, batch_size=batch_size)
args, kwargs = mock_add_documents.call_args
assert args == (docs,)
assert kwargs == {"ids": ids, "batch_size": batch_size}

Loading…
Cancel
Save