Harrison/sql rows (#915)

Co-authored-by: Jon Luo <20971593+jzluo@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-02-06 18:56:18 -08:00 committed by GitHub
parent ba5a2f06b9
commit f95cedc443
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 103 additions and 44 deletions

View File

@ -57,7 +57,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3d1e692e",
"metadata": {},
@ -94,15 +93,15 @@
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"How many employees are there? \n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m There are 9 employees.\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(8,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m There are 8 employees.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' There are 9 employees.'"
"' There are 8 employees.'"
]
},
"execution_count": 4,
@ -177,15 +176,15 @@
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"How many employees are there in the foobar table? \n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m There are 9 employees in the foobar table.\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(8,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m There are 8 employees in the foobar table.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' There are 9 employees in the foobar table.'"
"' There are 8 employees in the foobar table.'"
]
},
"execution_count": 7,
@ -219,7 +218,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"id": "78b6af4d",
"metadata": {},
"outputs": [
@ -232,18 +231,18 @@
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"How many employees are there in the foobar table? \n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m There are 9 employees in the foobar table.\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(8,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m There are 8 employees in the foobar table.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"[' SELECT COUNT(*) FROM Employee;', '[(9,)]']"
"[' SELECT COUNT(*) FROM Employee;', '[(8,)]']"
]
},
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -264,7 +263,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 10,
"id": "6adaa799",
"metadata": {},
"outputs": [],
@ -274,7 +273,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 11,
"id": "edfc8a8e",
"metadata": {},
"outputs": [
@ -286,8 +285,8 @@
"\n",
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"What are some example tracks by composer Johann Sebastian Bach? \n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name FROM Track WHERE Composer = 'Johann Sebastian Bach' LIMIT 3;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace',), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria',), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude',)]\u001b[0m\n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name, Composer FROM Track WHERE Composer = 'Johann Sebastian Bach' LIMIT 3;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', 'Johann Sebastian Bach')]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m Examples of tracks by Johann Sebastian Bach include 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', and 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude'.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -298,7 +297,7 @@
"' Examples of tracks by Johann Sebastian Bach include \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', and \\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\\'.'"
]
},
"execution_count": 8,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
@ -312,13 +311,13 @@
"id": "bcc5e936",
"metadata": {},
"source": [
"## Adding first row of each table\n",
"Sometimes, the format of the data is not obvious and it is optimal to include the first row of the table in the prompt to allow the LLM to understand the data before providing a final query. Here we will use this feature to let the LLM know that artists are saved with their full names."
"## Adding example rows from each table\n",
"Sometimes, the format of the data is not obvious and it is optimal to include a sample of rows from the tables in the prompt to allow the LLM to understand the data before providing a final query. Here we will use this feature to let the LLM know that artists are saved with their full names by providing two rows from the `Track` table."
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "9a22ee47",
"metadata": {},
"outputs": [],
@ -326,12 +325,40 @@
"db = SQLDatabase.from_uri(\n",
" \"sqlite:///../../../../notebooks/Chinook.db\", \n",
" include_tables=['Track'], # we include only one table to save tokens in the prompt :)\n",
" sample_row_in_table_info=True)"
" sample_rows_in_table_info=2)"
]
},
{
"cell_type": "markdown",
"id": "952c0b4d",
"metadata": {},
"source": [
"The sample rows are added to the prompt after each corresponding table's column information:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "9de86267",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Table 'Track' has columns: TrackId (INTEGER), Name (NVARCHAR(200)), AlbumId (INTEGER), MediaTypeId (INTEGER), GenreId (INTEGER), Composer (NVARCHAR(220)), Milliseconds (INTEGER), Bytes (INTEGER), UnitPrice (NUMERIC(10, 2)). Here is an example of 2 rows from this table (long strings are truncated):\n",
"1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99\n",
"2 Balls to the Wall 2 2 1 None 342562 5510424 0.99\n"
]
}
],
"source": [
"print(db.table_info)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "bcb7a489",
"metadata": {},
"outputs": [],
@ -341,7 +368,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 15,
"id": "81e05d82",
"metadata": {},
"outputs": [
@ -353,20 +380,19 @@
"\n",
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"What are some example tracks by Bach? \n",
"SQLQuery:Table 'Track' has columns: TrackId (INTEGER), Name (NVARCHAR(200)), AlbumId (INTEGER), MediaTypeId (INTEGER), GenreId (INTEGER), Composer (NVARCHAR(220)), Milliseconds (INTEGER), Bytes (INTEGER), UnitPrice (NUMERIC(10, 2)). Here is an example row for this table (long strings are truncated): ['1', 'For Those About To Rock (We Salute You)', '1', '1', '1', 'Angus Young, Malcolm Young, Brian Johnson', '343719', '11170334', '0.99'].\n",
"\u001b[32;1m\u001b[1;3m SELECT TrackId, Name, Composer FROM Track WHERE Composer LIKE '%Bach%' ORDER BY Name LIMIT 5;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(1709, 'American Woman', 'B. Cummings/G. Peterson/M.J. Kale/R. Bachman'), (3408, 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), (3433, 'Concerto No.2 in F Major, BWV1047, I. Allegro', 'Johann Sebastian Bach'), (3407, 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), (3490, 'Partita in E Major, BWV 1006A: I. Prelude', 'Johann Sebastian Bach')]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m Some example tracks by Bach are 'American Woman', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Concerto No.2 in F Major, BWV1047, I. Allegro', 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', and 'Partita in E Major, BWV 1006A: I. Prelude'.\u001b[0m\n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name, Composer FROM Track WHERE Composer LIKE '%Bach%' LIMIT 5;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[('American Woman', 'B. Cummings/G. Peterson/M.J. Kale/R. Bachman'), ('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', 'Johann Sebastian Bach'), ('Toccata and Fugue in D Minor, BWV 565: I. Toccata', 'Johann Sebastian Bach')]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m Some example tracks by Bach are 'American Woman', 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', and 'Toccata and Fugue in D Minor, BWV 565: I. Toccata'.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' Some example tracks by Bach are \\'American Woman\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', \\'Concerto No.2 in F Major, BWV1047, I. Allegro\\', \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', and \\'Partita in E Major, BWV 1006A: I. Prelude\\'.'"
"' Some example tracks by Bach are \\'American Woman\\', \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', \\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\\', and \\'Toccata and Fugue in D Minor, BWV 565: I. Toccata\\'.'"
]
},
"execution_count": 13,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
@ -455,6 +481,10 @@
}
],
"metadata": {
"@webio": {
"lastCommId": null,
"lastKernelId": null
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
@ -470,7 +500,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.2"
}
},
"nbformat": 4,

