Improve type hints and interface for SQL execution functionality. (#11353)

The previous API of the `_execute()` function had a few rough edges that
this PR addresses:
- The `fetch` argument was type-hinted as being able to take any string,
but any string other than `"all"` or `"one"` would `raise ValueError`.
The new type hints explicitly declare that only those values are
supported.
- The return type was type-hinted as `Sequence` but using `fetch =
"one"` would actually return a single result item. This was incorrectly
suppressed using `# type: ignore`. We now always return a list.
- Using `fetch = "one"` would return a single item if data was found, or
an empty *list* if no data was found. This was confusing, and we now
always return a list to simplify.
- The return type was `Sequence[Any]` which was a bit difficult to use
since it wasn't clear what one could do with the returned rows. I'm
making the new type `Dict[str, Any]` that corresponds to the column
names and their values in the query.

I've updated the use of this method elsewhere in the file to match the
new behavior.
pull/11361/head
Predrag Gruevski 9 months ago committed by GitHub
parent 3bddd708f7
commit 42d979efdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,7 +2,7 @@
from __future__ import annotations
import warnings
from typing import Any, Iterable, List, Optional, Sequence
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union
import sqlalchemy
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
@ -374,7 +374,11 @@ class SQLDatabase:
f"{sample_rows_str}"
)
def _execute(self, command: str, fetch: Optional[str] = "all") -> Sequence:
def _execute(
self,
command: str,
fetch: Union[Literal["all"], Literal["one"]] = "all",
) -> Sequence[Dict[str, Any]]:
"""
Executes SQL command through underlying engine.
@ -397,15 +401,20 @@ class SQLDatabase:
cursor = connection.execute(text(command))
if cursor.returns_rows:
if fetch == "all":
result = cursor.fetchall()
result = [x._asdict() for x in cursor.fetchall()]
elif fetch == "one":
result = cursor.fetchone() # type: ignore
first_result = cursor.fetchone()
result = [] if first_result is None else [first_result._asdict()]
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
return result
return []
def run(self, command: str, fetch: str = "all") -> str:
def run(
self,
command: str,
fetch: Union[Literal["all"], Literal["one"]] = "all",
) -> str:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
@ -414,18 +423,14 @@ class SQLDatabase:
result = self._execute(command, fetch)
# Convert columns values to string to avoid issues with sqlalchemy
# truncating text
if not result:
res = [
tuple(truncate_word(c, length=self._max_string_length) for c in r.values())
for r in result
]
if not res:
return ""
elif isinstance(result, list):
res: Sequence = [
tuple(truncate_word(c, length=self._max_string_length) for c in r)
for r in result
]
else:
res = tuple(
truncate_word(c, length=self._max_string_length) for c in result
)
return str(res)
return str(res)
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables.
@ -443,7 +448,11 @@ class SQLDatabase:
"""Format the error message"""
return f"Error: {e}"
def run_no_throw(self, command: str, fetch: str = "all") -> str:
def run_no_throw(
self,
command: str,
fetch: Union[Literal["all"], Literal["one"]] = "all",
) -> str:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.

Loading…
Cancel
Save