Add async sql record manager and async indexing API (#10726)

- **Description:** Add support for a SQLRecordManager in async
environments. It includes the creation of `RecorManagerAsync` abstract
class.
- **Issue:** None
- **Dependencies:** Optional `aiosqlite`.
- **Tag maintainer:** @nfcampos 
- **Twitter handle:** @jvelezmagic

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
pull/11482/head
Jesús Vélez Santiago 1 year ago committed by GitHub
parent 57ade13b2b
commit a1c7532298
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,13 +13,14 @@ Importantly, this keeps on working even if the content being written is derived
via a set of transformations from some source content (e.g., indexing children
documents that were derived from parent documents by chunking.)
"""
from langchain.indexes._api import IndexingResult, index
from langchain.indexes._api import IndexingResult, aindex, index
from langchain.indexes._sql_record_manager import SQLRecordManager
from langchain.indexes.graph import GraphIndexCreator
from langchain.indexes.vectorstore import VectorstoreIndexCreator
__all__ = [
# Keep sorted
"aindex",
"GraphIndexCreator",
"index",
"IndexingResult",

@ -7,6 +7,8 @@ import uuid
from itertools import islice
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Callable,
Dict,
Iterable,
@ -15,6 +17,7 @@ from typing import (
Literal,
Optional,
Sequence,
Set,
TypedDict,
TypeVar,
Union,
@ -36,7 +39,7 @@ def _hash_string_to_uuid(input_string: str) -> uuid.UUID:
return uuid.uuid5(NAMESPACE_UUID, hash_value)
def _hash_nested_dict_to_uuid(data: dict) -> uuid.UUID:
def _hash_nested_dict_to_uuid(data: dict[Any, Any]) -> uuid.UUID:
"""Hashes a nested dictionary and returns the corresponding UUID."""
serialized_data = json.dumps(data, sort_keys=True)
hash_value = hashlib.sha1(serialized_data.encode("utf-8")).hexdigest()
@ -118,6 +121,21 @@ def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
yield chunk
async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[List[T]]:
"""Utility batching function."""
batch: List[T] = []
async for element in iterable:
if len(batch) < size:
batch.append(element)
if len(batch) >= size:
yield batch
batch = []
if batch:
yield batch
def _get_source_id_assigner(
source_id_key: Union[str, Callable[[Document], str], None],
) -> Callable[[Document], Union[str, None]]:
@ -139,7 +157,7 @@ def _deduplicate_in_order(
hashed_documents: Iterable[_HashedDocument],
) -> Iterator[_HashedDocument]:
"""Deduplicate a list of hashed documents while preserving order."""
seen = set()
seen: Set[str] = set()
for hashed_doc in hashed_documents:
if hashed_doc.hash_ not in seen:
@ -346,3 +364,203 @@ def index(
"num_skipped": num_skipped,
"num_deleted": num_deleted,
}
# Define an asynchronous generator function
async def _to_async_iterator(iterator: Iterable[T]) -> AsyncIterator[T]:
"""Convert an iterable to an async iterator."""
for item in iterator:
yield item
async def aindex(
docs_source: Union[Iterable[Document], AsyncIterator[Document]],
record_manager: RecordManager,
vector_store: VectorStore,
*,
batch_size: int = 100,
cleanup: Literal["incremental", "full", None] = None,
source_id_key: Union[str, Callable[[Document], str], None] = None,
cleanup_batch_size: int = 1_000,
) -> IndexingResult:
"""Index data from the loader into the vector store.
Indexing functionality uses a manager to keep track of which documents
are in the vector store.
This allows us to keep track of which documents were updated, and which
documents were deleted, which documents should be skipped.
For the time being, documents are indexed using their hashes, and users
are not able to specify the uid of the document.
IMPORTANT:
if auto_cleanup is set to True, the loader should be returning
the entire dataset, and not just a subset of the dataset.
Otherwise, the auto_cleanup will remove documents that it is not
supposed to.
Args:
docs_source: Data loader or iterable of documents to index.
record_manager: Timestamped set to keep track of which documents were
updated.
vector_store: Vector store to index the documents into.
batch_size: Batch size to use when indexing.
cleanup: How to handle clean up of documents.
- Incremental: Cleans up all documents that haven't been updated AND
that are associated with source ids that were seen
during indexing.
Clean up is done continuously during indexing helping
to minimize the probability of users seeing duplicated
content.
- Full: Delete all documents that haven to been returned by the loader.
Clean up runs after all documents have been indexed.
This means that users may see duplicated content during indexing.
- None: Do not delete any documents.
source_id_key: Optional key that helps identify the original source
of the document.
cleanup_batch_size: Batch size to use when cleaning up documents.
Returns:
Indexing result which contains information about how many documents
were added, updated, deleted, or skipped.
"""
if cleanup not in {"incremental", "full", None}:
raise ValueError(
f"cleanup should be one of 'incremental', 'full' or None. "
f"Got {cleanup}."
)
if cleanup == "incremental" and source_id_key is None:
raise ValueError("Source id key is required when cleanup mode is incremental.")
# Check that the Vectorstore has required methods implemented
methods = ["adelete", "aadd_documents"]
for method in methods:
if not hasattr(vector_store, method):
raise ValueError(
f"Vectorstore {vector_store} does not have required method {method}"
)
if type(vector_store).adelete == VectorStore.adelete:
# Checking if the vectorstore has overridden the default delete method
# implementation which just raises a NotImplementedError
raise ValueError("Vectorstore has not implemented the delete method")
if isinstance(docs_source, BaseLoader):
raise NotImplementedError(
"Not supported yet. Please pass an async iterator of documents."
)
async_doc_iterator: AsyncIterator[Document]
if hasattr(docs_source, "__aiter__"):
async_doc_iterator = docs_source # type: ignore[assignment]
else:
async_doc_iterator = _to_async_iterator(docs_source)
source_id_assigner = _get_source_id_assigner(source_id_key)
# Mark when the update started.
index_start_dt = await record_manager.aget_time()
num_added = 0
num_skipped = 0
num_updated = 0
num_deleted = 0
async for doc_batch in _abatch(batch_size, async_doc_iterator):
hashed_docs = list(
_deduplicate_in_order(
[_HashedDocument.from_document(doc) for doc in doc_batch]
)
)
source_ids: Sequence[Optional[str]] = [
source_id_assigner(doc) for doc in hashed_docs
]
if cleanup == "incremental":
# If the cleanup mode is incremental, source ids are required.
for source_id, hashed_doc in zip(source_ids, hashed_docs):
if source_id is None:
raise ValueError(
"Source ids are required when cleanup mode is incremental. "
f"Document that starts with "
f"content: {hashed_doc.page_content[:100]} was not assigned "
f"as source id."
)
# source ids cannot be None after for loop above.
source_ids = cast(Sequence[str], source_ids)
exists_batch = await record_manager.aexists([doc.uid for doc in hashed_docs])
# Filter out documents that already exist in the record store.
uids: list[str] = []
docs_to_index: list[Document] = []
for hashed_doc, doc_exists in zip(hashed_docs, exists_batch):
if doc_exists:
# Must be updated to refresh timestamp.
await record_manager.aupdate(
[hashed_doc.uid], time_at_least=index_start_dt
)
num_skipped += 1
continue
uids.append(hashed_doc.uid)
docs_to_index.append(hashed_doc.to_document())
# Be pessimistic and assume that all vector store write will fail.
# First write to vector store
if docs_to_index:
await vector_store.aadd_documents(docs_to_index, ids=uids)
num_added += len(docs_to_index)
# And only then update the record store.
# Update ALL records, even if they already exist since we want to refresh
# their timestamp.
await record_manager.aupdate(
[doc.uid for doc in hashed_docs],
group_ids=source_ids,
time_at_least=index_start_dt,
)
# If source IDs are provided, we can do the deletion incrementally!
if cleanup == "incremental":
# Get the uids of the documents that were not returned by the loader.
# mypy isn't good enough to determine that source ids cannot be None
# here due to a check that's happening above, so we check again.
for source_id in source_ids:
if source_id is None:
raise AssertionError("Source ids cannot be None here.")
_source_ids = cast(Sequence[str], source_ids)
uids_to_delete = await record_manager.alist_keys(
group_ids=_source_ids, before=index_start_dt
)
if uids_to_delete:
# Then delete from vector store.
await vector_store.adelete(uids_to_delete)
# First delete from record store.
await record_manager.adelete_keys(uids_to_delete)
num_deleted += len(uids_to_delete)
if cleanup == "full":
while uids_to_delete := await record_manager.alist_keys(
before=index_start_dt, limit=cleanup_batch_size
):
# First delete from record store.
await vector_store.adelete(uids_to_delete)
# Then delete from record manager.
await record_manager.adelete_keys(uids_to_delete)
num_deleted += len(uids_to_delete)
return {
"num_added": num_added,
"num_updated": num_updated,
"num_skipped": num_skipped,
"num_deleted": num_deleted,
}

@ -16,7 +16,7 @@ allow it to work with a variety of SQL as a backend.
import contextlib
import decimal
import uuid
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence, Union
from sqlalchemy import (
URL,
@ -28,8 +28,16 @@ from sqlalchemy import (
UniqueConstraint,
and_,
create_engine,
delete,
select,
text,
)
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker
@ -77,9 +85,10 @@ class SQLRecordManager(RecordManager):
self,
namespace: str,
*,
engine: Optional[Engine] = None,
engine: Optional[Union[Engine, AsyncEngine]] = None,
db_url: Union[None, str, URL] = None,
engine_kwargs: Optional[Dict[str, Any]] = None,
async_mode: bool = False,
) -> None:
"""Initialize the SQLRecordManager.
@ -95,6 +104,10 @@ class SQLRecordManager(RecordManager):
an SQL Alchemy engine. Default is None.
engine_kwargs: Additional keyword arguments
to be passed when creating the engine. Default is an empty dictionary.
async_mode: Whether to create an async engine.
Driver should support async operations.
It only applies if db_url is provided.
Default is False.
Raises:
ValueError: If both db_url and engine are provided or neither.
@ -103,34 +116,71 @@ class SQLRecordManager(RecordManager):
super().__init__(namespace=namespace)
if db_url is None and engine is None:
raise ValueError("Must specify either db_url or engine")
if db_url is not None and engine is not None:
raise ValueError("Must specify either db_url or engine, not both")
_engine: Union[Engine, AsyncEngine]
if db_url:
_kwargs = engine_kwargs or {}
_engine = create_engine(db_url, **_kwargs)
if async_mode:
_engine = create_async_engine(db_url, **(engine_kwargs or {}))
else:
_engine = create_engine(db_url, **(engine_kwargs or {}))
elif engine:
_engine = engine
else:
raise AssertionError("Something went wrong with configuration of engine.")
_session_factory: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]
if isinstance(_engine, AsyncEngine):
_session_factory = async_sessionmaker(bind=_engine)
else:
_session_factory = sessionmaker(bind=_engine)
self.engine = _engine
self.dialect = _engine.dialect.name
self.session_factory = sessionmaker(bind=self.engine)
self.session_factory = _session_factory
def create_schema(self) -> None:
"""Create the database schema."""
if isinstance(self.engine, AsyncEngine):
raise AssertionError("This method is not supported for async engines.")
Base.metadata.create_all(self.engine)
async def acreate_schema(self) -> None:
"""Create the database schema."""
if not isinstance(self.engine, AsyncEngine):
raise AssertionError("This method is not supported for sync engines.")
async with self.engine.begin() as session:
await session.run_sync(Base.metadata.create_all)
@contextlib.contextmanager
def _make_session(self) -> Generator[Session, None, None]:
"""Create a session and close it after use."""
if isinstance(self.session_factory, async_sessionmaker):
raise AssertionError("This method is not supported for async engines.")
session = self.session_factory()
try:
yield session
finally:
session.close()
@contextlib.asynccontextmanager
async def _amake_session(self) -> AsyncGenerator[AsyncSession, None]:
"""Create a session and close it after use."""
if not isinstance(self.session_factory, async_sessionmaker):
raise AssertionError("This method is not supported for sync engines.")
async with self.session_factory() as session:
yield session
def get_time(self) -> float:
"""Get the current server time as a timestamp.
@ -161,6 +211,37 @@ class SQLRecordManager(RecordManager):
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
return dt
async def aget_time(self) -> float:
"""Get the current server time as a timestamp.
Please note it's critical that time is obtained from the server since
we want a monotonic clock.
"""
async with self._amake_session() as session:
# * SQLite specific implementation, can be changed based on dialect.
# * For SQLite, unlike unixepoch it will work with older versions of SQLite.
# ----
# julianday('now'): Julian day number for the current date and time.
# The Julian day is a continuous count of days, starting from a
# reference date (Julian day number 0).
# 2440587.5 - constant represents the Julian day number for January 1, 1970
# 86400.0 - constant represents the number of seconds
# in a day (24 hours * 60 minutes * 60 seconds)
if self.dialect == "sqlite":
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
elif self.dialect == "postgresql":
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
else:
raise NotImplementedError(f"Not implemented for dialect {self.dialect}")
dt = (await session.execute(query)).scalar_one_or_none()
if isinstance(dt, decimal.Decimal):
dt = float(dt)
if not isinstance(dt, float):
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
return dt
def update(
self,
keys: Sequence[str],
@ -236,6 +317,81 @@ class SQLRecordManager(RecordManager):
session.execute(stmt)
session.commit()
async def aupdate(
self,
keys: Sequence[str],
*,
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
"""Upsert records into the SQLite database."""
if group_ids is None:
group_ids = [None] * len(keys)
if len(keys) != len(group_ids):
raise ValueError(
f"Number of keys ({len(keys)}) does not match number of "
f"group_ids ({len(group_ids)})"
)
# Get the current time from the server.
# This makes an extra round trip to the server, should not be a big deal
# if the batch size is large enough.
# Getting the time here helps us compare it against the time_at_least
# and raise an error if there is a time sync issue.
# Here, we're just being extra careful to minimize the chance of
# data loss due to incorrectly deleting records.
update_time = await self.aget_time()
if time_at_least and update_time < time_at_least:
# Safeguard against time sync issues
raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}")
records_to_upsert = [
{
"key": key,
"namespace": self.namespace,
"updated_at": update_time,
"group_id": group_id,
}
for key, group_id in zip(keys, group_ids)
]
async with self._amake_session() as session:
if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
),
)
elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import insert as pg_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
"uix_key_namespace", # Name of constraint
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
),
)
else:
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
await session.execute(stmt)
await session.commit()
def exists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the given keys exist in the SQLite database."""
with self._make_session() as session:
@ -253,6 +409,26 @@ class SQLRecordManager(RecordManager):
found_keys = set(r.key for r in records)
return [k in found_keys for k in keys]
async def aexists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the given keys exist in the SQLite database."""
async with self._amake_session() as session:
records = (
(
await session.execute(
select(UpsertionRecord.key).where(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
)
)
)
)
.scalars()
.all()
)
found_keys = set(records)
return [k in found_keys for k in keys]
def list_keys(
self,
*,
@ -286,6 +462,39 @@ class SQLRecordManager(RecordManager):
records = query.all() # type: ignore[attr-defined]
return [r.key for r in records]
async def alist_keys(
self,
*,
before: Optional[float] = None,
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
"""List records in the SQLite database based on the provided date range."""
async with self._amake_session() as session:
query = select(UpsertionRecord.key).filter(
UpsertionRecord.namespace == self.namespace
)
# mypy does not recognize .all() or .filter()
if after:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at > after
)
if before:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at < before
)
if group_ids:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.group_id.in_(group_ids)
)
if limit:
query = query.limit(limit) # type: ignore[attr-defined]
records = (await session.execute(query)).scalars().all()
return list(records)
def delete_keys(self, keys: Sequence[str]) -> None:
"""Delete records from the SQLite database."""
with self._make_session() as session:
@ -297,3 +506,17 @@ class SQLRecordManager(RecordManager):
)
).delete() # type: ignore[attr-defined]
session.commit()
async def adelete_keys(self, keys: Sequence[str]) -> None:
"""Delete records from the SQLite database."""
async with self._amake_session() as session:
await session.execute(
delete(UpsertionRecord).where(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
)
)
)
await session.commit()

@ -25,6 +25,10 @@ class RecordManager(ABC):
def create_schema(self) -> None:
"""Create the database schema for the record manager."""
@abstractmethod
async def acreate_schema(self) -> None:
"""Create the database schema for the record manager."""
@abstractmethod
def get_time(self) -> float:
"""Get the current server time as a high resolution timestamp!
@ -36,6 +40,17 @@ class RecordManager(ABC):
The current server time as a float timestamp.
"""
@abstractmethod
async def aget_time(self) -> float:
"""Get the current server time as a high resolution timestamp!
It's important to get this from the server to ensure a monotonic clock,
otherwise there may be data loss when cleaning up old documents!
Returns:
The current server time as a float timestamp.
"""
@abstractmethod
def update(
self,
@ -56,6 +71,26 @@ class RecordManager(ABC):
ValueError: If the length of keys doesn't match the length of group_ids.
"""
@abstractmethod
async def aupdate(
self,
keys: Sequence[str],
*,
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
"""Upsert records into the database.
Args:
keys: A list of record keys to upsert.
group_ids: A list of group IDs corresponding to the keys.
time_at_least: if provided, updates should only happen if the
updated_at field is at least this time.
Raises:
ValueError: If the length of keys doesn't match the length of group_ids.
"""
@abstractmethod
def exists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the provided keys exist in the database.
@ -67,6 +102,17 @@ class RecordManager(ABC):
A list of boolean values indicating the existence of each key.
"""
@abstractmethod
async def aexists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the provided keys exist in the database.
Args:
keys: A list of keys to check.
Returns:
A list of boolean values indicating the existence of each key.
"""
@abstractmethod
def list_keys(
self,
@ -88,6 +134,27 @@ class RecordManager(ABC):
A list of keys for the matching records.
"""
@abstractmethod
async def alist_keys(
self,
*,
before: Optional[float] = None,
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
"""List records in the database based on the provided filters.
Args:
before: Filter to list records updated before this time.
after: Filter to list records updated after this time.
group_ids: Filter to list records with specific group IDs.
limit: optional limit on the number of records to return.
Returns:
A list of keys for the matching records.
"""
@abstractmethod
def delete_keys(self, keys: Sequence[str]) -> None:
"""Delete specified records from the database.
@ -95,3 +162,11 @@ class RecordManager(ABC):
Args:
keys: A list of keys to delete.
"""
@abstractmethod
async def adelete_keys(self, keys: Sequence[str]) -> None:
"""Delete specified records from the database.
Args:
keys: A list of keys to delete.
"""

@ -80,6 +80,22 @@ class VectorStore(ABC):
raise NotImplementedError("delete method must be implemented by subclass.")
async def adelete(
self, ids: Optional[List[str]] = None, **kwargs: Any
) -> Optional[bool]:
"""Delete by vector ID or other criteria.
Args:
ids: List of ids to delete.
**kwargs: Other keyword arguments that subclasses might use.
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
raise NotImplementedError("delete method must be implemented by subclass.")
async def aadd_texts(
self,
texts: Iterable[str],

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "absl-py"
@ -226,6 +226,21 @@ files = [
[package.dependencies]
frozenlist = ">=1.1.0"
[[package]]
name = "aiosqlite"
version = "0.19.0"
description = "asyncio bridge to the standard sqlite3 module"
optional = true
python-versions = ">=3.7"
files = [
{file = "aiosqlite-0.19.0-py3-none-any.whl", hash = "sha256:edba222e03453e094a3ce605db1b970c4b3376264e56f32e2a4959f948d66a96"},
{file = "aiosqlite-0.19.0.tar.gz", hash = "sha256:95ee77b91c8d2808bd08a59fbebf66270e9090c3d92ffbf260dc0db0b979577d"},
]
[package.extras]
dev = ["aiounittest (==1.4.1)", "attribution (==1.6.2)", "black (==23.3.0)", "coverage[toml] (==7.2.3)", "flake8 (==5.0.4)", "flake8-bugbear (==23.3.12)", "flit (==3.7.1)", "mypy (==1.2.0)", "ufmt (==2.1.0)", "usort (==1.0.6)"]
docs = ["sphinx (==6.1.3)", "sphinx-mdinclude (==0.5.3)"]
[[package]]
name = "aleph-alpha-client"
version = "2.17.0"
@ -5768,12 +5783,11 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.21.0", markers = "python_version <= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
{version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""},
{version = ">=1.17.0", markers = "python_version >= \"3.7\""},
{version = ">=1.17.3", markers = "python_version >= \"3.8\""},
{version = ">=1.21.2", markers = "python_version >= \"3.10\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""},
{version = ">=1.21.0", markers = "python_version <= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\" and python_version >= \"3.8\""},
{version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
{version = ">=1.17.3", markers = "(platform_system != \"Darwin\" and platform_system != \"Linux\") and python_version >= \"3.8\" and python_version < \"3.9\" or platform_system != \"Darwin\" and python_version >= \"3.8\" and python_version < \"3.9\" and platform_machine != \"aarch64\" or platform_machine != \"arm64\" and python_version >= \"3.8\" and python_version < \"3.9\" and platform_system != \"Linux\" or (platform_machine != \"arm64\" and platform_machine != \"aarch64\") and python_version >= \"3.8\" and python_version < \"3.9\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
]
@ -5961,7 +5975,7 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.20.3", markers = "python_version < \"3.10\""},
{version = ">=1.21.0", markers = "python_version >= \"3.10\""},
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
]
python-dateutil = ">=2.8.2"
@ -8577,11 +8591,6 @@ files = [
{file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f66eddfda9d45dd6cadcd706b65669ce1df84b8549875691b1f403730bdef217"},
{file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6448c37741145b241eeac617028ba6ec2119e1339b1385c9720dae31367f2be"},
{file = "scikit_learn-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:c413c2c850241998168bbb3bd1bb59ff03b1195a53864f0b80ab092071af6028"},
{file = "scikit_learn-1.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ef540e09873e31569bc8b02c8a9f745ee04d8e1263255a15c9969f6f5caa627f"},
{file = "scikit_learn-1.3.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:9147a3a4df4d401e618713880be023e36109c85d8569b3bf5377e6cd3fecdeac"},
{file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2cd3634695ad192bf71645702b3df498bd1e246fc2d529effdb45a06ab028b4"},
{file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c275a06c5190c5ce00af0acbb61c06374087949f643ef32d355ece12c4db043"},
{file = "scikit_learn-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:0e1aa8f206d0de814b81b41d60c1ce31f7f2c7354597af38fae46d9c47c45122"},
{file = "scikit_learn-1.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:52b77cc08bd555969ec5150788ed50276f5ef83abb72e6f469c5b91a0009bbca"},
{file = "scikit_learn-1.3.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a683394bc3f80b7c312c27f9b14ebea7766b1f0a34faf1a2e9158d80e860ec26"},
{file = "scikit_learn-1.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15d964d9eb181c79c190d3dbc2fff7338786bf017e9039571418a1d53dab236"},
@ -9029,7 +9038,7 @@ files = [
]
[package.dependencies]
greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""}
greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""}
typing-extensions = ">=4.2.0"
[package.extras]
@ -10854,7 +10863,7 @@ cli = ["typer"]
cohere = ["cohere"]
docarray = ["docarray"]
embeddings = ["sentence-transformers"]
extended-testing = ["amazon-textract-caller", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "markdownify", "motor", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-schema-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "xata", "xmltodict"]
extended-testing = ["aiosqlite", "amazon-textract-caller", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "markdownify", "motor", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-schema-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "xata", "xmltodict"]
javascript = ["esprima"]
llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"]
openai = ["openai", "tiktoken"]
@ -10864,4 +10873,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "7fbe9a5144717db54413735663870168b00e34deb4f37559e38d62843488adae"
content-hash = "8458ce2704b1fcba33f5b6bb8cb3d6fdf4d6b9a2006563f461a3edf7b8dd0d17"

@ -137,6 +137,7 @@ jsonpatch = "^1.33"
timescale-vector = {version = "^0.0.1", optional = true}
typer = {version= "^0.9.0", optional = true}
anthropic = {version = "^0.3.11", optional = true}
aiosqlite = {version = "^0.19.0", optional = true}
[tool.poetry.group.test.dependencies]
@ -314,6 +315,7 @@ cli = [
# merge-conflicts
extended_testing = [
"amazon-textract-caller",
"aiosqlite",
"assemblyai",
"beautifulsoup4",
"bibtexparser",

@ -5,6 +5,7 @@ def test_all() -> None:
"""Use to catch obvious breaking changes."""
assert __all__ == sorted(__all__, key=str.lower)
assert __all__ == [
"aindex",
"GraphIndexCreator",
"index",
"IndexingResult",

@ -1,14 +1,26 @@
from datetime import datetime
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Type
from typing import (
Any,
AsyncIterator,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Type,
)
from unittest.mock import patch
import pytest
import pytest_asyncio
from langchain.document_loaders.base import BaseLoader
from langchain.indexes import index
from langchain.embeddings.base import Embeddings
from langchain.indexes import aindex, index
from langchain.indexes._api import _abatch
from langchain.indexes._sql_record_manager import SQLRecordManager
from langchain.schema import Document
from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VST, VectorStore
@ -28,6 +40,19 @@ class ToyLoader(BaseLoader):
"""Load the documents from the source."""
return list(self.lazy_load())
async def alazy_load(
self,
) -> AsyncIterator[Document]:
async def async_generator() -> AsyncIterator[Document]:
for document in self.documents:
yield document
return async_generator()
async def aload(self) -> List[Document]:
"""Load the documents from the source."""
return [doc async for doc in await self.alazy_load()]
class InMemoryVectorStore(VectorStore):
"""In-memory implementation of VectorStore using a dictionary."""
@ -42,6 +67,12 @@ class InMemoryVectorStore(VectorStore):
for _id in ids:
self.store.pop(_id, None)
async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
"""Delete the given documents from the store using their IDs."""
if ids:
for _id in ids:
self.store.pop(_id, None)
def add_documents( # type: ignore
self,
documents: Sequence[Document],
@ -65,10 +96,33 @@ class InMemoryVectorStore(VectorStore):
)
self.store[_id] = document
async def aadd_documents(
self,
documents: Sequence[Document],
*,
ids: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> List[str]:
if ids and len(ids) != len(documents):
raise ValueError(
f"Expected {len(ids)} ids, got {len(documents)} documents."
)
if not ids:
raise NotImplementedError("This is not implemented yet.")
for _id, document in zip(ids, documents):
if _id in self.store:
raise ValueError(
f"Document with uid {_id} already exists in the store."
)
self.store[_id] = document
return list(ids)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
metadatas: Optional[List[Dict[Any, Any]]] = None,
**kwargs: Any,
) -> List[str]:
"""Add the given texts to the store (insert behavior)."""
@ -79,7 +133,7 @@ class InMemoryVectorStore(VectorStore):
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
metadatas: Optional[List[Dict[Any, Any]]] = None,
**kwargs: Any,
) -> VST:
"""Create a vector store from a list of texts."""
@ -100,6 +154,19 @@ def record_manager() -> SQLRecordManager:
return record_manager
@pytest_asyncio.fixture # type: ignore
@pytest.mark.requires("aiosqlite")
async def arecord_manager() -> SQLRecordManager:
"""Timestamped set fixture."""
record_manager = SQLRecordManager(
"kittens",
db_url="sqlite+aiosqlite:///:memory:",
async_mode=True,
)
await record_manager.acreate_schema()
return record_manager
@pytest.fixture
def vector_store() -> InMemoryVectorStore:
"""Vector store fixture."""
@ -140,6 +207,44 @@ def test_indexing_same_content(
}
@pytest.mark.asyncio
@pytest.mark.requires("aiosqlite")
async def test_aindexing_same_content(
arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
"""Indexing some content to confirm it gets added only once."""
loader = ToyLoader(
documents=[
Document(
page_content="This is a test document.",
),
Document(
page_content="This is another document.",
),
]
)
assert await aindex(await loader.alazy_load(), arecord_manager, vector_store) == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
"num_updated": 0,
}
assert len(list(vector_store.store)) == 2
for _ in range(2):
# Run the indexing again
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store
) == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
"num_updated": 0,
}
def test_index_simple_delete_full(
record_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
@ -215,6 +320,91 @@ def test_index_simple_delete_full(
}
@pytest.mark.asyncio
@pytest.mark.requires("aiosqlite")
async def test_aindex_simple_delete_full(
arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
"""Indexing some content to confirm it gets added only once."""
loader = ToyLoader(
documents=[
Document(
page_content="This is a test document.",
),
Document(
page_content="This is another document.",
),
]
)
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp()
):
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
"num_updated": 0,
}
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp()
):
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
"num_updated": 0,
}
loader = ToyLoader(
documents=[
Document(
page_content="mutated document 1",
),
Document(
page_content="This is another document.", # <-- Same as original
),
]
)
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
"num_added": 1,
"num_deleted": 1,
"num_skipped": 1,
"num_updated": 0,
}
doc_texts = set(
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore
for uid in vector_store.store
)
assert doc_texts == {"mutated document 1", "This is another document."}
# Attempt to index again verify that nothing changes
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
"num_updated": 0,
}
def test_incremental_fails_with_bad_source_ids(
record_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
@ -251,6 +441,49 @@ def test_incremental_fails_with_bad_source_ids(
)
@pytest.mark.asyncio
@pytest.mark.requires("aiosqlite")
async def test_aincremental_fails_with_bad_source_ids(
arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
"""Test indexing with incremental deletion strategy."""
loader = ToyLoader(
documents=[
Document(
page_content="This is a test document.",
metadata={"source": "1"},
),
Document(
page_content="This is another document.",
metadata={"source": "2"},
),
Document(
page_content="This is yet another document.",
metadata={"source": None},
),
]
)
with pytest.raises(ValueError):
# Should raise an error because no source id function was specified
await aindex(
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
)
with pytest.raises(ValueError):
# Should raise an error because no source id function was specified
await aindex(
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
source_id_key="source",
)
def test_no_delete(
record_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
@ -332,6 +565,89 @@ def test_no_delete(
}
@pytest.mark.asyncio
@pytest.mark.requires("aiosqlite")
async def test_ano_delete(
arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
"""Test indexing without a deletion strategy."""
loader = ToyLoader(
documents=[
Document(
page_content="This is a test document.",
metadata={"source": "1"},
),
Document(
page_content="This is another document.",
metadata={"source": "2"},
),
]
)
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup=None,
source_id_key="source",
) == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
"num_updated": 0,
}
# If we add the same content twice it should be skipped
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup=None,
source_id_key="source",
) == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
"num_updated": 0,
}
loader = ToyLoader(
documents=[
Document(
page_content="mutated content",
metadata={"source": "1"},
),
Document(
page_content="This is another document.",
metadata={"source": "2"},
),
]
)
# Should result in no updates or deletions!
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup=None,
source_id_key="source",
) == {
"num_added": 1,
"num_deleted": 0,
"num_skipped": 1,
"num_updated": 0,
}
def test_incremental_delete(
record_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
@ -436,6 +752,112 @@ def test_incremental_delete(
}
@pytest.mark.asyncio
@pytest.mark.requires("aiosqlite")
async def test_aincremental_delete(
arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
"""Test indexing with incremental deletion strategy."""
loader = ToyLoader(
documents=[
Document(
page_content="This is a test document.",
metadata={"source": "1"},
),
Document(
page_content="This is another document.",
metadata={"source": "2"},
),
]
)
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
source_id_key="source",
) == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
"num_updated": 0,
}
doc_texts = set(
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore
for uid in vector_store.store
)
assert doc_texts == {"This is another document.", "This is a test document."}
# Attempt to index again verify that nothing changes
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
source_id_key="source",
) == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
"num_updated": 0,
}
# Create 2 documents from the same source all with mutated content
loader = ToyLoader(
documents=[
Document(
page_content="mutated document 1",
metadata={"source": "1"},
),
Document(
page_content="mutated document 2",
metadata={"source": "1"},
),
Document(
page_content="This is another document.", # <-- Same as original
metadata={"source": "2"},
),
]
)
# Attempt to index again verify that nothing changes
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 3).timestamp()
):
assert await aindex(
await loader.alazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
source_id_key="source",
) == {
"num_added": 2,
"num_deleted": 1,
"num_skipped": 1,
"num_updated": 0,
}
doc_texts = set(
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore
for uid in vector_store.store
)
assert doc_texts == {
"mutated document 1",
"mutated document 2",
"This is another document.",
}
def test_indexing_with_no_docs(
record_manager: SQLRecordManager, vector_store: VectorStore
) -> None:
@ -450,6 +872,24 @@ def test_indexing_with_no_docs(
}
@pytest.mark.asyncio
@pytest.mark.requires("aiosqlite")
async def test_aindexing_with_no_docs(
arecord_manager: SQLRecordManager, vector_store: VectorStore
) -> None:
"""Check edge case when loader returns no new docs."""
loader = ToyLoader(documents=[])
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 0,
"num_updated": 0,
}
def test_deduplication(
record_manager: SQLRecordManager, vector_store: VectorStore
) -> None:
@ -474,6 +914,32 @@ def test_deduplication(
}
@pytest.mark.asyncio
@pytest.mark.requires("aiosqlite")
async def test_adeduplication(
arecord_manager: SQLRecordManager, vector_store: VectorStore
) -> None:
"""Check edge case when loader returns no new docs."""
docs = [
Document(
page_content="This is a test document.",
metadata={"source": "1"},
),
Document(
page_content="This is a test document.",
metadata={"source": "1"},
),
]
# Should result in only a single document being added
assert await aindex(docs, arecord_manager, vector_store, cleanup="full") == {
"num_added": 1,
"num_deleted": 0,
"num_skipped": 0,
"num_updated": 0,
}
def test_cleanup_with_different_batchsize(
record_manager: SQLRecordManager, vector_store: VectorStore
) -> None:
@ -511,6 +977,45 @@ def test_cleanup_with_different_batchsize(
}
@pytest.mark.asyncio
@pytest.mark.requires("aiosqlite")
async def test_async_cleanup_with_different_batchsize(
arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
"""Check that we can clean up with different batch size."""
docs = [
Document(
page_content="This is a test document.",
metadata={"source": str(d)},
)
for d in range(1000)
]
assert await aindex(docs, arecord_manager, vector_store, cleanup="full") == {
"num_added": 1000,
"num_deleted": 0,
"num_skipped": 0,
"num_updated": 0,
}
docs = [
Document(
page_content="Different doc",
metadata={"source": str(d)},
)
for d in range(1001)
]
assert await aindex(
docs, arecord_manager, vector_store, cleanup="full", cleanup_batch_size=17
) == {
"num_added": 1001,
"num_deleted": 1000,
"num_skipped": 0,
"num_updated": 0,
}
def test_deduplication_v2(
record_manager: SQLRecordManager, vector_store: VectorStore
) -> None:
@ -547,3 +1052,29 @@ def test_deduplication_v2(
[document.page_content for document in vector_store.store.values()]
)
assert contents == ["1", "2", "3"]
async def _to_async_iter(it: Iterable[Any]) -> AsyncIterator[Any]:
"""Convert an iterable to an async iterator."""
for i in it:
yield i
@pytest.mark.asyncio
async def test_abatch() -> None:
"""Test the abatch function."""
batches = _abatch(5, _to_async_iter(range(12)))
assert isinstance(batches, AsyncIterator)
assert [batch async for batch in batches] == [
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11],
]
batches = _abatch(1, _to_async_iter(range(3)))
assert isinstance(batches, AsyncIterator)
assert [batch async for batch in batches] == [[0], [1], [2]]
batches = _abatch(2, _to_async_iter(range(5)))
assert isinstance(batches, AsyncIterator)
assert [batch async for batch in batches] == [[0, 1], [2, 3], [4]]

