sql: do not hard code the LIMIT clause in the table_info section (#1563)

Seeing a lot of issues in Discord in which the LLM is not using the
correct LIMIT clause for different SQL dialects. ie, it's using `LIMIT`
for mssql instead of `TOP`, or instead of `ROWNUM` for Oracle, etc.
I think this could be due to us specifying the LIMIT statement in the
example rows portion of `table_info`. So the LLM is seeing the `LIMIT`
statement used in the prompt.
Since we can't specify each dialect's method here, I think it's fine to
just replace the `SELECT... LIMIT 3;` statement with `3 rows from
table_name table:`, and wrap everything in a block comment directly
following the `CREATE` statement. The Rajkumar et al paper wrapped the
example rows and `SELECT` statement in a block comment as well anyway.
Thoughts @fpingham?
tool-patch
Jon Luo 2 years ago committed by GitHub
parent 9ee2713272
commit 0a1b1806e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -377,18 +377,19 @@
"\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \n", "\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \n",
"\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\n", "\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\n",
")\n", ")\n",
"\n", "/*\n",
"SELECT * FROM 'Track' LIMIT 2;\n", "2 rows from Track table:\n",
"TrackId\tName\tAlbumId\tMediaTypeId\tGenreId\tComposer\tMilliseconds\tBytes\tUnitPrice\n", "TrackId\tName\tAlbumId\tMediaTypeId\tGenreId\tComposer\tMilliseconds\tBytes\tUnitPrice\n",
"1\tFor Those About To Rock (We Salute You)\t1\t1\t1\tAngus Young, Malcolm Young, Brian Johnson\t343719\t11170334\t0.99\n", "1\tFor Those About To Rock (We Salute You)\t1\t1\t1\tAngus Young, Malcolm Young, Brian Johnson\t343719\t11170334\t0.99\n",
"2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n" "2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n",
"*/\n"
] ]
}, },
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"/home/jon/projects/langchain/langchain/sql_database.py:121: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.\n", "/home/jon/projects/langchain/langchain/sql_database.py:135: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.\n",
" sample_rows = connection.execute(command)\n" " sample_rows = connection.execute(command)\n"
] ]
} }
@ -467,12 +468,13 @@
"\t\"Composer\" NVARCHAR(220),\n", "\t\"Composer\" NVARCHAR(220),\n",
"\tPRIMARY KEY (\"TrackId\")\n", "\tPRIMARY KEY (\"TrackId\")\n",
")\n", ")\n",
"\n", "/*\n",
"SELECT * FROM 'Track' LIMIT 3;\n", "3 rows from Track table:\n",
"TrackId\tName\tComposer\n", "TrackId\tName\tComposer\n",
"1\tFor Those About To Rock (We Salute You)\tAngus Young, Malcolm Young, Brian Johnson\n", "1\tFor Those About To Rock (We Salute You)\tAngus Young, Malcolm Young, Brian Johnson\n",
"2\tBalls to the Wall\tNone\n", "2\tBalls to the Wall\tNone\n",
"3\tMy favorite song ever\tThe coolest composer of all time\"\"\"\n", "3\tMy favorite song ever\tThe coolest composer of all time\n",
"*/\"\"\"\n",
"}" "}"
] ]
}, },
@ -492,11 +494,12 @@
"\t\"Name\" NVARCHAR(120), \n", "\t\"Name\" NVARCHAR(120), \n",
"\tPRIMARY KEY (\"PlaylistId\")\n", "\tPRIMARY KEY (\"PlaylistId\")\n",
")\n", ")\n",
"\n", "/*\n",
"SELECT * FROM 'Playlist' LIMIT 2;\n", "2 rows from Playlist table:\n",
"PlaylistId\tName\n", "PlaylistId\tName\n",
"1\tMusic\n", "1\tMusic\n",
"2\tMovies\n", "2\tMovies\n",
"*/\n",
"\n", "\n",
"CREATE TABLE Track (\n", "CREATE TABLE Track (\n",
"\t\"TrackId\" INTEGER NOT NULL, \n", "\t\"TrackId\" INTEGER NOT NULL, \n",
@ -504,12 +507,13 @@
"\t\"Composer\" NVARCHAR(220),\n", "\t\"Composer\" NVARCHAR(220),\n",
"\tPRIMARY KEY (\"TrackId\")\n", "\tPRIMARY KEY (\"TrackId\")\n",
")\n", ")\n",
"\n", "/*\n",
"SELECT * FROM 'Track' LIMIT 3;\n", "3 rows from Track table:\n",
"TrackId\tName\tComposer\n", "TrackId\tName\tComposer\n",
"1\tFor Those About To Rock (We Salute You)\tAngus Young, Malcolm Young, Brian Johnson\n", "1\tFor Those About To Rock (We Salute You)\tAngus Young, Malcolm Young, Brian Johnson\n",
"2\tBalls to the Wall\tNone\n", "2\tBalls to the Wall\tNone\n",
"3\tMy favorite song ever\tThe coolest composer of all time\n" "3\tMy favorite song ever\tThe coolest composer of all time\n",
"*/\n"
] ]
} }
], ],

