Support SQL statements that return no results (#222)

Adds support for statements such as insert, update etc which do not
return any rows.

`engine.execute` is deprecated and so execution has been updated to use
`connection.exec_driver_sql` as-per:


https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.Engine.execute
harrison/track_intermediate_steps
Andrew Gleave 2 years ago committed by GitHub
parent d368c43648
commit ea67c049f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -66,6 +66,14 @@ class SQLDatabase:
return "\n".join(tables)
def run(self, command: str) -> str:
"""Execute a SQL command and return a string of the results."""
result = self._engine.execute(command).fetchall()
return str(result)
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
with self._engine.connect() as connection:
cursor = connection.exec_driver_sql(command)
if cursor.returns_rows:
result = cursor.fetchall()
return str(result)
return ""

@ -28,3 +28,20 @@ def test_sql_database_run() -> None:
output = db_chain.run("What company does Harrison work at?")
expected_output = " Harrison works at Foo."
assert output == expected_output
def test_sql_database_run_update() -> None:
"""Test that update commands run successfully and returned in correct format."""
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison", user_company="Foo")
with engine.connect() as conn:
conn.execute(stmt)
db = SQLDatabase(engine)
db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db)
output = db_chain.run("Update Harrison's workplace to Bar")
expected_output = " Harrison's workplace has been updated to Bar."
assert output == expected_output
output = db_chain.run("What company does Harrison work at?")
expected_output = " Harrison works at Bar."
assert output == expected_output

@ -47,3 +47,17 @@ def test_sql_database_run() -> None:
output = db.run(command)
expected_output = "[('Harrison',)]"
assert output == expected_output
def test_sql_database_run_update() -> None:
"""Test commands which return no rows return an empty string."""
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison")
with engine.connect() as conn:
conn.execute(stmt)
db = SQLDatabase(engine)
command = "update user set user_name='Updated' where user_id = 13"
output = db.run(command)
expected_output = ""
assert output == expected_output

Loading…
Cancel
Save