core[minor]: add **kwargs to index and aindex functions for custom vector_field support (#26998)

Added `**kwargs` parameters to the `index` and `aindex` functions in
`libs/core/langchain_core/indexing/api.py`. This allows users to pass
additional arguments to the `add_documents` and `aadd_documents`
methods, enabling the specification of a custom `vector_field`. For
example, users can now use `vector_field="embedding"` when indexing
documents in `OpenSearchVectorStore`

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
João Carlos Ferra de Almeida 2024-10-07 19:52:50 +01:00 committed by GitHub
parent 14de81b140
commit 780ce00dea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 208 additions and 4 deletions

View File

@ -198,6 +198,7 @@ def index(
source_id_key: Union[str, Callable[[Document], str], None] = None,
cleanup_batch_size: int = 1_000,
force_update: bool = False,
upsert_kwargs: Optional[dict[str, Any]] = None,
) -> IndexingResult:
"""Index data from the loader into the vector store.
@ -249,6 +250,12 @@ def index(
force_update: Force update documents even if they are present in the
record manager. Useful if you are re-indexing with updated embeddings.
Default is False.
upsert_kwargs: Additional keyword arguments to pass to the add_documents
method of the VectorStore or the upsert method of the
DocumentIndex. For example, you can use this to
specify a custom vector_field:
upsert_kwargs={"vector_field": "embedding"}
.. versionadded:: 0.3.10
Returns:
Indexing result which contains information about how many documents
@ -363,10 +370,16 @@ def index(
if docs_to_index:
if isinstance(destination, VectorStore):
destination.add_documents(
docs_to_index, ids=uids, batch_size=batch_size
docs_to_index,
ids=uids,
batch_size=batch_size,
**(upsert_kwargs or {}),
)
elif isinstance(destination, DocumentIndex):
destination.upsert(docs_to_index)
destination.upsert(
docs_to_index,
**(upsert_kwargs or {}),
)
num_added += len(docs_to_index) - len(seen_docs)
num_updated += len(seen_docs)
@ -438,6 +451,7 @@ async def aindex(
source_id_key: Union[str, Callable[[Document], str], None] = None,
cleanup_batch_size: int = 1_000,
force_update: bool = False,
upsert_kwargs: Optional[dict[str, Any]] = None,
) -> IndexingResult:
"""Async index data from the loader into the vector store.
@ -480,6 +494,12 @@ async def aindex(
force_update: Force update documents even if they are present in the
record manager. Useful if you are re-indexing with updated embeddings.
Default is False.
upsert_kwargs: Additional keyword arguments to pass to the aadd_documents
method of the VectorStore or the aupsert method of the
DocumentIndex. For example, you can use this to
specify a custom vector_field:
upsert_kwargs={"vector_field": "embedding"}
.. versionadded:: 0.3.10
Returns:
Indexing result which contains information about how many documents
@ -604,10 +624,16 @@ async def aindex(
if docs_to_index:
if isinstance(destination, VectorStore):
await destination.aadd_documents(
docs_to_index, ids=uids, batch_size=batch_size
docs_to_index,
ids=uids,
batch_size=batch_size,
**(upsert_kwargs or {}),
)
elif isinstance(destination, DocumentIndex):
await destination.aupsert(docs_to_index)
await destination.aupsert(
docs_to_index,
**(upsert_kwargs or {}),
)
num_added += len(docs_to_index) - len(seen_docs)
num_updated += len(seen_docs)

View File

@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from pytest_mock import MockerFixture
from langchain_core.document_loaders.base import BaseLoader
from langchain_core.documents import Document
@ -1728,3 +1729,180 @@ async def test_incremental_aindexing_with_batch_size_with_optimization(
for uid in vector_store.store
}
assert doc_texts == {"updated 1", "2", "3", "updated 4"}
def test_index_with_upsert_kwargs(
record_manager: InMemoryRecordManager, upserting_vector_store: InMemoryVectorStore
) -> None:
"""Test indexing with upsert_kwargs parameter."""
mock_add_documents = MagicMock()
with patch.object(upserting_vector_store, "add_documents", mock_add_documents):
docs = [
Document(
page_content="Test document 1",
metadata={"source": "1"},
),
Document(
page_content="Test document 2",
metadata={"source": "2"},
),
]
upsert_kwargs = {"vector_field": "embedding"}
index(docs, record_manager, upserting_vector_store, upsert_kwargs=upsert_kwargs)
# Assert that add_documents was called with the correct arguments
mock_add_documents.assert_called_once()
call_args = mock_add_documents.call_args
assert call_args is not None
args, kwargs = call_args
# Check that the documents are correct (ignoring ids)
assert len(args[0]) == 2
assert all(isinstance(doc, Document) for doc in args[0])
assert [doc.page_content for doc in args[0]] == [
"Test document 1",
"Test document 2",
]
assert [doc.metadata for doc in args[0]] == [{"source": "1"}, {"source": "2"}]
# Check that ids are present
assert "ids" in kwargs
assert isinstance(kwargs["ids"], list)
assert len(kwargs["ids"]) == 2
# Check other arguments
assert kwargs["batch_size"] == 100
assert kwargs["vector_field"] == "embedding"
def test_index_with_upsert_kwargs_for_document_indexer(
record_manager: InMemoryRecordManager,
mocker: MockerFixture,
) -> None:
"""Test that kwargs are passed to the upsert method of the document indexer."""
document_index = InMemoryDocumentIndex()
upsert_spy = mocker.spy(document_index.__class__, "upsert")
docs = [
Document(
page_content="This is a test document.",
metadata={"source": "1"},
),
Document(
page_content="This is another document.",
metadata={"source": "2"},
),
]
upsert_kwargs = {"vector_field": "embedding"}
assert index(
docs,
record_manager,
document_index,
cleanup="full",
upsert_kwargs=upsert_kwargs,
) == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
"num_updated": 0,
}
assert upsert_spy.call_count == 1
# assert call kwargs were passed as kwargs
assert upsert_spy.call_args.kwargs == upsert_kwargs
async def test_aindex_with_upsert_kwargs_for_document_indexer(
arecord_manager: InMemoryRecordManager,
mocker: MockerFixture,
) -> None:
"""Test that kwargs are passed to the upsert method of the document indexer."""
document_index = InMemoryDocumentIndex()
upsert_spy = mocker.spy(document_index.__class__, "aupsert")
docs = [
Document(
page_content="This is a test document.",
metadata={"source": "1"},
),
Document(
page_content="This is another document.",
metadata={"source": "2"},
),
]
upsert_kwargs = {"vector_field": "embedding"}
assert await aindex(
docs,
arecord_manager,
document_index,
cleanup="full",
upsert_kwargs=upsert_kwargs,
) == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
"num_updated": 0,
}
assert upsert_spy.call_count == 1
# assert call kwargs were passed as kwargs
assert upsert_spy.call_args.kwargs == upsert_kwargs
async def test_aindex_with_upsert_kwargs(
arecord_manager: InMemoryRecordManager, upserting_vector_store: InMemoryVectorStore
) -> None:
"""Test async indexing with upsert_kwargs parameter."""
mock_aadd_documents = AsyncMock()
with patch.object(upserting_vector_store, "aadd_documents", mock_aadd_documents):
docs = [
Document(
page_content="Async test document 1",
metadata={"source": "1"},
),
Document(
page_content="Async test document 2",
metadata={"source": "2"},
),
]
upsert_kwargs = {"vector_field": "embedding"}
await aindex(
docs,
arecord_manager,
upserting_vector_store,
upsert_kwargs=upsert_kwargs,
)
# Assert that aadd_documents was called with the correct arguments
mock_aadd_documents.assert_called_once()
call_args = mock_aadd_documents.call_args
assert call_args is not None
args, kwargs = call_args
# Check that the documents are correct (ignoring ids)
assert len(args[0]) == 2
assert all(isinstance(doc, Document) for doc in args[0])
assert [doc.page_content for doc in args[0]] == [
"Async test document 1",
"Async test document 2",
]
assert [doc.metadata for doc in args[0]] == [{"source": "1"}, {"source": "2"}]
# Check that ids are present
assert "ids" in kwargs
assert isinstance(kwargs["ids"], list)
assert len(kwargs["ids"]) == 2
# Check other arguments
assert kwargs["batch_size"] == 100
assert kwargs["vector_field"] == "embedding"