@ -2,6 +2,8 @@ from datetime import datetime
from unittest.mock import patch
import pytest
import pytest_asyncio
from sqlalchemy import select
from langchain.indexes._sql_record_manager import SQLRecordManager, UpsertionRecord
@ -15,6 +17,20 @@ def manager() -> SQLRecordManager:
return record_manager
@pytest_asyncio.fixture # type: ignore
@pytest.mark.requires("aiosqlite")
async def amanager() -> SQLRecordManager:
"""Initialize the test database and yield the TimestampedSet instance."""
# Initialize and yield the TimestampedSet instance
record_manager = SQLRecordManager(
"kittens",
db_url="sqlite+aiosqlite:///:memory:",
async_mode=True,
)
await record_manager.acreate_schema()
return record_manager
def test_update(manager: SQLRecordManager) -> None:
"""Test updating records in the database."""
# no keys should be present in the set
@ -28,6 +44,21 @@ def test_update(manager: SQLRecordManager) -> None:
assert read_keys == ["key1", "key2", "key3"]
@pytest.mark.asyncio
@pytest.mark.requires("aiosqlite")
async def test_aupdate(amanager: SQLRecordManager) -> None:
"""Test updating records in the database."""
# no keys should be present in the set
read_keys = await amanager.alist_keys()
assert read_keys == []
# Insert records
keys = ["key1", "key2", "key3"]
await amanager.aupdate(keys)
# Retrieve the records
read_keys = await amanager.alist_keys()
assert read_keys == ["key1", "key2", "key3"]
def test_update_timestamp(manager: SQLRecordManager) -> None:
"""Test updating records in the database."""
# no keys should be present in the set
@ -119,6 +150,117 @@ def test_update_timestamp(manager: SQLRecordManager) -> None:
]
@pytest.mark.asyncio
@pytest.mark.requires("aiosqlite")
async def test_aupdate_timestamp(amanager: SQLRecordManager) -> None:
"""Test updating records in the database."""
# no keys should be present in the set
with patch.object(
amanager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
await amanager.aupdate(["key1"])
async with amanager._amake_session() as session:
records = (
(
await session.execute(
select(UpsertionRecord).filter(
UpsertionRecord.namespace == amanager.namespace
)
)
)
.scalars()
.all()
)
assert [
{
"key": record.key,
"namespace": record.namespace,
"updated_at": record.updated_at,
"group_id": record.group_id,
}
for record in records
] == [
{
"group_id": None,
"key": "key1",
"namespace": "kittens",
"updated_at": datetime(2021, 1, 2, 0, 0).timestamp(),
}
]
with patch.object(
amanager, "aget_time", return_value=datetime(2023, 1, 2).timestamp()
):
await amanager.aupdate(["key1"])
async with amanager._amake_session() as session:
records = (
(
await session.execute(
select(UpsertionRecord).filter(
UpsertionRecord.namespace == amanager.namespace
)
)
)
.scalars()
.all()
)
assert [
{
"key": record.key,
"namespace": record.namespace,
"updated_at": record.updated_at,
"group_id": record.group_id,
}
for record in records
] == [
{
"group_id": None,
"key": "key1",
"namespace": "kittens",
"updated_at": datetime(2023, 1, 2, 0, 0).timestamp(),
}
]
with patch.object(
amanager, "aget_time", return_value=datetime(2023, 2, 2).timestamp()
):
await amanager.aupdate(["key1"], group_ids=["group1"])
async with amanager._amake_session() as session:
records = (
(
await session.execute(
select(UpsertionRecord).filter(
UpsertionRecord.namespace == amanager.namespace
)
)
)
.scalars()
.all()
)
assert [
{
"key": record.key,
"namespace": record.namespace,
"updated_at": record.updated_at,
"group_id": record.group_id,
}
for record in records
] == [
{
"group_id": "group1",
"key": "key1",
"namespace": "kittens",
"updated_at": datetime(2023, 2, 2, 0, 0).timestamp(),
}
]
def test_update_with_group_ids(manager: SQLRecordManager) -> None:
"""Test updating records in the database."""
# no keys should be present in the set
@ -132,6 +274,21 @@ def test_update_with_group_ids(manager: SQLRecordManager) -> None:
assert read_keys == ["key1", "key2", "key3"]
@pytest.mark.asyncio
@pytest.mark.requires("aiosqlite")
async def test_aupdate_with_group_ids(amanager: SQLRecordManager) -> None:
"""Test updating records in the database."""
# no keys should be present in the set
read_keys = await amanager.alist_keys()
assert read_keys == []
# Insert records
keys = ["key1", "key2", "key3"]
await amanager.aupdate(keys)
# Retrieve the records
read_keys = await amanager.alist_keys()
assert read_keys == ["key1", "key2", "key3"]
def test_exists(manager: SQLRecordManager) -> None:
"""Test checking if keys exist in the database."""
# Insert records
@ -147,6 +304,23 @@ def test_exists(manager: SQLRecordManager) -> None:
assert exists == [True, False]
@pytest.mark.asyncio
@pytest.mark.requires("aiosqlite")
async def test_aexists(amanager: SQLRecordManager) -> None:
"""Test checking if keys exist in the database."""
# Insert records
keys = ["key1", "key2", "key3"]
await amanager.aupdate(keys)
# Check if the keys exist in the database
exists = await amanager.aexists(keys)
assert len(exists) == len(keys)
assert exists == [True, True, True]
exists = await amanager.aexists(["key1", "key4"])
assert len(exists) == 2
assert exists == [True, False]
def test_list_keys(manager: SQLRecordManager) -> None:
"""Test listing keys based on the provided date range."""
# Insert records
@ -234,6 +408,98 @@ def test_list_keys(manager: SQLRecordManager) -> None:
) == ["key4"]
@pytest.mark.asyncio
@pytest.mark.requires("aiosqlite")
async def test_alist_keys(amanager: SQLRecordManager) -> None:
"""Test listing keys based on the provided date range."""
# Insert records
assert await amanager.alist_keys() == []
async with amanager._amake_session() as session:
# Add some keys with explicit updated_ats
session.add(
UpsertionRecord(
key="key1",
updated_at=datetime(2021, 1, 1).timestamp(),
namespace="kittens",
)
)
session.add(
UpsertionRecord(
key="key2",
updated_at=datetime(2022, 1, 1).timestamp(),
namespace="kittens",
)
)
session.add(
UpsertionRecord(
key="key3",
updated_at=datetime(2023, 1, 1).timestamp(),
namespace="kittens",
)
)
session.add(
UpsertionRecord(
key="key4",
group_id="group1",
updated_at=datetime(2024, 1, 1).timestamp(),
namespace="kittens",
)
)
# Insert keys from a different namespace, these should not be visible!
session.add(
UpsertionRecord(
key="key1",
updated_at=datetime(2021, 1, 1).timestamp(),
namespace="puppies",
)
)
session.add(
UpsertionRecord(
key="key5",
updated_at=datetime(2021, 1, 1).timestamp(),
namespace="puppies",
)
)
await session.commit()
# Retrieve all keys
assert await amanager.alist_keys() == ["key1", "key2", "key3", "key4"]
# Retrieve keys updated after a certain date
assert await amanager.alist_keys(after=datetime(2022, 2, 1).timestamp()) == [
"key3",
"key4",
]
# Retrieve keys updated after a certain date
assert await amanager.alist_keys(before=datetime(2022, 2, 1).timestamp()) == [
"key1",
"key2",
]
# Retrieve keys updated after a certain date
assert await amanager.alist_keys(before=datetime(2019, 2, 1).timestamp()) == []
# Retrieve keys in a time range
assert await amanager.alist_keys(
before=datetime(2022, 2, 1).timestamp(),
after=datetime(2021, 11, 1).timestamp(),
) == ["key2"]
assert await amanager.alist_keys(group_ids=["group1", "group2"]) == ["key4"]
# Test multiple filters
assert (
await amanager.alist_keys(
group_ids=["group1", "group2"], before=datetime(2019, 1, 1).timestamp()
)
== []
)
assert await amanager.alist_keys(
group_ids=["group1", "group2"], after=datetime(2019, 1, 1).timestamp()
) == ["key4"]
def test_namespace_is_used(manager: SQLRecordManager) -> None:
"""Verify that namespace is taken into account for all operations."""
assert manager.namespace == "kittens"
@ -261,6 +527,35 @@ def test_namespace_is_used(manager: SQLRecordManager) -> None:
]
@pytest.mark.asyncio
@pytest.mark.requires("aiosqlite")
async def test_anamespace_is_used(amanager: SQLRecordManager) -> None:
"""Verify that namespace is taken into account for all operations."""
assert amanager.namespace == "kittens"
async with amanager._amake_session() as session:
# Add some keys with explicit updated_ats
session.add(UpsertionRecord(key="key1", namespace="kittens"))
session.add(UpsertionRecord(key="key2", namespace="kittens"))
session.add(UpsertionRecord(key="key1", namespace="puppies"))
session.add(UpsertionRecord(key="key3", namespace="puppies"))
await session.commit()
assert await amanager.alist_keys() == ["key1", "key2"]
await amanager.adelete_keys(["key1"])
assert await amanager.alist_keys() == ["key2"]
await amanager.aupdate(["key3"], group_ids=["group3"])
async with amanager._amake_session() as session:
results = (await session.execute(select(UpsertionRecord))).scalars().all()
assert sorted([(r.namespace, r.key, r.group_id) for r in results]) == [
("kittens", "key2", None),
("kittens", "key3", "group3"),
("puppies", "key1", None),
("puppies", "key3", None),
]
def test_delete_keys(manager: SQLRecordManager) -> None:
"""Test deleting keys from the database."""
# Insert records
@ -274,3 +569,20 @@ def test_delete_keys(manager: SQLRecordManager) -> None:
# Check if the deleted keys are no longer in the database
remaining_keys = manager.list_keys()
assert remaining_keys == ["key3"]
@pytest.mark.asyncio
@pytest.mark.requires("aiosqlite")
async def test_adelete_keys(amanager: SQLRecordManager) -> None:
"""Test deleting keys from the database."""
# Insert records
keys = ["key1", "key2", "key3"]
await amanager.aupdate(keys)
# Delete some keys
keys_to_delete = ["key1", "key2"]
await amanager.adelete_keys(keys_to_delete)
# Check if the deleted keys are no longer in the database
remaining_keys = await amanager.alist_keys()
assert remaining_keys == ["key3"]

Loading…
Cancel
Save