sql: do not hard code the LIMIT clause in the table_info section (#1563)

Seeing a lot of issues in Discord in which the LLM is not using the
correct LIMIT clause for different SQL dialects. ie, it's using `LIMIT`
for mssql instead of `TOP`, or instead of `ROWNUM` for Oracle, etc.
I think this could be due to us specifying the LIMIT statement in the
example rows portion of `table_info`. So the LLM is seeing the `LIMIT`
statement used in the prompt.
Since we can't specify each dialect's method here, I think it's fine to
just replace the `SELECT... LIMIT 3;` statement with `3 rows from
table_name table:`, and wrap everything in a block comment directly
following the `CREATE` statement. The Rajkumar et al paper wrapped the
example rows and `SELECT` statement in a block comment as well anyway.
Thoughts @fpingham?
tool-patch
Jon Luo 1 year ago committed by GitHub
parent 9ee2713272
commit 0a1b1806e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -377,18 +377,19 @@
"\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \n",
"\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\n",
")\n",
"\n",
"SELECT * FROM 'Track' LIMIT 2;\n",
"/*\n",
"2 rows from Track table:\n",
"TrackId\tName\tAlbumId\tMediaTypeId\tGenreId\tComposer\tMilliseconds\tBytes\tUnitPrice\n",
"1\tFor Those About To Rock (We Salute You)\t1\t1\t1\tAngus Young, Malcolm Young, Brian Johnson\t343719\t11170334\t0.99\n",
"2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n"
"2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n",
"*/\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jon/projects/langchain/langchain/sql_database.py:121: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.\n",
"/home/jon/projects/langchain/langchain/sql_database.py:135: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.\n",
" sample_rows = connection.execute(command)\n"
]
}
@ -467,12 +468,13 @@
"\t\"Composer\" NVARCHAR(220),\n",
"\tPRIMARY KEY (\"TrackId\")\n",
")\n",
"\n",
"SELECT * FROM 'Track' LIMIT 3;\n",
"/*\n",
"3 rows from Track table:\n",
"TrackId\tName\tComposer\n",
"1\tFor Those About To Rock (We Salute You)\tAngus Young, Malcolm Young, Brian Johnson\n",
"2\tBalls to the Wall\tNone\n",
"3\tMy favorite song ever\tThe coolest composer of all time\"\"\"\n",
"3\tMy favorite song ever\tThe coolest composer of all time\n",
"*/\"\"\"\n",
"}"
]
},
@ -492,11 +494,12 @@
"\t\"Name\" NVARCHAR(120), \n",
"\tPRIMARY KEY (\"PlaylistId\")\n",
")\n",
"\n",
"SELECT * FROM 'Playlist' LIMIT 2;\n",
"/*\n",
"2 rows from Playlist table:\n",
"PlaylistId\tName\n",
"1\tMusic\n",
"2\tMovies\n",
"*/\n",
"\n",
"CREATE TABLE Track (\n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
@ -504,12 +507,13 @@
"\t\"Composer\" NVARCHAR(220),\n",
"\tPRIMARY KEY (\"TrackId\")\n",
")\n",
"\n",
"SELECT * FROM 'Track' LIMIT 3;\n",
"/*\n",
"3 rows from Track table:\n",
"TrackId\tName\tComposer\n",
"1\tFor Those About To Rock (We Salute You)\tAngus Young, Malcolm Young, Brian Johnson\n",
"2\tBalls to the Wall\tNone\n",
"3\tMy favorite song ever\tThe coolest composer of all time\n"
"3\tMy favorite song ever\tThe coolest composer of all time\n",
"*/\n"
]
}
],

@ -126,12 +126,6 @@ class SQLDatabase:
# build the select command
command = select(table).limit(self._sample_rows_in_table_info)
# save the command in string format
select_star = (
f"SELECT * FROM '{table.name}' LIMIT "
f"{self._sample_rows_in_table_info}"
)
# save the columns in string format
columns_str = "\t".join([col.name for col in table.columns])
@ -152,16 +146,18 @@ class SQLDatabase:
except ProgrammingError:
sample_rows_str = ""
# build final info for table
tables.append(
create_table
+ select_star
+ ";\n"
+ columns_str
+ "\n"
+ sample_rows_str
table_info = (
f"{create_table.rstrip()}\n"
f"/*\n"
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
f"{columns_str}\n"
f"{sample_rows_str}\n"
f"*/"
)
# build final info for table
tables.append(table_info)
else:
tables.append(create_table)

@ -34,9 +34,10 @@ def test_table_info() -> None:
user_name VARCHAR(16) NOT NULL,
PRIMARY KEY (user_id)
)
SELECT * FROM 'user' LIMIT 3;
/*
3 rows from user table:
user_id user_name
/*
CREATE TABLE company (
@ -44,9 +45,10 @@ def test_table_info() -> None:
company_location VARCHAR NOT NULL,
PRIMARY KEY (company_id)
)
SELECT * FROM 'company' LIMIT 3;
/*
3 rows from company table:
company_id company_location
*/
"""
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))
@ -74,21 +76,22 @@ def test_table_info_w_sample_rows() -> None:
company_location VARCHAR NOT NULL,
PRIMARY KEY (company_id)
)
SELECT * FROM 'company' LIMIT 2;
/*
2 rows from company table:
company_id company_location
*/
CREATE TABLE user (
user_id INTEGER NOT NULL,
user_name VARCHAR(16) NOT NULL,
PRIMARY KEY (user_id)
)
SELECT * FROM 'user' LIMIT 2;
/*
2 rows from user table:
user_id user_name
13 Harrison
14 Chase
*/
"""
assert sorted(output.split()) == sorted(expected_output.split())

@ -54,9 +54,10 @@ def test_table_info() -> None:
user_name VARCHAR NOT NULL,
PRIMARY KEY (user_id)
)
SELECT * FROM 'user' LIMIT 3;
/*
3 rows from user table:
user_id user_name
*/
"""
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))

Loading…
Cancel
Save