You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/indexes/_sql_record_manager.py

523 lines
20 KiB
Python

"""Implementation of a record management layer in SQLAlchemy.
The management layer uses SQLAlchemy to track upserted records.
Currently, this layer only works with SQLite; hopwever, should be adaptable
to other SQL implementations with minimal effort.
Currently, includes an implementation that uses SQLAlchemy which should
allow it to work with a variety of SQL as a backend.
* Each key is associated with an updated_at field.
* This filed is updated whenever the key is updated.
* Keys can be listed based on the updated at field.
* Keys can be deleted.
"""
import contextlib
import decimal
import uuid
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence, Union
from sqlalchemy import (
URL,
Column,
Engine,
Float,
Index,
String,
UniqueConstraint,
and_,
create_engine,
delete,
select,
text,
)
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker
from langchain_community.indexes.base import RecordManager
Base = declarative_base()
class UpsertionRecord(Base): # type: ignore[valid-type,misc]
"""Table used to keep track of when a key was last updated."""
# ATTENTION:
# Prior to modifying this table, please determine whether
# we should create migrations for this table to make sure
# users do not experience data loss.
__tablename__ = "upsertion_record"
uuid = Column(
String,
index=True,
default=lambda: str(uuid.uuid4()),
primary_key=True,
nullable=False,
)
key = Column(String, index=True)
# Using a non-normalized representation to handle `namespace` attribute.
# If the need arises, this attribute can be pulled into a separate Collection
# table at some time later.
namespace = Column(String, index=True, nullable=False)
group_id = Column(String, index=True, nullable=True)
# The timestamp associated with the last record upsertion.
updated_at = Column(Float, index=True)
__table_args__ = (
UniqueConstraint("key", "namespace", name="uix_key_namespace"),
Index("ix_key_namespace", "key", "namespace"),
)
class SQLRecordManager(RecordManager):
"""A SQL Alchemy based implementation of the record manager."""
def __init__(
self,
namespace: str,
*,
engine: Optional[Union[Engine, AsyncEngine]] = None,
db_url: Union[None, str, URL] = None,
engine_kwargs: Optional[Dict[str, Any]] = None,
async_mode: bool = False,
) -> None:
"""Initialize the SQLRecordManager.
This class serves as a manager persistence layer that uses an SQL
backend to track upserted records. You should specify either a db_url
to create an engine or provide an existing engine.
Args:
namespace: The namespace associated with this record manager.
engine: An already existing SQL Alchemy engine.
Default is None.
db_url: A database connection string used to create
an SQL Alchemy engine. Default is None.
engine_kwargs: Additional keyword arguments
to be passed when creating the engine. Default is an empty dictionary.
async_mode: Whether to create an async engine.
Driver should support async operations.
It only applies if db_url is provided.
Default is False.
Raises:
ValueError: If both db_url and engine are provided or neither.
AssertionError: If something unexpected happens during engine configuration.
"""
super().__init__(namespace=namespace)
if db_url is None and engine is None:
raise ValueError("Must specify either db_url or engine")
if db_url is not None and engine is not None:
raise ValueError("Must specify either db_url or engine, not both")
_engine: Union[Engine, AsyncEngine]
if db_url:
if async_mode:
_engine = create_async_engine(db_url, **(engine_kwargs or {}))
else:
_engine = create_engine(db_url, **(engine_kwargs or {}))
elif engine:
_engine = engine
else:
raise AssertionError("Something went wrong with configuration of engine.")
_session_factory: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]
if isinstance(_engine, AsyncEngine):
_session_factory = async_sessionmaker(bind=_engine)
else:
_session_factory = sessionmaker(bind=_engine)
self.engine = _engine
self.dialect = _engine.dialect.name
self.session_factory = _session_factory
def create_schema(self) -> None:
"""Create the database schema."""
if isinstance(self.engine, AsyncEngine):
raise AssertionError("This method is not supported for async engines.")
Base.metadata.create_all(self.engine)
async def acreate_schema(self) -> None:
"""Create the database schema."""
if not isinstance(self.engine, AsyncEngine):
raise AssertionError("This method is not supported for sync engines.")
async with self.engine.begin() as session:
await session.run_sync(Base.metadata.create_all)
@contextlib.contextmanager
def _make_session(self) -> Generator[Session, None, None]:
"""Create a session and close it after use."""
if isinstance(self.session_factory, async_sessionmaker):
raise AssertionError("This method is not supported for async engines.")
session = self.session_factory()
try:
yield session
finally:
session.close()
@contextlib.asynccontextmanager
async def _amake_session(self) -> AsyncGenerator[AsyncSession, None]:
"""Create a session and close it after use."""
if not isinstance(self.session_factory, async_sessionmaker):
raise AssertionError("This method is not supported for sync engines.")
async with self.session_factory() as session:
yield session
def get_time(self) -> float:
"""Get the current server time as a timestamp.
Please note it's critical that time is obtained from the server since
we want a monotonic clock.
"""
with self._make_session() as session:
# * SQLite specific implementation, can be changed based on dialect.
# * For SQLite, unlike unixepoch it will work with older versions of SQLite.
# ----
# julianday('now'): Julian day number for the current date and time.
# The Julian day is a continuous count of days, starting from a
# reference date (Julian day number 0).
# 2440587.5 - constant represents the Julian day number for January 1, 1970
# 86400.0 - constant represents the number of seconds
# in a day (24 hours * 60 minutes * 60 seconds)
if self.dialect == "sqlite":
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
elif self.dialect == "postgresql":
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
else:
raise NotImplementedError(f"Not implemented for dialect {self.dialect}")
dt = session.execute(query).scalar()
if isinstance(dt, decimal.Decimal):
dt = float(dt)
if not isinstance(dt, float):
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
return dt
async def aget_time(self) -> float:
"""Get the current server time as a timestamp.
Please note it's critical that time is obtained from the server since
we want a monotonic clock.
"""
async with self._amake_session() as session:
# * SQLite specific implementation, can be changed based on dialect.
# * For SQLite, unlike unixepoch it will work with older versions of SQLite.
# ----
# julianday('now'): Julian day number for the current date and time.
# The Julian day is a continuous count of days, starting from a
# reference date (Julian day number 0).
# 2440587.5 - constant represents the Julian day number for January 1, 1970
# 86400.0 - constant represents the number of seconds
# in a day (24 hours * 60 minutes * 60 seconds)
if self.dialect == "sqlite":
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
elif self.dialect == "postgresql":
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
else:
raise NotImplementedError(f"Not implemented for dialect {self.dialect}")
dt = (await session.execute(query)).scalar_one_or_none()
if isinstance(dt, decimal.Decimal):
dt = float(dt)
if not isinstance(dt, float):
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
return dt
def update(
self,
keys: Sequence[str],
*,
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
"""Upsert records into the SQLite database."""
if group_ids is None:
group_ids = [None] * len(keys)
if len(keys) != len(group_ids):
raise ValueError(
f"Number of keys ({len(keys)}) does not match number of "
f"group_ids ({len(group_ids)})"
)
# Get the current time from the server.
# This makes an extra round trip to the server, should not be a big deal
# if the batch size is large enough.
# Getting the time here helps us compare it against the time_at_least
# and raise an error if there is a time sync issue.
# Here, we're just being extra careful to minimize the chance of
# data loss due to incorrectly deleting records.
update_time = self.get_time()
if time_at_least and update_time < time_at_least:
# Safeguard against time sync issues
raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}")
records_to_upsert = [
{
"key": key,
"namespace": self.namespace,
"updated_at": update_time,
"group_id": group_id,
}
for key, group_id in zip(keys, group_ids)
]
with self._make_session() as session:
if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
),
)
elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import insert as pg_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
"uix_key_namespace", # Name of constraint
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
),
)
else:
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
session.execute(stmt)
session.commit()
async def aupdate(
self,
keys: Sequence[str],
*,
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
"""Upsert records into the SQLite database."""
if group_ids is None:
group_ids = [None] * len(keys)
if len(keys) != len(group_ids):
raise ValueError(
f"Number of keys ({len(keys)}) does not match number of "
f"group_ids ({len(group_ids)})"
)
# Get the current time from the server.
# This makes an extra round trip to the server, should not be a big deal
# if the batch size is large enough.
# Getting the time here helps us compare it against the time_at_least
# and raise an error if there is a time sync issue.
# Here, we're just being extra careful to minimize the chance of
# data loss due to incorrectly deleting records.
update_time = await self.aget_time()
if time_at_least and update_time < time_at_least:
# Safeguard against time sync issues
raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}")
records_to_upsert = [
{
"key": key,
"namespace": self.namespace,
"updated_at": update_time,
"group_id": group_id,
}
for key, group_id in zip(keys, group_ids)
]
async with self._amake_session() as session:
if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
),
)
elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import insert as pg_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
"uix_key_namespace", # Name of constraint
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
),
)
else:
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
await session.execute(stmt)
await session.commit()
def exists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the given keys exist in the SQLite database."""
with self._make_session() as session:
records = (
# mypy does not recognize .all()
session.query(UpsertionRecord.key) # type: ignore[attr-defined]
.filter(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
)
)
.all()
)
found_keys = set(r.key for r in records)
return [k in found_keys for k in keys]
async def aexists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the given keys exist in the SQLite database."""
async with self._amake_session() as session:
records = (
(
await session.execute(
select(UpsertionRecord.key).where(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
)
)
)
)
.scalars()
.all()
)
found_keys = set(records)
return [k in found_keys for k in keys]
def list_keys(
self,
*,
before: Optional[float] = None,
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
"""List records in the SQLite database based on the provided date range."""
with self._make_session() as session:
query = session.query(UpsertionRecord).filter(
UpsertionRecord.namespace == self.namespace
)
# mypy does not recognize .all() or .filter()
if after:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at > after
)
if before:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at < before
)
if group_ids:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.group_id.in_(group_ids)
)
if limit:
query = query.limit(limit) # type: ignore[attr-defined]
records = query.all() # type: ignore[attr-defined]
return [r.key for r in records]
async def alist_keys(
self,
*,
before: Optional[float] = None,
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
"""List records in the SQLite database based on the provided date range."""
async with self._amake_session() as session:
query = select(UpsertionRecord.key).filter(
UpsertionRecord.namespace == self.namespace
)
# mypy does not recognize .all() or .filter()
if after:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at > after
)
if before:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at < before
)
if group_ids:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.group_id.in_(group_ids)
)
if limit:
query = query.limit(limit) # type: ignore[attr-defined]
records = (await session.execute(query)).scalars().all()
return list(records)
def delete_keys(self, keys: Sequence[str]) -> None:
"""Delete records from the SQLite database."""
with self._make_session() as session:
# mypy does not recognize .delete()
session.query(UpsertionRecord).filter(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
)
).delete() # type: ignore[attr-defined]
session.commit()
async def adelete_keys(self, keys: Sequence[str]) -> None:
"""Delete records from the SQLite database."""
async with self._amake_session() as session:
await session.execute(
delete(UpsertionRecord).where(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
)
)
)
await session.commit()