From 9efc29e3d18eb324ea45026295272be9515e6a98 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 29 Aug 2023 11:13:42 -0400 Subject: [PATCH] x --- libs/langchain/langchain/indexes/_api.py | 4 +-- .../langchain/indexes/_sql_record_manager.py | 25 ++++++++++++++++--- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/indexes/_api.py b/libs/langchain/langchain/indexes/_api.py index 47b9d33ea8..130a5c685d 100644 --- a/libs/langchain/langchain/indexes/_api.py +++ b/libs/langchain/langchain/indexes/_api.py @@ -332,9 +332,9 @@ def index( uids_to_delete = record_manager.list_keys(before=index_start_dt) if uids_to_delete: - # Then delete from vector store. - vector_store.delete(uids_to_delete) # First delete from record store. + vector_store.delete(uids_to_delete) + # Then delete from record manager. record_manager.delete_keys(uids_to_delete) num_deleted = len(uids_to_delete) diff --git a/libs/langchain/langchain/indexes/_sql_record_manager.py b/libs/langchain/langchain/indexes/_sql_record_manager.py index 9cad02ef93..be793dcf57 100644 --- a/libs/langchain/langchain/indexes/_sql_record_manager.py +++ b/libs/langchain/langchain/indexes/_sql_record_manager.py @@ -15,8 +15,10 @@ allow it to work with a variety of SQL as a backend. """ import contextlib import uuid -from typing import Any, Dict, Generator, List, Optional, Sequence +from typing import Any, Dict, Generator, List, Optional, Sequence, Union +import decimal +from sqlalchemy import URL from sqlalchemy import ( Column, Engine, @@ -28,7 +30,6 @@ from sqlalchemy import ( create_engine, text, ) -from sqlalchemy.dialects.sqlite import insert from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, sessionmaker @@ -77,7 +78,7 @@ class SQLRecordManager(RecordManager): namespace: str, *, engine: Optional[Engine] = None, - db_url: Optional[str] = None, + db_url: Union[None, str, URL] = None, engine_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the SQLRecordManager. @@ -114,6 +115,7 @@ class SQLRecordManager(RecordManager): raise AssertionError("Something went wrong with configuration of engine.") self.engine = _engine + self.dialect = _engine.dialect.name self.session_factory = sessionmaker(bind=self.engine) def create_schema(self) -> None: @@ -145,8 +147,16 @@ class SQLRecordManager(RecordManager): # 2440587.5 - constant represents the Julian day number for January 1, 1970 # 86400.0 - constant represents the number of seconds # in a day (24 hours * 60 minutes * 60 seconds) - query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;") + if self.dialect == "sqlite": + query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;") + elif self.dialect == "postgresql": + query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);") + else: + raise NotImplementedError(f"Not implemented for dialect {self.dialect}") + dt = session.execute(query).scalar() + if isinstance(dt, decimal.Decimal): + dt = float(dt) if not isinstance(dt, float): raise AssertionError(f"Unexpected type for datetime: {type(dt)}") return dt @@ -191,6 +201,13 @@ class SQLRecordManager(RecordManager): for key, group_id in zip(keys, group_ids) ] + if self.dialect == "sqlite": + from sqlalchemy.dialects.sqlite import insert + elif self.dialect == "postgresql": + from sqlalchemy.dialects.sqlite import insert + else: + raise NotImplementedError(f"Unsupported dialect {self.dialect}") + with self._make_session() as session: # Note: uses SQLite insert to make on_conflict_do_update work. # This code needs to be generalized a bit to work with more dialects.