From 0a1b1806e91ce74b09ca96db01ab64b81f8d75bd Mon Sep 17 00:00:00 2001 From: Jon Luo <20971593+jzluo@users.noreply.github.com> Date: Tue, 14 Mar 2023 02:08:27 -0400 Subject: [PATCH] sql: do not hard code the LIMIT clause in the table_info section (#1563) Seeing a lot of issues in Discord in which the LLM is not using the correct LIMIT clause for different SQL dialects. ie, it's using `LIMIT` for mssql instead of `TOP`, or instead of `ROWNUM` for Oracle, etc. I think this could be due to us specifying the LIMIT statement in the example rows portion of `table_info`. So the LLM is seeing the `LIMIT` statement used in the prompt. Since we can't specify each dialect's method here, I think it's fine to just replace the `SELECT... LIMIT 3;` statement with `3 rows from table_name table:`, and wrap everything in a block comment directly following the `CREATE` statement. The Rajkumar et al paper wrapped the example rows and `SELECT` statement in a block comment as well anyway. Thoughts @fpingham? --- docs/modules/chains/examples/sqlite.ipynb | 28 +++++++++++--------- langchain/sql_database.py | 24 +++++++---------- tests/unit_tests/test_sql_database.py | 21 ++++++++------- tests/unit_tests/test_sql_database_schema.py | 5 ++-- 4 files changed, 41 insertions(+), 37 deletions(-) diff --git a/docs/modules/chains/examples/sqlite.ipynb b/docs/modules/chains/examples/sqlite.ipynb index c85d3f10..2f614c8c 100644 --- a/docs/modules/chains/examples/sqlite.ipynb +++ b/docs/modules/chains/examples/sqlite.ipynb @@ -377,18 +377,19 @@ "\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \n", "\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\n", ")\n", - "\n", - "SELECT * FROM 'Track' LIMIT 2;\n", + "/*\n", + "2 rows from Track table:\n", "TrackId\tName\tAlbumId\tMediaTypeId\tGenreId\tComposer\tMilliseconds\tBytes\tUnitPrice\n", "1\tFor Those About To Rock (We Salute You)\t1\t1\t1\tAngus Young, Malcolm Young, Brian Johnson\t343719\t11170334\t0.99\n", - "2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n" + "2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n", + "*/\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "/home/jon/projects/langchain/langchain/sql_database.py:121: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.\n", + "/home/jon/projects/langchain/langchain/sql_database.py:135: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.\n", " sample_rows = connection.execute(command)\n" ] } @@ -467,12 +468,13 @@ "\t\"Composer\" NVARCHAR(220),\n", "\tPRIMARY KEY (\"TrackId\")\n", ")\n", - "\n", - "SELECT * FROM 'Track' LIMIT 3;\n", + "/*\n", + "3 rows from Track table:\n", "TrackId\tName\tComposer\n", "1\tFor Those About To Rock (We Salute You)\tAngus Young, Malcolm Young, Brian Johnson\n", "2\tBalls to the Wall\tNone\n", - "3\tMy favorite song ever\tThe coolest composer of all time\"\"\"\n", + "3\tMy favorite song ever\tThe coolest composer of all time\n", + "*/\"\"\"\n", "}" ] }, @@ -492,11 +494,12 @@ "\t\"Name\" NVARCHAR(120), \n", "\tPRIMARY KEY (\"PlaylistId\")\n", ")\n", - "\n", - "SELECT * FROM 'Playlist' LIMIT 2;\n", + "/*\n", + "2 rows from Playlist table:\n", "PlaylistId\tName\n", "1\tMusic\n", "2\tMovies\n", + "*/\n", "\n", "CREATE TABLE Track (\n", "\t\"TrackId\" INTEGER NOT NULL, \n", @@ -504,12 +507,13 @@ "\t\"Composer\" NVARCHAR(220),\n", "\tPRIMARY KEY (\"TrackId\")\n", ")\n", - "\n", - "SELECT * FROM 'Track' LIMIT 3;\n", + "/*\n", + "3 rows from Track table:\n", "TrackId\tName\tComposer\n", "1\tFor Those About To Rock (We Salute You)\tAngus Young, Malcolm Young, Brian Johnson\n", "2\tBalls to the Wall\tNone\n", - "3\tMy favorite song ever\tThe coolest composer of all time\n" + "3\tMy favorite song ever\tThe coolest composer of all time\n", + "*/\n" ] } ], diff --git a/langchain/sql_database.py b/langchain/sql_database.py index 0537b73d..18f76bd2 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -126,12 +126,6 @@ class SQLDatabase: # build the select command command = select(table).limit(self._sample_rows_in_table_info) - # save the command in string format - select_star = ( - f"SELECT * FROM '{table.name}' LIMIT " - f"{self._sample_rows_in_table_info}" - ) - # save the columns in string format columns_str = "\t".join([col.name for col in table.columns]) @@ -152,16 +146,18 @@ class SQLDatabase: except ProgrammingError: sample_rows_str = "" - # build final info for table - tables.append( - create_table - + select_star - + ";\n" - + columns_str - + "\n" - + sample_rows_str + table_info = ( + f"{create_table.rstrip()}\n" + f"/*\n" + f"{self._sample_rows_in_table_info} rows from {table.name} table:\n" + f"{columns_str}\n" + f"{sample_rows_str}\n" + f"*/" ) + # build final info for table + tables.append(table_info) + else: tables.append(create_table) diff --git a/tests/unit_tests/test_sql_database.py b/tests/unit_tests/test_sql_database.py index c503c7fb..3da40c5a 100644 --- a/tests/unit_tests/test_sql_database.py +++ b/tests/unit_tests/test_sql_database.py @@ -34,9 +34,10 @@ def test_table_info() -> None: user_name VARCHAR(16) NOT NULL, PRIMARY KEY (user_id) ) - - SELECT * FROM 'user' LIMIT 3; + /* + 3 rows from user table: user_id user_name + /* CREATE TABLE company ( @@ -44,9 +45,10 @@ def test_table_info() -> None: company_location VARCHAR NOT NULL, PRIMARY KEY (company_id) ) - - SELECT * FROM 'company' LIMIT 3; + /* + 3 rows from company table: company_id company_location + */ """ assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split())) @@ -74,21 +76,22 @@ def test_table_info_w_sample_rows() -> None: company_location VARCHAR NOT NULL, PRIMARY KEY (company_id) ) - - SELECT * FROM 'company' LIMIT 2; + /* + 2 rows from company table: company_id company_location - + */ CREATE TABLE user ( user_id INTEGER NOT NULL, user_name VARCHAR(16) NOT NULL, PRIMARY KEY (user_id) ) - - SELECT * FROM 'user' LIMIT 2; + /* + 2 rows from user table: user_id user_name 13 Harrison 14 Chase + */ """ assert sorted(output.split()) == sorted(expected_output.split()) diff --git a/tests/unit_tests/test_sql_database_schema.py b/tests/unit_tests/test_sql_database_schema.py index 6251b098..58a0ea37 100644 --- a/tests/unit_tests/test_sql_database_schema.py +++ b/tests/unit_tests/test_sql_database_schema.py @@ -54,9 +54,10 @@ def test_table_info() -> None: user_name VARCHAR NOT NULL, PRIMARY KEY (user_id) ) - - SELECT * FROM 'user' LIMIT 3; + /* + 3 rows from user table: user_id user_name + */ """ assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))