mirror of https://github.com/hwchase17/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
585 lines
18 KiB
Python
585 lines
18 KiB
Python
from datetime import datetime
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from sqlalchemy import select
|
|
|
|
from langchain_community.indexes._sql_record_manager import (
|
|
SQLRecordManager,
|
|
UpsertionRecord,
|
|
)
|
|
|
|
|
|
@pytest.fixture()
|
|
def manager() -> SQLRecordManager:
|
|
"""Initialize the test database and yield the TimestampedSet instance."""
|
|
# Initialize and yield the TimestampedSet instance
|
|
record_manager = SQLRecordManager("kittens", db_url="sqlite:///:memory:")
|
|
record_manager.create_schema()
|
|
return record_manager
|
|
|
|
|
|
@pytest_asyncio.fixture # type: ignore
|
|
@pytest.mark.requires("aiosqlite")
|
|
async def amanager() -> SQLRecordManager:
|
|
"""Initialize the test database and yield the TimestampedSet instance."""
|
|
# Initialize and yield the TimestampedSet instance
|
|
record_manager = SQLRecordManager(
|
|
"kittens",
|
|
db_url="sqlite+aiosqlite:///:memory:",
|
|
async_mode=True,
|
|
)
|
|
await record_manager.acreate_schema()
|
|
return record_manager
|
|
|
|
|
|
def test_update(manager: SQLRecordManager) -> None:
|
|
"""Test updating records in the database."""
|
|
# no keys should be present in the set
|
|
read_keys = manager.list_keys()
|
|
assert read_keys == []
|
|
# Insert records
|
|
keys = ["key1", "key2", "key3"]
|
|
manager.update(keys)
|
|
# Retrieve the records
|
|
read_keys = manager.list_keys()
|
|
assert read_keys == ["key1", "key2", "key3"]
|
|
|
|
|
|
@pytest.mark.requires("aiosqlite")
|
|
async def test_aupdate(amanager: SQLRecordManager) -> None:
|
|
"""Test updating records in the database."""
|
|
# no keys should be present in the set
|
|
read_keys = await amanager.alist_keys()
|
|
assert read_keys == []
|
|
# Insert records
|
|
keys = ["key1", "key2", "key3"]
|
|
await amanager.aupdate(keys)
|
|
# Retrieve the records
|
|
read_keys = await amanager.alist_keys()
|
|
assert read_keys == ["key1", "key2", "key3"]
|
|
|
|
|
|
def test_update_timestamp(manager: SQLRecordManager) -> None:
|
|
"""Test updating records in the database."""
|
|
# no keys should be present in the set
|
|
with patch.object(
|
|
manager, "get_time", return_value=datetime(2021, 1, 2).timestamp()
|
|
):
|
|
manager.update(["key1"])
|
|
|
|
with manager._make_session() as session:
|
|
records = (
|
|
session.query(UpsertionRecord)
|
|
.filter(UpsertionRecord.namespace == manager.namespace)
|
|
.all() # type: ignore[attr-defined]
|
|
)
|
|
|
|
assert [
|
|
{
|
|
"key": record.key,
|
|
"namespace": record.namespace,
|
|
"updated_at": record.updated_at,
|
|
"group_id": record.group_id,
|
|
}
|
|
for record in records
|
|
] == [
|
|
{
|
|
"group_id": None,
|
|
"key": "key1",
|
|
"namespace": "kittens",
|
|
"updated_at": datetime(2021, 1, 2, 0, 0).timestamp(),
|
|
}
|
|
]
|
|
|
|
with patch.object(
|
|
manager, "get_time", return_value=datetime(2023, 1, 2).timestamp()
|
|
):
|
|
manager.update(["key1"])
|
|
|
|
with manager._make_session() as session:
|
|
records = (
|
|
session.query(UpsertionRecord)
|
|
.filter(UpsertionRecord.namespace == manager.namespace)
|
|
.all() # type: ignore[attr-defined]
|
|
)
|
|
|
|
assert [
|
|
{
|
|
"key": record.key,
|
|
"namespace": record.namespace,
|
|
"updated_at": record.updated_at,
|
|
"group_id": record.group_id,
|
|
}
|
|
for record in records
|
|
] == [
|
|
{
|
|
"group_id": None,
|
|
"key": "key1",
|
|
"namespace": "kittens",
|
|
"updated_at": datetime(2023, 1, 2, 0, 0).timestamp(),
|
|
}
|
|
]
|
|
|
|
with patch.object(
|
|
manager, "get_time", return_value=datetime(2023, 2, 2).timestamp()
|
|
):
|
|
manager.update(["key1"], group_ids=["group1"])
|
|
|
|
with manager._make_session() as session:
|
|
records = (
|
|
session.query(UpsertionRecord)
|
|
.filter(UpsertionRecord.namespace == manager.namespace)
|
|
.all() # type: ignore[attr-defined]
|
|
)
|
|
|
|
assert [
|
|
{
|
|
"key": record.key,
|
|
"namespace": record.namespace,
|
|
"updated_at": record.updated_at,
|
|
"group_id": record.group_id,
|
|
}
|
|
for record in records
|
|
] == [
|
|
{
|
|
"group_id": "group1",
|
|
"key": "key1",
|
|
"namespace": "kittens",
|
|
"updated_at": datetime(2023, 2, 2, 0, 0).timestamp(),
|
|
}
|
|
]
|
|
|
|
|
|
@pytest.mark.requires("aiosqlite")
|
|
async def test_aupdate_timestamp(amanager: SQLRecordManager) -> None:
|
|
"""Test updating records in the database."""
|
|
# no keys should be present in the set
|
|
with patch.object(
|
|
amanager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
|
):
|
|
await amanager.aupdate(["key1"])
|
|
|
|
async with amanager._amake_session() as session:
|
|
records = (
|
|
(
|
|
await session.execute(
|
|
select(UpsertionRecord).filter(
|
|
UpsertionRecord.namespace == amanager.namespace
|
|
)
|
|
)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
assert [
|
|
{
|
|
"key": record.key,
|
|
"namespace": record.namespace,
|
|
"updated_at": record.updated_at,
|
|
"group_id": record.group_id,
|
|
}
|
|
for record in records
|
|
] == [
|
|
{
|
|
"group_id": None,
|
|
"key": "key1",
|
|
"namespace": "kittens",
|
|
"updated_at": datetime(2021, 1, 2, 0, 0).timestamp(),
|
|
}
|
|
]
|
|
|
|
with patch.object(
|
|
amanager, "aget_time", return_value=datetime(2023, 1, 2).timestamp()
|
|
):
|
|
await amanager.aupdate(["key1"])
|
|
|
|
async with amanager._amake_session() as session:
|
|
records = (
|
|
(
|
|
await session.execute(
|
|
select(UpsertionRecord).filter(
|
|
UpsertionRecord.namespace == amanager.namespace
|
|
)
|
|
)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
assert [
|
|
{
|
|
"key": record.key,
|
|
"namespace": record.namespace,
|
|
"updated_at": record.updated_at,
|
|
"group_id": record.group_id,
|
|
}
|
|
for record in records
|
|
] == [
|
|
{
|
|
"group_id": None,
|
|
"key": "key1",
|
|
"namespace": "kittens",
|
|
"updated_at": datetime(2023, 1, 2, 0, 0).timestamp(),
|
|
}
|
|
]
|
|
|
|
with patch.object(
|
|
amanager, "aget_time", return_value=datetime(2023, 2, 2).timestamp()
|
|
):
|
|
await amanager.aupdate(["key1"], group_ids=["group1"])
|
|
|
|
async with amanager._amake_session() as session:
|
|
records = (
|
|
(
|
|
await session.execute(
|
|
select(UpsertionRecord).filter(
|
|
UpsertionRecord.namespace == amanager.namespace
|
|
)
|
|
)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
assert [
|
|
{
|
|
"key": record.key,
|
|
"namespace": record.namespace,
|
|
"updated_at": record.updated_at,
|
|
"group_id": record.group_id,
|
|
}
|
|
for record in records
|
|
] == [
|
|
{
|
|
"group_id": "group1",
|
|
"key": "key1",
|
|
"namespace": "kittens",
|
|
"updated_at": datetime(2023, 2, 2, 0, 0).timestamp(),
|
|
}
|
|
]
|
|
|
|
|
|
def test_update_with_group_ids(manager: SQLRecordManager) -> None:
|
|
"""Test updating records in the database."""
|
|
# no keys should be present in the set
|
|
read_keys = manager.list_keys()
|
|
assert read_keys == []
|
|
# Insert records
|
|
keys = ["key1", "key2", "key3"]
|
|
manager.update(keys)
|
|
# Retrieve the records
|
|
read_keys = manager.list_keys()
|
|
assert read_keys == ["key1", "key2", "key3"]
|
|
|
|
|
|
@pytest.mark.requires("aiosqlite")
|
|
async def test_aupdate_with_group_ids(amanager: SQLRecordManager) -> None:
|
|
"""Test updating records in the database."""
|
|
# no keys should be present in the set
|
|
read_keys = await amanager.alist_keys()
|
|
assert read_keys == []
|
|
# Insert records
|
|
keys = ["key1", "key2", "key3"]
|
|
await amanager.aupdate(keys)
|
|
# Retrieve the records
|
|
read_keys = await amanager.alist_keys()
|
|
assert read_keys == ["key1", "key2", "key3"]
|
|
|
|
|
|
def test_exists(manager: SQLRecordManager) -> None:
|
|
"""Test checking if keys exist in the database."""
|
|
# Insert records
|
|
keys = ["key1", "key2", "key3"]
|
|
manager.update(keys)
|
|
# Check if the keys exist in the database
|
|
exists = manager.exists(keys)
|
|
assert len(exists) == len(keys)
|
|
assert exists == [True, True, True]
|
|
|
|
exists = manager.exists(["key1", "key4"])
|
|
assert len(exists) == 2
|
|
assert exists == [True, False]
|
|
|
|
|
|
@pytest.mark.requires("aiosqlite")
|
|
async def test_aexists(amanager: SQLRecordManager) -> None:
|
|
"""Test checking if keys exist in the database."""
|
|
# Insert records
|
|
keys = ["key1", "key2", "key3"]
|
|
await amanager.aupdate(keys)
|
|
# Check if the keys exist in the database
|
|
exists = await amanager.aexists(keys)
|
|
assert len(exists) == len(keys)
|
|
assert exists == [True, True, True]
|
|
|
|
exists = await amanager.aexists(["key1", "key4"])
|
|
assert len(exists) == 2
|
|
assert exists == [True, False]
|
|
|
|
|
|
def test_list_keys(manager: SQLRecordManager) -> None:
|
|
"""Test listing keys based on the provided date range."""
|
|
# Insert records
|
|
assert manager.list_keys() == []
|
|
with manager._make_session() as session:
|
|
# Add some keys with explicit updated_ats
|
|
session.add(
|
|
UpsertionRecord(
|
|
key="key1",
|
|
updated_at=datetime(2021, 1, 1).timestamp(),
|
|
namespace="kittens",
|
|
)
|
|
)
|
|
session.add(
|
|
UpsertionRecord(
|
|
key="key2",
|
|
updated_at=datetime(2022, 1, 1).timestamp(),
|
|
namespace="kittens",
|
|
)
|
|
)
|
|
session.add(
|
|
UpsertionRecord(
|
|
key="key3",
|
|
updated_at=datetime(2023, 1, 1).timestamp(),
|
|
namespace="kittens",
|
|
)
|
|
)
|
|
session.add(
|
|
UpsertionRecord(
|
|
key="key4",
|
|
group_id="group1",
|
|
updated_at=datetime(2024, 1, 1).timestamp(),
|
|
namespace="kittens",
|
|
)
|
|
)
|
|
# Insert keys from a different namespace, these should not be visible!
|
|
session.add(
|
|
UpsertionRecord(
|
|
key="key1",
|
|
updated_at=datetime(2021, 1, 1).timestamp(),
|
|
namespace="puppies",
|
|
)
|
|
)
|
|
session.add(
|
|
UpsertionRecord(
|
|
key="key5",
|
|
updated_at=datetime(2021, 1, 1).timestamp(),
|
|
namespace="puppies",
|
|
)
|
|
)
|
|
session.commit()
|
|
|
|
# Retrieve all keys
|
|
assert manager.list_keys() == ["key1", "key2", "key3", "key4"]
|
|
|
|
# Retrieve keys updated after a certain date
|
|
assert manager.list_keys(after=datetime(2022, 2, 1).timestamp()) == ["key3", "key4"]
|
|
|
|
# Retrieve keys updated after a certain date
|
|
assert manager.list_keys(before=datetime(2022, 2, 1).timestamp()) == [
|
|
"key1",
|
|
"key2",
|
|
]
|
|
|
|
# Retrieve keys updated after a certain date
|
|
assert manager.list_keys(before=datetime(2019, 2, 1).timestamp()) == []
|
|
|
|
# Retrieve keys in a time range
|
|
assert manager.list_keys(
|
|
before=datetime(2022, 2, 1).timestamp(),
|
|
after=datetime(2021, 11, 1).timestamp(),
|
|
) == ["key2"]
|
|
|
|
assert manager.list_keys(group_ids=["group1", "group2"]) == ["key4"]
|
|
|
|
# Test multiple filters
|
|
assert (
|
|
manager.list_keys(
|
|
group_ids=["group1", "group2"], before=datetime(2019, 1, 1).timestamp()
|
|
)
|
|
== []
|
|
)
|
|
assert manager.list_keys(
|
|
group_ids=["group1", "group2"], after=datetime(2019, 1, 1).timestamp()
|
|
) == ["key4"]
|
|
|
|
|
|
@pytest.mark.requires("aiosqlite")
|
|
async def test_alist_keys(amanager: SQLRecordManager) -> None:
|
|
"""Test listing keys based on the provided date range."""
|
|
# Insert records
|
|
assert await amanager.alist_keys() == []
|
|
async with amanager._amake_session() as session:
|
|
# Add some keys with explicit updated_ats
|
|
session.add(
|
|
UpsertionRecord(
|
|
key="key1",
|
|
updated_at=datetime(2021, 1, 1).timestamp(),
|
|
namespace="kittens",
|
|
)
|
|
)
|
|
session.add(
|
|
UpsertionRecord(
|
|
key="key2",
|
|
updated_at=datetime(2022, 1, 1).timestamp(),
|
|
namespace="kittens",
|
|
)
|
|
)
|
|
session.add(
|
|
UpsertionRecord(
|
|
key="key3",
|
|
updated_at=datetime(2023, 1, 1).timestamp(),
|
|
namespace="kittens",
|
|
)
|
|
)
|
|
session.add(
|
|
UpsertionRecord(
|
|
key="key4",
|
|
group_id="group1",
|
|
updated_at=datetime(2024, 1, 1).timestamp(),
|
|
namespace="kittens",
|
|
)
|
|
)
|
|
# Insert keys from a different namespace, these should not be visible!
|
|
session.add(
|
|
UpsertionRecord(
|
|
key="key1",
|
|
updated_at=datetime(2021, 1, 1).timestamp(),
|
|
namespace="puppies",
|
|
)
|
|
)
|
|
session.add(
|
|
UpsertionRecord(
|
|
key="key5",
|
|
updated_at=datetime(2021, 1, 1).timestamp(),
|
|
namespace="puppies",
|
|
)
|
|
)
|
|
await session.commit()
|
|
|
|
# Retrieve all keys
|
|
assert await amanager.alist_keys() == ["key1", "key2", "key3", "key4"]
|
|
|
|
# Retrieve keys updated after a certain date
|
|
assert await amanager.alist_keys(after=datetime(2022, 2, 1).timestamp()) == [
|
|
"key3",
|
|
"key4",
|
|
]
|
|
|
|
# Retrieve keys updated after a certain date
|
|
assert await amanager.alist_keys(before=datetime(2022, 2, 1).timestamp()) == [
|
|
"key1",
|
|
"key2",
|
|
]
|
|
|
|
# Retrieve keys updated after a certain date
|
|
assert await amanager.alist_keys(before=datetime(2019, 2, 1).timestamp()) == []
|
|
|
|
# Retrieve keys in a time range
|
|
assert await amanager.alist_keys(
|
|
before=datetime(2022, 2, 1).timestamp(),
|
|
after=datetime(2021, 11, 1).timestamp(),
|
|
) == ["key2"]
|
|
|
|
assert await amanager.alist_keys(group_ids=["group1", "group2"]) == ["key4"]
|
|
|
|
# Test multiple filters
|
|
assert (
|
|
await amanager.alist_keys(
|
|
group_ids=["group1", "group2"], before=datetime(2019, 1, 1).timestamp()
|
|
)
|
|
== []
|
|
)
|
|
assert await amanager.alist_keys(
|
|
group_ids=["group1", "group2"], after=datetime(2019, 1, 1).timestamp()
|
|
) == ["key4"]
|
|
|
|
|
|
def test_namespace_is_used(manager: SQLRecordManager) -> None:
|
|
"""Verify that namespace is taken into account for all operations."""
|
|
assert manager.namespace == "kittens"
|
|
with manager._make_session() as session:
|
|
# Add some keys with explicit updated_ats
|
|
session.add(UpsertionRecord(key="key1", namespace="kittens"))
|
|
session.add(UpsertionRecord(key="key2", namespace="kittens"))
|
|
session.add(UpsertionRecord(key="key1", namespace="puppies"))
|
|
session.add(UpsertionRecord(key="key3", namespace="puppies"))
|
|
session.commit()
|
|
|
|
assert manager.list_keys() == ["key1", "key2"]
|
|
manager.delete_keys(["key1"])
|
|
assert manager.list_keys() == ["key2"]
|
|
manager.update(["key3"], group_ids=["group3"])
|
|
|
|
with manager._make_session() as session:
|
|
results = session.query(UpsertionRecord).all()
|
|
|
|
assert sorted([(r.namespace, r.key, r.group_id) for r in results]) == [
|
|
("kittens", "key2", None),
|
|
("kittens", "key3", "group3"),
|
|
("puppies", "key1", None),
|
|
("puppies", "key3", None),
|
|
]
|
|
|
|
|
|
@pytest.mark.requires("aiosqlite")
|
|
async def test_anamespace_is_used(amanager: SQLRecordManager) -> None:
|
|
"""Verify that namespace is taken into account for all operations."""
|
|
assert amanager.namespace == "kittens"
|
|
async with amanager._amake_session() as session:
|
|
# Add some keys with explicit updated_ats
|
|
session.add(UpsertionRecord(key="key1", namespace="kittens"))
|
|
session.add(UpsertionRecord(key="key2", namespace="kittens"))
|
|
session.add(UpsertionRecord(key="key1", namespace="puppies"))
|
|
session.add(UpsertionRecord(key="key3", namespace="puppies"))
|
|
await session.commit()
|
|
|
|
assert await amanager.alist_keys() == ["key1", "key2"]
|
|
await amanager.adelete_keys(["key1"])
|
|
assert await amanager.alist_keys() == ["key2"]
|
|
await amanager.aupdate(["key3"], group_ids=["group3"])
|
|
|
|
async with amanager._amake_session() as session:
|
|
results = (await session.execute(select(UpsertionRecord))).scalars().all()
|
|
|
|
assert sorted([(r.namespace, r.key, r.group_id) for r in results]) == [
|
|
("kittens", "key2", None),
|
|
("kittens", "key3", "group3"),
|
|
("puppies", "key1", None),
|
|
("puppies", "key3", None),
|
|
]
|
|
|
|
|
|
def test_delete_keys(manager: SQLRecordManager) -> None:
|
|
"""Test deleting keys from the database."""
|
|
# Insert records
|
|
keys = ["key1", "key2", "key3"]
|
|
manager.update(keys)
|
|
|
|
# Delete some keys
|
|
keys_to_delete = ["key1", "key2"]
|
|
manager.delete_keys(keys_to_delete)
|
|
|
|
# Check if the deleted keys are no longer in the database
|
|
remaining_keys = manager.list_keys()
|
|
assert remaining_keys == ["key3"]
|
|
|
|
|
|
@pytest.mark.requires("aiosqlite")
|
|
async def test_adelete_keys(amanager: SQLRecordManager) -> None:
|
|
"""Test deleting keys from the database."""
|
|
# Insert records
|
|
keys = ["key1", "key2", "key3"]
|
|
await amanager.aupdate(keys)
|
|
|
|
# Delete some keys
|
|
keys_to_delete = ["key1", "key2"]
|
|
await amanager.adelete_keys(keys_to_delete)
|
|
|
|
# Check if the deleted keys are no longer in the database
|
|
remaining_keys = await amanager.alist_keys()
|
|
assert remaining_keys == ["key3"]
|