mirror of https://github.com/hwchase17/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
523 lines
20 KiB
Python
523 lines
20 KiB
Python
"""Implementation of a record management layer in SQLAlchemy.
|
|
|
|
The management layer uses SQLAlchemy to track upserted records.
|
|
|
|
Currently, this layer only works with SQLite; hopwever, should be adaptable
|
|
to other SQL implementations with minimal effort.
|
|
|
|
Currently, includes an implementation that uses SQLAlchemy which should
|
|
allow it to work with a variety of SQL as a backend.
|
|
|
|
* Each key is associated with an updated_at field.
|
|
* This filed is updated whenever the key is updated.
|
|
* Keys can be listed based on the updated at field.
|
|
* Keys can be deleted.
|
|
"""
|
|
import contextlib
|
|
import decimal
|
|
import uuid
|
|
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence, Union
|
|
|
|
from sqlalchemy import (
|
|
URL,
|
|
Column,
|
|
Engine,
|
|
Float,
|
|
Index,
|
|
String,
|
|
UniqueConstraint,
|
|
and_,
|
|
create_engine,
|
|
delete,
|
|
select,
|
|
text,
|
|
)
|
|
from sqlalchemy.ext.asyncio import (
|
|
AsyncEngine,
|
|
AsyncSession,
|
|
async_sessionmaker,
|
|
create_async_engine,
|
|
)
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from langchain_community.indexes.base import RecordManager
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class UpsertionRecord(Base): # type: ignore[valid-type,misc]
|
|
"""Table used to keep track of when a key was last updated."""
|
|
|
|
# ATTENTION:
|
|
# Prior to modifying this table, please determine whether
|
|
# we should create migrations for this table to make sure
|
|
# users do not experience data loss.
|
|
__tablename__ = "upsertion_record"
|
|
|
|
uuid = Column(
|
|
String,
|
|
index=True,
|
|
default=lambda: str(uuid.uuid4()),
|
|
primary_key=True,
|
|
nullable=False,
|
|
)
|
|
key = Column(String, index=True)
|
|
# Using a non-normalized representation to handle `namespace` attribute.
|
|
# If the need arises, this attribute can be pulled into a separate Collection
|
|
# table at some time later.
|
|
namespace = Column(String, index=True, nullable=False)
|
|
group_id = Column(String, index=True, nullable=True)
|
|
|
|
# The timestamp associated with the last record upsertion.
|
|
updated_at = Column(Float, index=True)
|
|
|
|
__table_args__ = (
|
|
UniqueConstraint("key", "namespace", name="uix_key_namespace"),
|
|
Index("ix_key_namespace", "key", "namespace"),
|
|
)
|
|
|
|
|
|
class SQLRecordManager(RecordManager):
|
|
"""A SQL Alchemy based implementation of the record manager."""
|
|
|
|
def __init__(
|
|
self,
|
|
namespace: str,
|
|
*,
|
|
engine: Optional[Union[Engine, AsyncEngine]] = None,
|
|
db_url: Union[None, str, URL] = None,
|
|
engine_kwargs: Optional[Dict[str, Any]] = None,
|
|
async_mode: bool = False,
|
|
) -> None:
|
|
"""Initialize the SQLRecordManager.
|
|
|
|
This class serves as a manager persistence layer that uses an SQL
|
|
backend to track upserted records. You should specify either a db_url
|
|
to create an engine or provide an existing engine.
|
|
|
|
Args:
|
|
namespace: The namespace associated with this record manager.
|
|
engine: An already existing SQL Alchemy engine.
|
|
Default is None.
|
|
db_url: A database connection string used to create
|
|
an SQL Alchemy engine. Default is None.
|
|
engine_kwargs: Additional keyword arguments
|
|
to be passed when creating the engine. Default is an empty dictionary.
|
|
async_mode: Whether to create an async engine.
|
|
Driver should support async operations.
|
|
It only applies if db_url is provided.
|
|
Default is False.
|
|
|
|
Raises:
|
|
ValueError: If both db_url and engine are provided or neither.
|
|
AssertionError: If something unexpected happens during engine configuration.
|
|
"""
|
|
super().__init__(namespace=namespace)
|
|
if db_url is None and engine is None:
|
|
raise ValueError("Must specify either db_url or engine")
|
|
|
|
if db_url is not None and engine is not None:
|
|
raise ValueError("Must specify either db_url or engine, not both")
|
|
|
|
_engine: Union[Engine, AsyncEngine]
|
|
if db_url:
|
|
if async_mode:
|
|
_engine = create_async_engine(db_url, **(engine_kwargs or {}))
|
|
else:
|
|
_engine = create_engine(db_url, **(engine_kwargs or {}))
|
|
elif engine:
|
|
_engine = engine
|
|
|
|
else:
|
|
raise AssertionError("Something went wrong with configuration of engine.")
|
|
|
|
_session_factory: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]
|
|
if isinstance(_engine, AsyncEngine):
|
|
_session_factory = async_sessionmaker(bind=_engine)
|
|
else:
|
|
_session_factory = sessionmaker(bind=_engine)
|
|
|
|
self.engine = _engine
|
|
self.dialect = _engine.dialect.name
|
|
self.session_factory = _session_factory
|
|
|
|
def create_schema(self) -> None:
|
|
"""Create the database schema."""
|
|
if isinstance(self.engine, AsyncEngine):
|
|
raise AssertionError("This method is not supported for async engines.")
|
|
|
|
Base.metadata.create_all(self.engine)
|
|
|
|
async def acreate_schema(self) -> None:
|
|
"""Create the database schema."""
|
|
|
|
if not isinstance(self.engine, AsyncEngine):
|
|
raise AssertionError("This method is not supported for sync engines.")
|
|
|
|
async with self.engine.begin() as session:
|
|
await session.run_sync(Base.metadata.create_all)
|
|
|
|
@contextlib.contextmanager
|
|
def _make_session(self) -> Generator[Session, None, None]:
|
|
"""Create a session and close it after use."""
|
|
|
|
if isinstance(self.session_factory, async_sessionmaker):
|
|
raise AssertionError("This method is not supported for async engines.")
|
|
|
|
session = self.session_factory()
|
|
try:
|
|
yield session
|
|
finally:
|
|
session.close()
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def _amake_session(self) -> AsyncGenerator[AsyncSession, None]:
|
|
"""Create a session and close it after use."""
|
|
|
|
if not isinstance(self.session_factory, async_sessionmaker):
|
|
raise AssertionError("This method is not supported for sync engines.")
|
|
|
|
async with self.session_factory() as session:
|
|
yield session
|
|
|
|
def get_time(self) -> float:
|
|
"""Get the current server time as a timestamp.
|
|
|
|
Please note it's critical that time is obtained from the server since
|
|
we want a monotonic clock.
|
|
"""
|
|
with self._make_session() as session:
|
|
# * SQLite specific implementation, can be changed based on dialect.
|
|
# * For SQLite, unlike unixepoch it will work with older versions of SQLite.
|
|
# ----
|
|
# julianday('now'): Julian day number for the current date and time.
|
|
# The Julian day is a continuous count of days, starting from a
|
|
# reference date (Julian day number 0).
|
|
# 2440587.5 - constant represents the Julian day number for January 1, 1970
|
|
# 86400.0 - constant represents the number of seconds
|
|
# in a day (24 hours * 60 minutes * 60 seconds)
|
|
if self.dialect == "sqlite":
|
|
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
|
|
elif self.dialect == "postgresql":
|
|
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
|
|
else:
|
|
raise NotImplementedError(f"Not implemented for dialect {self.dialect}")
|
|
|
|
dt = session.execute(query).scalar()
|
|
if isinstance(dt, decimal.Decimal):
|
|
dt = float(dt)
|
|
if not isinstance(dt, float):
|
|
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
|
|
return dt
|
|
|
|
async def aget_time(self) -> float:
|
|
"""Get the current server time as a timestamp.
|
|
|
|
Please note it's critical that time is obtained from the server since
|
|
we want a monotonic clock.
|
|
"""
|
|
async with self._amake_session() as session:
|
|
# * SQLite specific implementation, can be changed based on dialect.
|
|
# * For SQLite, unlike unixepoch it will work with older versions of SQLite.
|
|
# ----
|
|
# julianday('now'): Julian day number for the current date and time.
|
|
# The Julian day is a continuous count of days, starting from a
|
|
# reference date (Julian day number 0).
|
|
# 2440587.5 - constant represents the Julian day number for January 1, 1970
|
|
# 86400.0 - constant represents the number of seconds
|
|
# in a day (24 hours * 60 minutes * 60 seconds)
|
|
if self.dialect == "sqlite":
|
|
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
|
|
elif self.dialect == "postgresql":
|
|
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
|
|
else:
|
|
raise NotImplementedError(f"Not implemented for dialect {self.dialect}")
|
|
|
|
dt = (await session.execute(query)).scalar_one_or_none()
|
|
|
|
if isinstance(dt, decimal.Decimal):
|
|
dt = float(dt)
|
|
if not isinstance(dt, float):
|
|
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
|
|
return dt
|
|
|
|
def update(
|
|
self,
|
|
keys: Sequence[str],
|
|
*,
|
|
group_ids: Optional[Sequence[Optional[str]]] = None,
|
|
time_at_least: Optional[float] = None,
|
|
) -> None:
|
|
"""Upsert records into the SQLite database."""
|
|
if group_ids is None:
|
|
group_ids = [None] * len(keys)
|
|
|
|
if len(keys) != len(group_ids):
|
|
raise ValueError(
|
|
f"Number of keys ({len(keys)}) does not match number of "
|
|
f"group_ids ({len(group_ids)})"
|
|
)
|
|
|
|
# Get the current time from the server.
|
|
# This makes an extra round trip to the server, should not be a big deal
|
|
# if the batch size is large enough.
|
|
# Getting the time here helps us compare it against the time_at_least
|
|
# and raise an error if there is a time sync issue.
|
|
# Here, we're just being extra careful to minimize the chance of
|
|
# data loss due to incorrectly deleting records.
|
|
update_time = self.get_time()
|
|
|
|
if time_at_least and update_time < time_at_least:
|
|
# Safeguard against time sync issues
|
|
raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}")
|
|
|
|
records_to_upsert = [
|
|
{
|
|
"key": key,
|
|
"namespace": self.namespace,
|
|
"updated_at": update_time,
|
|
"group_id": group_id,
|
|
}
|
|
for key, group_id in zip(keys, group_ids)
|
|
]
|
|
|
|
with self._make_session() as session:
|
|
if self.dialect == "sqlite":
|
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|
|
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
|
# This code needs to be generalized a bit to work with more dialects.
|
|
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
|
|
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
|
[UpsertionRecord.key, UpsertionRecord.namespace],
|
|
set_=dict(
|
|
# attr-defined type ignore
|
|
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
|
group_id=insert_stmt.excluded.group_id, # type: ignore
|
|
),
|
|
)
|
|
elif self.dialect == "postgresql":
|
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
|
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
|
# This code needs to be generalized a bit to work with more dialects.
|
|
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
|
|
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
|
"uix_key_namespace", # Name of constraint
|
|
set_=dict(
|
|
# attr-defined type ignore
|
|
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
|
group_id=insert_stmt.excluded.group_id, # type: ignore
|
|
),
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
|
|
|
|
session.execute(stmt)
|
|
session.commit()
|
|
|
|
async def aupdate(
|
|
self,
|
|
keys: Sequence[str],
|
|
*,
|
|
group_ids: Optional[Sequence[Optional[str]]] = None,
|
|
time_at_least: Optional[float] = None,
|
|
) -> None:
|
|
"""Upsert records into the SQLite database."""
|
|
if group_ids is None:
|
|
group_ids = [None] * len(keys)
|
|
|
|
if len(keys) != len(group_ids):
|
|
raise ValueError(
|
|
f"Number of keys ({len(keys)}) does not match number of "
|
|
f"group_ids ({len(group_ids)})"
|
|
)
|
|
|
|
# Get the current time from the server.
|
|
# This makes an extra round trip to the server, should not be a big deal
|
|
# if the batch size is large enough.
|
|
# Getting the time here helps us compare it against the time_at_least
|
|
# and raise an error if there is a time sync issue.
|
|
# Here, we're just being extra careful to minimize the chance of
|
|
# data loss due to incorrectly deleting records.
|
|
update_time = await self.aget_time()
|
|
|
|
if time_at_least and update_time < time_at_least:
|
|
# Safeguard against time sync issues
|
|
raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}")
|
|
|
|
records_to_upsert = [
|
|
{
|
|
"key": key,
|
|
"namespace": self.namespace,
|
|
"updated_at": update_time,
|
|
"group_id": group_id,
|
|
}
|
|
for key, group_id in zip(keys, group_ids)
|
|
]
|
|
|
|
async with self._amake_session() as session:
|
|
if self.dialect == "sqlite":
|
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|
|
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
|
# This code needs to be generalized a bit to work with more dialects.
|
|
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
|
|
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
|
[UpsertionRecord.key, UpsertionRecord.namespace],
|
|
set_=dict(
|
|
# attr-defined type ignore
|
|
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
|
group_id=insert_stmt.excluded.group_id, # type: ignore
|
|
),
|
|
)
|
|
elif self.dialect == "postgresql":
|
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
|
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
|
# This code needs to be generalized a bit to work with more dialects.
|
|
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
|
|
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
|
"uix_key_namespace", # Name of constraint
|
|
set_=dict(
|
|
# attr-defined type ignore
|
|
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
|
group_id=insert_stmt.excluded.group_id, # type: ignore
|
|
),
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
|
|
|
|
await session.execute(stmt)
|
|
await session.commit()
|
|
|
|
def exists(self, keys: Sequence[str]) -> List[bool]:
|
|
"""Check if the given keys exist in the SQLite database."""
|
|
with self._make_session() as session:
|
|
records = (
|
|
# mypy does not recognize .all()
|
|
session.query(UpsertionRecord.key) # type: ignore[attr-defined]
|
|
.filter(
|
|
and_(
|
|
UpsertionRecord.key.in_(keys),
|
|
UpsertionRecord.namespace == self.namespace,
|
|
)
|
|
)
|
|
.all()
|
|
)
|
|
found_keys = set(r.key for r in records)
|
|
return [k in found_keys for k in keys]
|
|
|
|
async def aexists(self, keys: Sequence[str]) -> List[bool]:
|
|
"""Check if the given keys exist in the SQLite database."""
|
|
async with self._amake_session() as session:
|
|
records = (
|
|
(
|
|
await session.execute(
|
|
select(UpsertionRecord.key).where(
|
|
and_(
|
|
UpsertionRecord.key.in_(keys),
|
|
UpsertionRecord.namespace == self.namespace,
|
|
)
|
|
)
|
|
)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
found_keys = set(records)
|
|
return [k in found_keys for k in keys]
|
|
|
|
def list_keys(
|
|
self,
|
|
*,
|
|
before: Optional[float] = None,
|
|
after: Optional[float] = None,
|
|
group_ids: Optional[Sequence[str]] = None,
|
|
limit: Optional[int] = None,
|
|
) -> List[str]:
|
|
"""List records in the SQLite database based on the provided date range."""
|
|
with self._make_session() as session:
|
|
query = session.query(UpsertionRecord).filter(
|
|
UpsertionRecord.namespace == self.namespace
|
|
)
|
|
|
|
# mypy does not recognize .all() or .filter()
|
|
if after:
|
|
query = query.filter( # type: ignore[attr-defined]
|
|
UpsertionRecord.updated_at > after
|
|
)
|
|
if before:
|
|
query = query.filter( # type: ignore[attr-defined]
|
|
UpsertionRecord.updated_at < before
|
|
)
|
|
if group_ids:
|
|
query = query.filter( # type: ignore[attr-defined]
|
|
UpsertionRecord.group_id.in_(group_ids)
|
|
)
|
|
|
|
if limit:
|
|
query = query.limit(limit) # type: ignore[attr-defined]
|
|
records = query.all() # type: ignore[attr-defined]
|
|
return [r.key for r in records]
|
|
|
|
async def alist_keys(
|
|
self,
|
|
*,
|
|
before: Optional[float] = None,
|
|
after: Optional[float] = None,
|
|
group_ids: Optional[Sequence[str]] = None,
|
|
limit: Optional[int] = None,
|
|
) -> List[str]:
|
|
"""List records in the SQLite database based on the provided date range."""
|
|
async with self._amake_session() as session:
|
|
query = select(UpsertionRecord.key).filter(
|
|
UpsertionRecord.namespace == self.namespace
|
|
)
|
|
|
|
# mypy does not recognize .all() or .filter()
|
|
if after:
|
|
query = query.filter( # type: ignore[attr-defined]
|
|
UpsertionRecord.updated_at > after
|
|
)
|
|
if before:
|
|
query = query.filter( # type: ignore[attr-defined]
|
|
UpsertionRecord.updated_at < before
|
|
)
|
|
if group_ids:
|
|
query = query.filter( # type: ignore[attr-defined]
|
|
UpsertionRecord.group_id.in_(group_ids)
|
|
)
|
|
|
|
if limit:
|
|
query = query.limit(limit) # type: ignore[attr-defined]
|
|
records = (await session.execute(query)).scalars().all()
|
|
return list(records)
|
|
|
|
def delete_keys(self, keys: Sequence[str]) -> None:
|
|
"""Delete records from the SQLite database."""
|
|
with self._make_session() as session:
|
|
# mypy does not recognize .delete()
|
|
session.query(UpsertionRecord).filter(
|
|
and_(
|
|
UpsertionRecord.key.in_(keys),
|
|
UpsertionRecord.namespace == self.namespace,
|
|
)
|
|
).delete() # type: ignore[attr-defined]
|
|
session.commit()
|
|
|
|
async def adelete_keys(self, keys: Sequence[str]) -> None:
|
|
"""Delete records from the SQLite database."""
|
|
async with self._amake_session() as session:
|
|
await session.execute(
|
|
delete(UpsertionRecord).where(
|
|
and_(
|
|
UpsertionRecord.key.in_(keys),
|
|
UpsertionRecord.namespace == self.namespace,
|
|
)
|
|
)
|
|
)
|
|
|
|
await session.commit()
|