core[patch]: fix _sql_record_manager mypy for #17048 (#17073)

- **Description:** Add relevant type annotations for relevant session
and query objects to resolve mypy errors when `# type: ignore` comments
are removed.
  - **Issue:** #17048
  - **Dependencies:** None,
  - **Twitter handle:** [clesiemo3](https://twitter.com/clesiemo3)
 
I attempted to solve the `UpsertionRecord` ignore but it would require
added a deprecated plugin or moving completely to sqlalchemy 2.0+ from
my understanding. I'm assuming this is not something desired at this
point in time.
pull/17078/head
Jimmy Moore 5 months ago committed by GitHub
parent 3d5e988c55
commit 912210ac19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -39,7 +39,7 @@ from sqlalchemy.ext.asyncio import (
create_async_engine, create_async_engine,
) )
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Query, Session, sessionmaker
from langchain.indexes.base import RecordManager from langchain.indexes.base import RecordManager
@ -284,31 +284,35 @@ class SQLRecordManager(RecordManager):
with self._make_session() as session: with self._make_session() as session:
if self.dialect == "sqlite": if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType
from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert
# Note: uses SQLite insert to make on_conflict_do_update work. # Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects. # This code needs to be generalized a bit to work with more dialects.
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert) sqlite_insert_stmt: SqliteInsertType = sqlite_insert(
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] UpsertionRecord
).values(records_to_upsert)
stmt = sqlite_insert_stmt.on_conflict_do_update(
[UpsertionRecord.key, UpsertionRecord.namespace], [UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict( set_=dict(
# attr-defined type ignore updated_at=sqlite_insert_stmt.excluded.updated_at,
updated_at=insert_stmt.excluded.updated_at, # type: ignore group_id=sqlite_insert_stmt.excluded.group_id,
group_id=insert_stmt.excluded.group_id, # type: ignore
), ),
) )
elif self.dialect == "postgresql": elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.postgresql import insert as pg_insert
# Note: uses SQLite insert to make on_conflict_do_update work. # Note: uses postgresql insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects. # This code needs to be generalized a bit to work with more dialects.
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values(
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] records_to_upsert
)
stmt = pg_insert_stmt.on_conflict_do_update(
"uix_key_namespace", # Name of constraint "uix_key_namespace", # Name of constraint
set_=dict( set_=dict(
# attr-defined type ignore updated_at=pg_insert_stmt.excluded.updated_at,
updated_at=insert_stmt.excluded.updated_at, # type: ignore group_id=pg_insert_stmt.excluded.group_id,
group_id=insert_stmt.excluded.group_id, # type: ignore
), ),
) )
else: else:
@ -359,31 +363,35 @@ class SQLRecordManager(RecordManager):
async with self._amake_session() as session: async with self._amake_session() as session:
if self.dialect == "sqlite": if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType
from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert
# Note: uses SQLite insert to make on_conflict_do_update work. # Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects. # This code needs to be generalized a bit to work with more dialects.
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert) sqlite_insert_stmt: SqliteInsertType = sqlite_insert(
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] UpsertionRecord
).values(records_to_upsert)
stmt = sqlite_insert_stmt.on_conflict_do_update(
[UpsertionRecord.key, UpsertionRecord.namespace], [UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict( set_=dict(
# attr-defined type ignore updated_at=sqlite_insert_stmt.excluded.updated_at,
updated_at=insert_stmt.excluded.updated_at, # type: ignore group_id=sqlite_insert_stmt.excluded.group_id,
group_id=insert_stmt.excluded.group_id, # type: ignore
), ),
) )
elif self.dialect == "postgresql": elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.postgresql import insert as pg_insert
# Note: uses SQLite insert to make on_conflict_do_update work. # Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects. # This code needs to be generalized a bit to work with more dialects.
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values(
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] records_to_upsert
)
stmt = pg_insert_stmt.on_conflict_do_update(
"uix_key_namespace", # Name of constraint "uix_key_namespace", # Name of constraint
set_=dict( set_=dict(
# attr-defined type ignore updated_at=pg_insert_stmt.excluded.updated_at,
updated_at=insert_stmt.excluded.updated_at, # type: ignore group_id=pg_insert_stmt.excluded.group_id,
group_id=insert_stmt.excluded.group_id, # type: ignore
), ),
) )
else: else:
@ -394,18 +402,15 @@ class SQLRecordManager(RecordManager):
def exists(self, keys: Sequence[str]) -> List[bool]: def exists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the given keys exist in the SQLite database.""" """Check if the given keys exist in the SQLite database."""
session: Session
with self._make_session() as session: with self._make_session() as session:
records = ( filtered_query: Query = session.query(UpsertionRecord.key).filter(
# mypy does not recognize .all() and_(
session.query(UpsertionRecord.key) # type: ignore[attr-defined] UpsertionRecord.key.in_(keys),
.filter( UpsertionRecord.namespace == self.namespace,
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
)
) )
.all()
) )
records = filtered_query.all()
found_keys = set(r.key for r in records) found_keys = set(r.key for r in records)
return [k in found_keys for k in keys] return [k in found_keys for k in keys]
@ -438,28 +443,22 @@ class SQLRecordManager(RecordManager):
limit: Optional[int] = None, limit: Optional[int] = None,
) -> List[str]: ) -> List[str]:
"""List records in the SQLite database based on the provided date range.""" """List records in the SQLite database based on the provided date range."""
session: Session
with self._make_session() as session: with self._make_session() as session:
query = session.query(UpsertionRecord).filter( query: Query = session.query(UpsertionRecord).filter(
UpsertionRecord.namespace == self.namespace UpsertionRecord.namespace == self.namespace
) )
# mypy does not recognize .all() or .filter()
if after: if after:
query = query.filter( # type: ignore[attr-defined] query = query.filter(UpsertionRecord.updated_at > after)
UpsertionRecord.updated_at > after
)
if before: if before:
query = query.filter( # type: ignore[attr-defined] query = query.filter(UpsertionRecord.updated_at < before)
UpsertionRecord.updated_at < before
)
if group_ids: if group_ids:
query = query.filter( # type: ignore[attr-defined] query = query.filter(UpsertionRecord.group_id.in_(group_ids))
UpsertionRecord.group_id.in_(group_ids)
)
if limit: if limit:
query = query.limit(limit) # type: ignore[attr-defined] query = query.limit(limit)
records = query.all() # type: ignore[attr-defined] records = query.all()
return [r.key for r in records] return [r.key for r in records]
async def alist_keys( async def alist_keys(
@ -471,40 +470,37 @@ class SQLRecordManager(RecordManager):
limit: Optional[int] = None, limit: Optional[int] = None,
) -> List[str]: ) -> List[str]:
"""List records in the SQLite database based on the provided date range.""" """List records in the SQLite database based on the provided date range."""
session: AsyncSession
async with self._amake_session() as session: async with self._amake_session() as session:
query = select(UpsertionRecord.key).filter( query: Query = select(UpsertionRecord.key).filter(
UpsertionRecord.namespace == self.namespace UpsertionRecord.namespace == self.namespace
) )
# mypy does not recognize .all() or .filter() # mypy does not recognize .all() or .filter()
if after: if after:
query = query.filter( # type: ignore[attr-defined] query = query.filter(UpsertionRecord.updated_at > after)
UpsertionRecord.updated_at > after
)
if before: if before:
query = query.filter( # type: ignore[attr-defined] query = query.filter(UpsertionRecord.updated_at < before)
UpsertionRecord.updated_at < before
)
if group_ids: if group_ids:
query = query.filter( # type: ignore[attr-defined] query = query.filter(UpsertionRecord.group_id.in_(group_ids))
UpsertionRecord.group_id.in_(group_ids)
)
if limit: if limit:
query = query.limit(limit) # type: ignore[attr-defined] query = query.limit(limit)
records = (await session.execute(query)).scalars().all() records = (await session.execute(query)).scalars().all()
return list(records) return list(records)
def delete_keys(self, keys: Sequence[str]) -> None: def delete_keys(self, keys: Sequence[str]) -> None:
"""Delete records from the SQLite database.""" """Delete records from the SQLite database."""
session: Session
with self._make_session() as session: with self._make_session() as session:
# mypy does not recognize .delete() filtered_query: Query = session.query(UpsertionRecord).filter(
session.query(UpsertionRecord).filter(
and_( and_(
UpsertionRecord.key.in_(keys), UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace, UpsertionRecord.namespace == self.namespace,
) )
).delete() # type: ignore[attr-defined] )
filtered_query.delete()
session.commit() session.commit()
async def adelete_keys(self, keys: Sequence[str]) -> None: async def adelete_keys(self, keys: Sequence[str]) -> None:

Loading…
Cancel
Save