diff --git a/langchain/cache.py b/langchain/cache.py index 52c272c4..c81142cc 100644 --- a/langchain/cache.py +++ b/langchain/cache.py @@ -4,8 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple from sqlalchemy import Column, Integer, String, create_engine, select from sqlalchemy.engine.base import Engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, declarative_base from langchain.schema import Generation diff --git a/langchain/sql_database.py b/langchain/sql_database.py index d4695c0b..2d5a8405 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -86,7 +86,7 @@ class SQLDatabase: If the statement returns rows, a string of the results is returned. If the statement returns no rows, an empty string is returned. """ - with self._engine.connect() as connection: + with self._engine.begin() as connection: if self._schema is not None: connection.exec_driver_sql(f"SET search_path TO {self._schema}") cursor = connection.exec_driver_sql(command) diff --git a/tests/unit_tests/llms/test_base.py b/tests/unit_tests/llms/test_base.py index 9726b304..8838997d 100644 --- a/tests/unit_tests/llms/test_base.py +++ b/tests/unit_tests/llms/test_base.py @@ -1,6 +1,6 @@ """Test base LLM functionality.""" from sqlalchemy import Column, Integer, Sequence, String, create_engine -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import declarative_base import langchain from langchain.cache import InMemoryCache, SQLAlchemyCache diff --git a/tests/unit_tests/test_sql_database.py b/tests/unit_tests/test_sql_database.py index d9ce84f5..4c5d208f 100644 --- a/tests/unit_tests/test_sql_database.py +++ b/tests/unit_tests/test_sql_database.py @@ -40,7 +40,7 @@ def test_sql_database_run() -> None: engine = create_engine("sqlite:///:memory:") metadata_obj.create_all(engine) stmt = insert(user).values(user_id=13, user_name="Harrison") - with engine.connect() as conn: + with engine.begin() as conn: conn.execute(stmt) db = SQLDatabase(engine) command = "select user_name from user where user_id = 13" @@ -54,7 +54,7 @@ def test_sql_database_run_update() -> None: engine = create_engine("sqlite:///:memory:") metadata_obj.create_all(engine) stmt = insert(user).values(user_id=13, user_name="Harrison") - with engine.connect() as conn: + with engine.begin() as conn: conn.execute(stmt) db = SQLDatabase(engine) command = "update user set user_name='Updated' where user_id = 13" diff --git a/tests/unit_tests/test_sql_database_schema.py b/tests/unit_tests/test_sql_database_schema.py index 16f54532..6b15e600 100644 --- a/tests/unit_tests/test_sql_database_schema.py +++ b/tests/unit_tests/test_sql_database_schema.py @@ -57,7 +57,7 @@ def test_sql_database_run() -> None: engine = create_engine("duckdb:///:memory:") metadata_obj.create_all(engine) stmt = insert(user).values(user_id=13, user_name="Harrison") - with engine.connect() as conn: + with engine.begin() as conn: conn.execute(stmt) db = SQLDatabase(engine, schema="schema_a") command = 'select user_name from "user" where user_id = 13'