This commit is contained in:
Eugene Yurtsev 2023-08-29 11:13:42 -04:00
parent e80834d783
commit 9efc29e3d1
2 changed files with 23 additions and 6 deletions

View File

@ -332,9 +332,9 @@ def index(
uids_to_delete = record_manager.list_keys(before=index_start_dt)
if uids_to_delete:
# Then delete from vector store.
vector_store.delete(uids_to_delete)
# First delete from record store.
vector_store.delete(uids_to_delete)
# Then delete from record manager.
record_manager.delete_keys(uids_to_delete)
num_deleted = len(uids_to_delete)

View File

@ -15,8 +15,10 @@ allow it to work with a variety of SQL as a backend.
"""
import contextlib
import uuid
from typing import Any, Dict, Generator, List, Optional, Sequence
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
import decimal
from sqlalchemy import URL
from sqlalchemy import (
Column,
Engine,
@ -28,7 +30,6 @@ from sqlalchemy import (
create_engine,
text,
)
from sqlalchemy.dialects.sqlite import insert
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker
@ -77,7 +78,7 @@ class SQLRecordManager(RecordManager):
namespace: str,
*,
engine: Optional[Engine] = None,
db_url: Optional[str] = None,
db_url: Union[None, str, URL] = None,
engine_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the SQLRecordManager.
@ -114,6 +115,7 @@ class SQLRecordManager(RecordManager):
raise AssertionError("Something went wrong with configuration of engine.")
self.engine = _engine
self.dialect = _engine.dialect.name
self.session_factory = sessionmaker(bind=self.engine)
def create_schema(self) -> None:
@ -145,8 +147,16 @@ class SQLRecordManager(RecordManager):
# 2440587.5 - constant represents the Julian day number for January 1, 1970
# 86400.0 - constant represents the number of seconds
# in a day (24 hours * 60 minutes * 60 seconds)
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
if self.dialect == "sqlite":
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
elif self.dialect == "postgresql":
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
else:
raise NotImplementedError(f"Not implemented for dialect {self.dialect}")
dt = session.execute(query).scalar()
if isinstance(dt, decimal.Decimal):
dt = float(dt)
if not isinstance(dt, float):
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
return dt
@ -191,6 +201,13 @@ class SQLRecordManager(RecordManager):
for key, group_id in zip(keys, group_ids)
]
if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import insert
elif self.dialect == "postgresql":
from sqlalchemy.dialects.sqlite import insert
else:
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
with self._make_session() as session:
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.