SQL Query Prompt update + added _execute method for SQLDatabase (#8100)

- Description: This pull request (PR) includes two minor changes:

1. Updated the default prompt for SQL Query Checker: The current prompt
does not clearly specify the final response that the LLM (Language
Model) should provide when checking for the query if `use_query_checker`
is enabled in SQLDatabase Chain. As a result, the LLM adds extra words
like "Here is your updated query" to the response. However, this causes
a syntax error when executing the SQL command in SQLDatabaseChain, as
these additional words are also included in the SQL query.

2. Moved the query's execution part into a separate method for
SQLDatabase: The purpose of this change is to provide users with more
flexibility when obtaining the result of an SQL query in the original
form returned by sqlalchemy. In the previous implementation, the run
method returned the results as a string. By creating a distinct method
for execution, users can now receive the results in original format,
which proves helpful in various scenarios. For example, during the
development of a tool, I found it advantageous to obtain results in
original format rather than a string, as currently done by the run
method.

- Tag maintainer: @hinthornw

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Mohammad Mohtashim 2023-08-01 04:28:08 +05:00 committed by GitHub
parent 844eca98d5
commit 144b4c0c78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 25 deletions

View File

@ -11,4 +11,8 @@ Double check the {dialect} query above for common mistakes, including:
- Casting to the correct data type - Casting to the correct data type
- Using the proper columns for joins - Using the proper columns for joins
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.""" If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
Output the final SQL query only.
SQL Query: """

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional, Sequence
import sqlalchemy import sqlalchemy
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
@ -368,12 +368,11 @@ class SQLDatabase:
f"{sample_rows_str}" f"{sample_rows_str}"
) )
def run(self, command: str, fetch: str = "all") -> str: def _execute(self, command: str, fetch: Optional[str] = "all") -> Sequence:
"""Execute a SQL command and return a string representing the results. """
Executes SQL command through underlying engine.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
If the statement returns no rows, an empty list is returned.
""" """
with self._engine.begin() as connection: with self._engine.begin() as connection:
if self._schema is not None: if self._schema is not None:
@ -395,26 +394,30 @@ class SQLDatabase:
result = cursor.fetchone() # type: ignore result = cursor.fetchone() # type: ignore
else: else:
raise ValueError("Fetch parameter must be either 'one' or 'all'") raise ValueError("Fetch parameter must be either 'one' or 'all'")
return result
return []
# Convert columns values to string to avoid issues with sqlalchmey def run(self, command: str, fetch: str = "all") -> str:
# trunacating text """Execute a SQL command and return a string representing the results.
if isinstance(result, list):
return str(
[
tuple(
truncate_word(c, length=self._max_string_length)
for c in r
)
for r in result
]
)
return str( If the statement returns rows, a string of the results is returned.
tuple( If the statement returns no rows, an empty string is returned.
truncate_word(c, length=self._max_string_length) for c in result """
) result = self._execute(command, fetch)
) # Convert columns values to string to avoid issues with sqlalchemy
return "" # truncating text
if not result:
return ""
elif isinstance(result, list):
res: Sequence = [
tuple(truncate_word(c, length=self._max_string_length) for c in r)
for r in result
]
else:
res = tuple(
truncate_word(c, length=self._max_string_length) for c in result
)
return str(res)
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables. """Get information about specified tables.