diff --git a/langchain/memory/chat_message_histories/sql.py b/langchain/memory/chat_message_histories/sql.py index e3770133..40f8691c 100644 --- a/langchain/memory/chat_message_histories/sql.py +++ b/langchain/memory/chat_message_histories/sql.py @@ -3,7 +3,8 @@ import logging from typing import List from sqlalchemy import Column, Integer, Text, create_engine -from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker from langchain.schema import ( AIMessage, diff --git a/langchain/sql_database.py b/langchain/sql_database.py index fc0a5098..337af0e8 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -13,7 +13,7 @@ from sqlalchemy import ( select, text, ) -from sqlalchemy.engine import CursorResult, Engine +from sqlalchemy.engine import Engine from sqlalchemy.exc import ProgrammingError, SQLAlchemyError from sqlalchemy.schema import CreateTable @@ -196,7 +196,7 @@ class SQLDatabase: try: # get the sample rows with self._engine.connect() as connection: - sample_rows_result: CursorResult = connection.execute(command) + sample_rows_result = connection.execute(command) # type: ignore # shorten values in the sample rows sample_rows = list( map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result) diff --git a/langchain/vectorstores/analyticdb.py b/langchain/vectorstores/analyticdb.py index 02f3be63..bb32549b 100644 --- a/langchain/vectorstores/analyticdb.py +++ b/langchain/vectorstores/analyticdb.py @@ -8,7 +8,8 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple import sqlalchemy from sqlalchemy import REAL, Index from sqlalchemy.dialects.postgresql import ARRAY, JSON, UUID -from sqlalchemy.orm import Session, declarative_base, relationship +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session, relationship from sqlalchemy.sql.expression import func from langchain.docstore.document import Document