From 1fdd9bd980e09d9f80a0ded6aacfbd7329f66ce0 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 13 Feb 2024 04:58:34 +0100 Subject: [PATCH] community/SQLDatabase: Generalize and trim software tests (#16659) - **Description:** Improve test cases for `SQLDatabase` adapter component, see [suggestion](https://github.com/langchain-ai/langchain/pull/16655#pullrequestreview-1846749474). - **Depends on:** GH-16655 - **Addressed to:** @baskaryan, @cbornet, @eyurtsev _Remark: This PR is stacked upon GH-16655, so that one will need to go in first._ Edit: Thank you for bringing in GH-17191, @eyurtsev. This is a little aftermath, improving/streamlining the corresponding test cases. --- .../tests/unit_tests/test_sql_database.py | 101 +++++++++--------- 1 file changed, 52 insertions(+), 49 deletions(-) diff --git a/libs/community/tests/unit_tests/test_sql_database.py b/libs/community/tests/unit_tests/test_sql_database.py index 42aaef7134..4293265a1d 100644 --- a/libs/community/tests/unit_tests/test_sql_database.py +++ b/libs/community/tests/unit_tests/test_sql_database.py @@ -10,7 +10,6 @@ from sqlalchemy import ( String, Table, Text, - create_engine, insert, select, ) @@ -35,11 +34,19 @@ company = Table( ) -def test_table_info() -> None: - """Test that table info is constructed properly.""" - engine = create_engine("sqlite:///:memory:") +@pytest.fixture +def engine() -> sa.Engine: + return sa.create_engine("sqlite:///:memory:") + + +@pytest.fixture +def db(engine: sa.Engine) -> SQLDatabase: metadata_obj.create_all(engine) - db = SQLDatabase(engine) + return SQLDatabase(engine) + + +def test_table_info(db: SQLDatabase) -> None: + """Test that table info is constructed properly.""" output = db.table_info expected_output = """ CREATE TABLE user ( @@ -68,20 +75,19 @@ def test_table_info() -> None: assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split())) -def test_table_info_w_sample_rows() -> None: +def test_table_info_w_sample_rows(db: SQLDatabase) -> None: """Test that table info is constructed properly.""" - engine = create_engine("sqlite:///:memory:") - metadata_obj.create_all(engine) + + # Provision. values = [ {"user_id": 13, "user_name": "Harrison", "user_bio": "bio"}, {"user_id": 14, "user_name": "Chase", "user_bio": "bio"}, ] stmt = insert(user).values(values) - with engine.begin() as conn: - conn.execute(stmt) - - db = SQLDatabase(engine, sample_rows_in_table_info=2) + db._execute(stmt) + # Query and verify. + db = SQLDatabase(db._engine, sample_rows_in_table_info=2) output = db.table_info expected_output = """ @@ -112,16 +118,16 @@ def test_table_info_w_sample_rows() -> None: assert sorted(output.split()) == sorted(expected_output.split()) -def test_sql_database_run_fetch_all() -> None: +def test_sql_database_run_fetch_all(db: SQLDatabase) -> None: """Verify running SQL expressions returning results as strings.""" - engine = create_engine("sqlite:///:memory:") - metadata_obj.create_all(engine) + + # Provision. stmt = insert(user).values( user_id=13, user_name="Harrison", user_bio="That is my Bio " * 24 ) - with engine.begin() as conn: - conn.execute(stmt) - db = SQLDatabase(engine) + db._execute(stmt) + + # Query and verify. command = "select user_id, user_name, user_bio from user where user_id = 13" partial_output = db.run(command) user_bio = "That is my Bio " * 19 + "That is my..." @@ -135,60 +141,57 @@ def test_sql_database_run_fetch_all() -> None: assert full_output == expected_full_output -def test_sql_database_run_fetch_result() -> None: +def test_sql_database_run_fetch_result(db: SQLDatabase) -> None: """Verify running SQL expressions returning results as SQLAlchemy `Result` instances.""" - engine = create_engine("sqlite:///:memory:") - metadata_obj.create_all(engine) - stmt = insert(user).values(user_id=17, user_name="hwchase") - with engine.begin() as conn: - conn.execute(stmt) - db = SQLDatabase(engine) - command = "select user_id, user_name, user_bio from user where user_id = 17" + # Provision. + stmt = insert(user).values(user_id=17, user_name="hwchase") + db._execute(stmt) + + # Query and verify. + command = "select user_id, user_name, user_bio from user where user_id = 17" result = db.run(command, fetch="cursor", include_columns=True) expected = [{"user_id": 17, "user_name": "hwchase", "user_bio": None}] assert isinstance(result, Result) assert result.mappings().fetchall() == expected -def test_sql_database_run_with_parameters() -> None: +def test_sql_database_run_with_parameters(db: SQLDatabase) -> None: """Verify running SQL expressions with query parameters.""" - engine = create_engine("sqlite:///:memory:") - metadata_obj.create_all(engine) - stmt = insert(user).values(user_id=17, user_name="hwchase") - with engine.begin() as conn: - conn.execute(stmt) - db = SQLDatabase(engine) - command = "select user_id, user_name, user_bio from user where user_id = :user_id" + # Provision. + stmt = insert(user).values(user_id=17, user_name="hwchase") + db._execute(stmt) + + # Query and verify. + command = "select user_id, user_name, user_bio from user where user_id = :user_id" full_output = db.run(command, parameters={"user_id": 17}, include_columns=True) expected_full_output = "[{'user_id': 17, 'user_name': 'hwchase', 'user_bio': None}]" assert full_output == expected_full_output -def test_sql_database_run_sqlalchemy_selectable() -> None: +def test_sql_database_run_sqlalchemy_selectable(db: SQLDatabase) -> None: """Verify running SQL expressions using SQLAlchemy selectable.""" - engine = create_engine("sqlite:///:memory:") - metadata_obj.create_all(engine) - stmt = insert(user).values(user_id=17, user_name="hwchase") - with engine.begin() as conn: - conn.execute(stmt) - db = SQLDatabase(engine) - command = select(user).where(user.c.user_id == 17) + # Provision. + stmt = insert(user).values(user_id=17, user_name="hwchase") + db._execute(stmt) + + # Query and verify. + command = select(user).where(user.c.user_id == 17) full_output = db.run(command, include_columns=True) expected_full_output = "[{'user_id': 17, 'user_name': 'hwchase', 'user_bio': None}]" assert full_output == expected_full_output -def test_sql_database_run_update() -> None: +def test_sql_database_run_update(db: SQLDatabase) -> None: """Test commands which return no rows return an empty string.""" - engine = create_engine("sqlite:///:memory:") - metadata_obj.create_all(engine) + + # Provision. stmt = insert(user).values(user_id=13, user_name="Harrison") - with engine.begin() as conn: - conn.execute(stmt) - db = SQLDatabase(engine) + db._execute(stmt) + + # Query and verify. command = "update user set user_name='Updated' where user_id = 13" output = db.run(command) expected_output = "" @@ -198,7 +201,7 @@ def test_sql_database_run_update() -> None: def test_sql_database_schema_translate_map() -> None: """Verify using statement-specific execution options.""" - engine = create_engine("sqlite:///:memory:") + engine = sa.create_engine("sqlite:///:memory:") db = SQLDatabase(engine) # Define query using SQLAlchemy selectable.