@ -20,7 +20,7 @@ from langchain_core.embeddings import Embeddings
from langchain_core . vectorstores import VST , VectorStore
from langchain . indexes import aindex , index
from langchain . indexes . _api import _abatch
from langchain . indexes . _api import _abatch , _HashedDocument
from langchain . indexes . _sql_record_manager import SQLRecordManager
@ -1304,3 +1304,44 @@ async def test_aindexing_force_update(
" num_skipped " : 0 ,
" num_updated " : 2 ,
}
def test_indexing_custom_batch_size (
record_manager : SQLRecordManager , vector_store : InMemoryVectorStore
) - > None :
""" Test indexing with a custom batch size. """
docs = [
Document (
page_content = " This is a test document. " ,
metadata = { " source " : " 1 " } ,
) ,
]
ids = [ _HashedDocument . from_document ( doc ) . uid for doc in docs ]
batch_size = 1
with patch . object ( vector_store , " add_documents " ) as mock_add_documents :
index ( docs , record_manager , vector_store , batch_size = batch_size )
args , kwargs = mock_add_documents . call_args
assert args == ( docs , )
assert kwargs == { " ids " : ids , " batch_size " : batch_size }
@pytest.mark.requires ( " aiosqlite " )
async def test_aindexing_custom_batch_size (
arecord_manager : SQLRecordManager , vector_store : InMemoryVectorStore
) - > None :
""" Test indexing with a custom batch size. """
docs = [
Document (
page_content = " This is a test document. " ,
metadata = { " source " : " 1 " } ,
) ,
]
ids = [ _HashedDocument . from_document ( doc ) . uid for doc in docs ]
batch_size = 1
with patch . object ( vector_store , " aadd_documents " ) as mock_add_documents :
await aindex ( docs , arecord_manager , vector_store , batch_size = batch_size )
args , kwargs = mock_add_documents . call_args
assert args == ( docs , )
assert kwargs == { " ids " : ids , " batch_size " : batch_size }