@ -126,12 +126,6 @@ class SQLDatabase:
# build the select command # build the select command
command = select(table).limit(self._sample_rows_in_table_info) command = select(table).limit(self._sample_rows_in_table_info)
# save the command in string format
select_star = (
f"SELECT * FROM '{table.name}' LIMIT "
f"{self._sample_rows_in_table_info}"
)
# save the columns in string format # save the columns in string format
columns_str = "\t".join([col.name for col in table.columns]) columns_str = "\t".join([col.name for col in table.columns])
@ -152,16 +146,18 @@ class SQLDatabase:
except ProgrammingError: except ProgrammingError:
sample_rows_str = "" sample_rows_str = ""
# build final info for table table_info = (
tables.append( f"{create_table.rstrip()}\n"
create_table f"/*\n"
+ select_star f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
+ ";\n" f"{columns_str}\n"
+ columns_str f"{sample_rows_str}\n"
+ "\n" f"*/"
+ sample_rows_str
) )
# build final info for table
tables.append(table_info)
else: else:
tables.append(create_table) tables.append(create_table)

@ -34,9 +34,10 @@ def test_table_info() -> None:
user_name VARCHAR(16) NOT NULL, user_name VARCHAR(16) NOT NULL,
PRIMARY KEY (user_id) PRIMARY KEY (user_id)
) )
/*
SELECT * FROM 'user' LIMIT 3; 3 rows from user table:
user_id user_name user_id user_name
/*
CREATE TABLE company ( CREATE TABLE company (
@ -44,9 +45,10 @@ def test_table_info() -> None:
company_location VARCHAR NOT NULL, company_location VARCHAR NOT NULL,
PRIMARY KEY (company_id) PRIMARY KEY (company_id)
) )
/*
SELECT * FROM 'company' LIMIT 3; 3 rows from company table:
company_id company_location company_id company_location
*/
""" """
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split())) assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))
@ -74,21 +76,22 @@ def test_table_info_w_sample_rows() -> None:
company_location VARCHAR NOT NULL, company_location VARCHAR NOT NULL,
PRIMARY KEY (company_id) PRIMARY KEY (company_id)
) )
/*
SELECT * FROM 'company' LIMIT 2; 2 rows from company table:
company_id company_location company_id company_location
*/
CREATE TABLE user ( CREATE TABLE user (
user_id INTEGER NOT NULL, user_id INTEGER NOT NULL,
user_name VARCHAR(16) NOT NULL, user_name VARCHAR(16) NOT NULL,
PRIMARY KEY (user_id) PRIMARY KEY (user_id)
) )
/*
SELECT * FROM 'user' LIMIT 2; 2 rows from user table:
user_id user_name user_id user_name
13 Harrison 13 Harrison
14 Chase 14 Chase
*/
""" """
assert sorted(output.split()) == sorted(expected_output.split()) assert sorted(output.split()) == sorted(expected_output.split())

@ -54,9 +54,10 @@ def test_table_info() -> None:
user_name VARCHAR NOT NULL, user_name VARCHAR NOT NULL,
PRIMARY KEY (user_id) PRIMARY KEY (user_id)
) )
/*
SELECT * FROM 'user' LIMIT 3; 3 rows from user table:
user_id user_name user_id user_name
*/
""" """
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split())) assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))

Loading…
Cancel
Save