View File

@ -16,9 +16,16 @@ class SQLDatabase:
schema: Optional[str] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 0,
# TODO: deprecate.
sample_row_in_table_info: bool = False,
):
"""Create engine from database URI."""
if sample_row_in_table_info and sample_rows_in_table_info > 0:
raise ValueError(
"Only one of `sample_row_in_table_info` "
"and `sample_rows_in_table_info` should be set"
)
self._engine = engine
self._schema = schema
if include_tables and ignore_tables:
@ -40,7 +47,10 @@ class SQLDatabase:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
self._sample_row_in_table_info = sample_row_in_table_info
self._sample_rows_in_table_info = sample_rows_in_table_info
# TODO: deprecate
if sample_row_in_table_info:
self._sample_rows_in_table_info = 1
@classmethod
def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase:
@ -64,7 +74,12 @@ class SQLDatabase:
return self.get_table_info()
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables."""
"""Get information about specified tables.
If `sample_rows_in_table_info`, the specified number of sample rows will be
appended to each table description. This can increase performance as
demonstrated by Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498).
"""
all_table_names = self.get_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
@ -83,15 +98,25 @@ class SQLDatabase:
column_str = ", ".join(columns)
table_str = template.format(table_name=table_name, columns=column_str)
if self._sample_row_in_table_info:
if self._sample_rows_in_table_info:
row_template = (
" Here is an example row for this table"
" (long strings are truncated): {sample_row}."
" Here is an example of {n_rows} rows from this table "
"(long strings are truncated):\n"
"{sample_rows}"
)
sample_row = self.run(f"SELECT * FROM '{table_name}' LIMIT 1")
if len(eval(sample_row)) > 0:
sample_row = " ".join([str(i)[:100] for i in eval(sample_row)[0]])
table_str += row_template.format(sample_row=sample_row)
sample_rows = self.run(
f"SELECT * FROM '{table_name}' LIMIT "
f"{self._sample_rows_in_table_info}"
)
sample_rows = eval(sample_rows)
if len(sample_rows) > 0:
n_rows = len(sample_rows)
sample_rows = "\n".join(
[" ".join([str(i)[:100] for i in row]) for row in sample_rows]
)
table_str += row_template.format(
n_rows=n_rows, sample_rows=sample_rows
)
tables.append(table_str)
return "\n".join(tables)

View File

@ -35,23 +35,27 @@ def test_table_info() -> None:
assert sorted(output.split("\n")) == sorted(expected_output)
def test_table_info_w_sample_row() -> None:
def test_table_info_w_sample_rows() -> None:
"""Test that table info is constructed properly."""
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison")
values = [
{"user_id": 13, "user_name": "Harrison"},
{"user_id": 14, "user_name": "Chase"},
]
stmt = insert(user).values(values)
with engine.begin() as conn:
conn.execute(stmt)
db = SQLDatabase(engine, sample_row_in_table_info=True)
db = SQLDatabase(engine, sample_rows_in_table_info=2)
output = db.table_info
expected_output = (
"Table 'company' has columns: company_id (INTEGER), "
"company_location (VARCHAR).\n"
"Table 'user' has columns: user_id (INTEGER), "
"user_name (VARCHAR(16)). Here is an example row "
"for this table (long strings are truncated): 13 Harrison."
"user_name (VARCHAR(16)). Here is an example of 2 rows "
"from this table (long strings are truncated):\n13 Harrison\n14 Chase"
)
assert sorted(output.split("\n")) == sorted(expected_output.split("\n"))