Fix sqlalchemy warnings when running tests (#733)

This has been bugging me when running my own tests that call langchain
methods :P
ankush/async-llmchain
Amos Ng 1 year ago committed by GitHub
parent bd0bf4e0a9
commit fa6826e417
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,8 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import Column, Integer, String, create_engine, select from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, declarative_base
from sqlalchemy.orm import Session
from langchain.schema import Generation from langchain.schema import Generation

@ -86,7 +86,7 @@ class SQLDatabase:
If the statement returns rows, a string of the results is returned. If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned. If the statement returns no rows, an empty string is returned.
""" """
with self._engine.connect() as connection: with self._engine.begin() as connection:
if self._schema is not None: if self._schema is not None:
connection.exec_driver_sql(f"SET search_path TO {self._schema}") connection.exec_driver_sql(f"SET search_path TO {self._schema}")
cursor = connection.exec_driver_sql(command) cursor = connection.exec_driver_sql(command)

@ -1,6 +1,6 @@
"""Test base LLM functionality.""" """Test base LLM functionality."""
from sqlalchemy import Column, Integer, Sequence, String, create_engine from sqlalchemy import Column, Integer, Sequence, String, create_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import declarative_base
import langchain import langchain
from langchain.cache import InMemoryCache, SQLAlchemyCache from langchain.cache import InMemoryCache, SQLAlchemyCache

@ -40,7 +40,7 @@ def test_sql_database_run() -> None:
engine = create_engine("sqlite:///:memory:") engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine) metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison") stmt = insert(user).values(user_id=13, user_name="Harrison")
with engine.connect() as conn: with engine.begin() as conn:
conn.execute(stmt) conn.execute(stmt)
db = SQLDatabase(engine) db = SQLDatabase(engine)
command = "select user_name from user where user_id = 13" command = "select user_name from user where user_id = 13"
@ -54,7 +54,7 @@ def test_sql_database_run_update() -> None:
engine = create_engine("sqlite:///:memory:") engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine) metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison") stmt = insert(user).values(user_id=13, user_name="Harrison")
with engine.connect() as conn: with engine.begin() as conn:
conn.execute(stmt) conn.execute(stmt)
db = SQLDatabase(engine) db = SQLDatabase(engine)
command = "update user set user_name='Updated' where user_id = 13" command = "update user set user_name='Updated' where user_id = 13"

@ -57,7 +57,7 @@ def test_sql_database_run() -> None:
engine = create_engine("duckdb:///:memory:") engine = create_engine("duckdb:///:memory:")
metadata_obj.create_all(engine) metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison") stmt = insert(user).values(user_id=13, user_name="Harrison")
with engine.connect() as conn: with engine.begin() as conn:
conn.execute(stmt) conn.execute(stmt)
db = SQLDatabase(engine, schema="schema_a") db = SQLDatabase(engine, schema="schema_a")
command = 'select user_name from "user" where user_id = 13' command = 'select user_name from "user" where user_id = 13'

Loading…
Cancel
Save