core[patch]: fix _sql_record_manager mypy for #17048 (#17073)

- **Description:** Add relevant type annotations for relevant session
and query objects to resolve mypy errors when `# type: ignore` comments
are removed.
  - **Issue:** #17048
  - **Dependencies:** None,
  - **Twitter handle:** [clesiemo3](https://twitter.com/clesiemo3)
 
I attempted to solve the `UpsertionRecord` ignore but it would require
added a deprecated plugin or moving completely to sqlalchemy 2.0+ from
my understanding. I'm assuming this is not something desired at this
point in time.
pull/17078/head
Jimmy Moore 4 months ago committed by GitHub
parent 3d5e988c55
commit 912210ac19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -39,7 +39,7 @@ from sqlalchemy.ext.asyncio import (
create_async_engine,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Query, Session, sessionmaker
from langchain.indexes.base import RecordManager
@ -284,31 +284,35 @@ class SQLRecordManager(RecordManager):
with self._make_session() as session:
if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
sqlite_insert_stmt: SqliteInsertType = sqlite_insert(
UpsertionRecord
).values(records_to_upsert)
stmt = sqlite_insert_stmt.on_conflict_do_update(
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
updated_at=sqlite_insert_stmt.excluded.updated_at,
group_id=sqlite_insert_stmt.excluded.group_id,
),
)
elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
from sqlalchemy.dialects.postgresql import insert as pg_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# Note: uses postgresql insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values(
records_to_upsert
)
stmt = pg_insert_stmt.on_conflict_do_update(
"uix_key_namespace", # Name of constraint
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
updated_at=pg_insert_stmt.excluded.updated_at,
group_id=pg_insert_stmt.excluded.group_id,
),
)
else:
@ -359,31 +363,35 @@ class SQLRecordManager(RecordManager):
async with self._amake_session() as session:
if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
sqlite_insert_stmt: SqliteInsertType = sqlite_insert(
UpsertionRecord
).values(records_to_upsert)
stmt = sqlite_insert_stmt.on_conflict_do_update(
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
updated_at=sqlite_insert_stmt.excluded.updated_at,
group_id=sqlite_insert_stmt.excluded.group_id,
),
)
elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
from sqlalchemy.dialects.postgresql import insert as pg_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values(
records_to_upsert
)
stmt = pg_insert_stmt.on_conflict_do_update(
"uix_key_namespace", # Name of constraint
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
updated_at=pg_insert_stmt.excluded.updated_at,
group_id=pg_insert_stmt.excluded.group_id,
),
)
else:
@ -394,18 +402,15 @@ class SQLRecordManager(RecordManager):
def exists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the given keys exist in the SQLite database."""
session: Session
with self._make_session() as session:
records = (
# mypy does not recognize .all()
session.query(UpsertionRecord.key) # type: ignore[attr-defined]
.filter(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
)
filtered_query: Query = session.query(UpsertionRecord.key).filter(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
)
.all()
)
records = filtered_query.all()
found_keys = set(r.key for r in records)
return [k in found_keys for k in keys]
@ -438,28 +443,22 @@ class SQLRecordManager(RecordManager):
limit: Optional[int] = None,
) -> List[str]:
"""List records in the SQLite database based on the provided date range."""
session: Session
with self._make_session() as session:
query = session.query(UpsertionRecord).filter(
query: Query = session.query(UpsertionRecord).filter(
UpsertionRecord.namespace == self.namespace
)
# mypy does not recognize .all() or .filter()
if after:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at > after
)
query = query.filter(UpsertionRecord.updated_at > after)
if before:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at < before
)
query = query.filter(UpsertionRecord.updated_at < before)
if group_ids:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.group_id.in_(group_ids)
)
query = query.filter(UpsertionRecord.group_id.in_(group_ids))
if limit:
query = query.limit(limit) # type: ignore[attr-defined]
records = query.all() # type: ignore[attr-defined]
query = query.limit(limit)
records = query.all()
return [r.key for r in records]
async def alist_keys(
@ -471,40 +470,37 @@ class SQLRecordManager(RecordManager):
limit: Optional[int] = None,
) -> List[str]:
"""List records in the SQLite database based on the provided date range."""
session: AsyncSession
async with self._amake_session() as session:
query = select(UpsertionRecord.key).filter(
query: Query = select(UpsertionRecord.key).filter(
UpsertionRecord.namespace == self.namespace
)
# mypy does not recognize .all() or .filter()
if after:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at > after
)
query = query.filter(UpsertionRecord.updated_at > after)
if before:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at < before
)
query = query.filter(UpsertionRecord.updated_at < before)
if group_ids:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.group_id.in_(group_ids)
)
query = query.filter(UpsertionRecord.group_id.in_(group_ids))
if limit:
query = query.limit(limit) # type: ignore[attr-defined]
query = query.limit(limit)
records = (await session.execute(query)).scalars().all()
return list(records)
def delete_keys(self, keys: Sequence[str]) -> None:
"""Delete records from the SQLite database."""
session: Session
with self._make_session() as session:
# mypy does not recognize .delete()
session.query(UpsertionRecord).filter(
filtered_query: Query = session.query(UpsertionRecord).filter(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
)
).delete() # type: ignore[attr-defined]
)
filtered_query.delete()
session.commit()
async def adelete_keys(self, keys: Sequence[str]) -> None:

Loading…
Cancel
Save