From 42d979efddb3a09bc982c97a6a3450f3829b70a4 Mon Sep 17 00:00:00 2001 From: Predrag Gruevski <2348618+obi1kenobi@users.noreply.github.com> Date: Tue, 3 Oct 2023 15:19:08 -0400 Subject: [PATCH] Improve type hints and interface for SQL execution functionality. (#11353) The previous API of the `_execute()` function had a few rough edges that this PR addresses: - The `fetch` argument was type-hinted as being able to take any string, but any string other than `"all"` or `"one"` would `raise ValueError`. The new type hints explicitly declare that only those values are supported. - The return type was type-hinted as `Sequence` but using `fetch = "one"` would actually return a single result item. This was incorrectly suppressed using `# type: ignore`. We now always return a list. - Using `fetch = "one"` would return a single item if data was found, or an empty *list* if no data was found. This was confusing, and we now always return a list to simplify. - The return type was `Sequence[Any]` which was a bit difficult to use since it wasn't clear what one could do with the returned rows. I'm making the new type `Dict[str, Any]` that corresponds to the column names and their values in the query. I've updated the use of this method elsewhere in the file to match the new behavior. --- .../langchain/utilities/sql_database.py | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/libs/langchain/langchain/utilities/sql_database.py b/libs/langchain/langchain/utilities/sql_database.py index 13718c8c0c..cd5af96b91 100644 --- a/libs/langchain/langchain/utilities/sql_database.py +++ b/libs/langchain/langchain/utilities/sql_database.py @@ -2,7 +2,7 @@ from __future__ import annotations import warnings -from typing import Any, Iterable, List, Optional, Sequence +from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union import sqlalchemy from sqlalchemy import MetaData, Table, create_engine, inspect, select, text @@ -374,7 +374,11 @@ class SQLDatabase: f"{sample_rows_str}" ) - def _execute(self, command: str, fetch: Optional[str] = "all") -> Sequence: + def _execute( + self, + command: str, + fetch: Union[Literal["all"], Literal["one"]] = "all", + ) -> Sequence[Dict[str, Any]]: """ Executes SQL command through underlying engine. @@ -397,15 +401,20 @@ class SQLDatabase: cursor = connection.execute(text(command)) if cursor.returns_rows: if fetch == "all": - result = cursor.fetchall() + result = [x._asdict() for x in cursor.fetchall()] elif fetch == "one": - result = cursor.fetchone() # type: ignore + first_result = cursor.fetchone() + result = [] if first_result is None else [first_result._asdict()] else: raise ValueError("Fetch parameter must be either 'one' or 'all'") return result return [] - def run(self, command: str, fetch: str = "all") -> str: + def run( + self, + command: str, + fetch: Union[Literal["all"], Literal["one"]] = "all", + ) -> str: """Execute a SQL command and return a string representing the results. If the statement returns rows, a string of the results is returned. @@ -414,18 +423,14 @@ class SQLDatabase: result = self._execute(command, fetch) # Convert columns values to string to avoid issues with sqlalchemy # truncating text - if not result: + res = [ + tuple(truncate_word(c, length=self._max_string_length) for c in r.values()) + for r in result + ] + if not res: return "" - elif isinstance(result, list): - res: Sequence = [ - tuple(truncate_word(c, length=self._max_string_length) for c in r) - for r in result - ] else: - res = tuple( - truncate_word(c, length=self._max_string_length) for c in result - ) - return str(res) + return str(res) def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: """Get information about specified tables. @@ -443,7 +448,11 @@ class SQLDatabase: """Format the error message""" return f"Error: {e}" - def run_no_throw(self, command: str, fetch: str = "all") -> str: + def run_no_throw( + self, + command: str, + fetch: Union[Literal["all"], Literal["one"]] = "all", + ) -> str: """Execute a SQL command and return a string representing the results. If the statement returns rows, a string of the results is returned.