@ -39,7 +39,7 @@ from sqlalchemy.ext.asyncio import (
create_async_engine ,
)
from sqlalchemy . ext . declarative import declarative_base
from sqlalchemy . orm import Session, sessionmaker
from sqlalchemy . orm import Query, Session, sessionmaker
from langchain . indexes . base import RecordManager
@ -284,31 +284,35 @@ class SQLRecordManager(RecordManager):
with self . _make_session ( ) as session :
if self . dialect == " sqlite " :
from sqlalchemy . dialects . sqlite import Insert as SqliteInsertType
from sqlalchemy . dialects . sqlite import insert as sqlite_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = sqlite_insert ( UpsertionRecord ) . values ( records_to_upsert )
stmt = insert_stmt . on_conflict_do_update ( # type: ignore[attr-defined]
sqlite_insert_stmt : SqliteInsertType = sqlite_insert (
UpsertionRecord
) . values ( records_to_upsert )
stmt = sqlite_insert_stmt . on_conflict_do_update (
[ UpsertionRecord . key , UpsertionRecord . namespace ] ,
set_ = dict (
# attr-defined type ignore
updated_at = insert_stmt . excluded . updated_at , # type: ignore
group_id = insert_stmt . excluded . group_id , # type: ignore
updated_at = sqlite_insert_stmt . excluded . updated_at ,
group_id = sqlite_insert_stmt . excluded . group_id ,
) ,
)
elif self . dialect == " postgresql " :
from sqlalchemy . dialects . postgresql import Insert as PgInsertType
from sqlalchemy . dialects . postgresql import insert as pg_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# Note: uses postgresql insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = pg_insert ( UpsertionRecord ) . values ( records_to_upsert )
stmt = insert_stmt . on_conflict_do_update ( # type: ignore[attr-defined]
pg_insert_stmt : PgInsertType = pg_insert ( UpsertionRecord ) . values (
records_to_upsert
)
stmt = pg_insert_stmt . on_conflict_do_update (
" uix_key_namespace " , # Name of constraint
set_ = dict (
# attr-defined type ignore
updated_at = insert_stmt . excluded . updated_at , # type: ignore
group_id = insert_stmt . excluded . group_id , # type: ignore
updated_at = pg_insert_stmt . excluded . updated_at ,
group_id = pg_insert_stmt . excluded . group_id ,
) ,
)
else :
@ -359,31 +363,35 @@ class SQLRecordManager(RecordManager):
async with self . _amake_session ( ) as session :
if self . dialect == " sqlite " :
from sqlalchemy . dialects . sqlite import Insert as SqliteInsertType
from sqlalchemy . dialects . sqlite import insert as sqlite_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = sqlite_insert ( UpsertionRecord ) . values ( records_to_upsert )
stmt = insert_stmt . on_conflict_do_update ( # type: ignore[attr-defined]
sqlite_insert_stmt : SqliteInsertType = sqlite_insert (
UpsertionRecord
) . values ( records_to_upsert )
stmt = sqlite_insert_stmt . on_conflict_do_update (
[ UpsertionRecord . key , UpsertionRecord . namespace ] ,
set_ = dict (
# attr-defined type ignore
updated_at = insert_stmt . excluded . updated_at , # type: ignore
group_id = insert_stmt . excluded . group_id , # type: ignore
updated_at = sqlite_insert_stmt . excluded . updated_at ,
group_id = sqlite_insert_stmt . excluded . group_id ,
) ,
)
elif self . dialect == " postgresql " :
from sqlalchemy . dialects . postgresql import Insert as PgInsertType
from sqlalchemy . dialects . postgresql import insert as pg_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = pg_insert ( UpsertionRecord ) . values ( records_to_upsert )
stmt = insert_stmt . on_conflict_do_update ( # type: ignore[attr-defined]
pg_insert_stmt : PgInsertType = pg_insert ( UpsertionRecord ) . values (
records_to_upsert
)
stmt = pg_insert_stmt . on_conflict_do_update (
" uix_key_namespace " , # Name of constraint
set_ = dict (
# attr-defined type ignore
updated_at = insert_stmt . excluded . updated_at , # type: ignore
group_id = insert_stmt . excluded . group_id , # type: ignore
updated_at = pg_insert_stmt . excluded . updated_at ,
group_id = pg_insert_stmt . excluded . group_id ,
) ,
)
else :
@ -394,18 +402,15 @@ class SQLRecordManager(RecordManager):
def exists ( self , keys : Sequence [ str ] ) - > List [ bool ] :
""" Check if the given keys exist in the SQLite database. """
session : Session
with self . _make_session ( ) as session :
records = (
# mypy does not recognize .all()
session . query ( UpsertionRecord . key ) # type: ignore[attr-defined]
. filter (
and_ (
UpsertionRecord . key . in_ ( keys ) ,
UpsertionRecord . namespace == self . namespace ,
)
filtered_query : Query = session . query ( UpsertionRecord . key ) . filter (
and_ (
UpsertionRecord . key . in_ ( keys ) ,
UpsertionRecord . namespace == self . namespace ,
)
. all ( )
)
records = filtered_query . all ( )
found_keys = set ( r . key for r in records )
return [ k in found_keys for k in keys ]
@ -438,28 +443,22 @@ class SQLRecordManager(RecordManager):
limit : Optional [ int ] = None ,
) - > List [ str ] :
""" List records in the SQLite database based on the provided date range. """
session : Session
with self . _make_session ( ) as session :
query = session . query ( UpsertionRecord ) . filter (
query : Query = session . query ( UpsertionRecord ) . filter (
UpsertionRecord . namespace == self . namespace
)
# mypy does not recognize .all() or .filter()
if after :
query = query . filter ( # type: ignore[attr-defined]
UpsertionRecord . updated_at > after
)
query = query . filter ( UpsertionRecord . updated_at > after )
if before :
query = query . filter ( # type: ignore[attr-defined]
UpsertionRecord . updated_at < before
)
query = query . filter ( UpsertionRecord . updated_at < before )
if group_ids :
query = query . filter ( # type: ignore[attr-defined]
UpsertionRecord . group_id . in_ ( group_ids )
)
query = query . filter ( UpsertionRecord . group_id . in_ ( group_ids ) )
if limit :
query = query . limit ( limit ) # type: ignore[attr-defined]
records = query . all ( ) # type: ignore[attr-defined]
query = query . limit ( limit )
records = query . all ( )
return [ r . key for r in records ]
async def alist_keys (
@ -471,40 +470,37 @@ class SQLRecordManager(RecordManager):
limit : Optional [ int ] = None ,
) - > List [ str ] :
""" List records in the SQLite database based on the provided date range. """
session : AsyncSession
async with self . _amake_session ( ) as session :
query = select ( UpsertionRecord . key ) . filter (
query : Query = select ( UpsertionRecord . key ) . filter (
UpsertionRecord . namespace == self . namespace
)
# mypy does not recognize .all() or .filter()
if after :
query = query . filter ( # type: ignore[attr-defined]
UpsertionRecord . updated_at > after
)
query = query . filter ( UpsertionRecord . updated_at > after )
if before :
query = query . filter ( # type: ignore[attr-defined]
UpsertionRecord . updated_at < before
)
query = query . filter ( UpsertionRecord . updated_at < before )
if group_ids :
query = query . filter ( # type: ignore[attr-defined]
UpsertionRecord . group_id . in_ ( group_ids )
)
query = query . filter ( UpsertionRecord . group_id . in_ ( group_ids ) )
if limit :
query = query . limit ( limit ) # type: ignore[attr-defined]
query = query . limit ( limit )
records = ( await session . execute ( query ) ) . scalars ( ) . all ( )
return list ( records )
def delete_keys ( self , keys : Sequence [ str ] ) - > None :
""" Delete records from the SQLite database. """
session : Session
with self . _make_session ( ) as session :
# mypy does not recognize .delete()
session . query ( UpsertionRecord ) . filter (
filtered_query : Query = session . query ( UpsertionRecord ) . filter (
and_ (
UpsertionRecord . key . in_ ( keys ) ,
UpsertionRecord . namespace == self . namespace ,
)
) . delete ( ) # type: ignore[attr-defined]
)
filtered_query . delete ( )
session . commit ( )
async def adelete_keys ( self , keys : Sequence [ str ] ) - > None :