From f95cedc4434c8f74e4a03b29fac1aafbc48ebb8a Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 6 Feb 2023 18:56:18 -0800 Subject: [PATCH] Harrison/sql rows (#915) Co-authored-by: Jon Luo <20971593+jzluo@users.noreply.github.com> --- docs/modules/chains/examples/sqlite.ipynb | 90 +++++++++++++++-------- langchain/sql_database.py | 43 ++++++++--- tests/unit_tests/test_sql_database.py | 14 ++-- 3 files changed, 103 insertions(+), 44 deletions(-) diff --git a/docs/modules/chains/examples/sqlite.ipynb b/docs/modules/chains/examples/sqlite.ipynb index aaea8fbf..6f13f805 100644 --- a/docs/modules/chains/examples/sqlite.ipynb +++ b/docs/modules/chains/examples/sqlite.ipynb @@ -57,7 +57,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "3d1e692e", "metadata": {}, @@ -94,15 +93,15 @@ "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n", "How many employees are there? \n", "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n", - "SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n", - "Answer:\u001b[32;1m\u001b[1;3m There are 9 employees.\u001b[0m\n", + "SQLResult: \u001b[33;1m\u001b[1;3m[(8,)]\u001b[0m\n", + "Answer:\u001b[32;1m\u001b[1;3m There are 8 employees.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "' There are 9 employees.'" + "' There are 8 employees.'" ] }, "execution_count": 4, @@ -177,15 +176,15 @@ "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n", "How many employees are there in the foobar table? \n", "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n", - "SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n", - "Answer:\u001b[32;1m\u001b[1;3m There are 9 employees in the foobar table.\u001b[0m\n", + "SQLResult: \u001b[33;1m\u001b[1;3m[(8,)]\u001b[0m\n", + "Answer:\u001b[32;1m\u001b[1;3m There are 8 employees in the foobar table.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "' There are 9 employees in the foobar table.'" + "' There are 8 employees in the foobar table.'" ] }, "execution_count": 7, @@ -219,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "id": "78b6af4d", "metadata": {}, "outputs": [ @@ -232,18 +231,18 @@ "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n", "How many employees are there in the foobar table? \n", "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n", - "SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n", - "Answer:\u001b[32;1m\u001b[1;3m There are 9 employees in the foobar table.\u001b[0m\n", + "SQLResult: \u001b[33;1m\u001b[1;3m[(8,)]\u001b[0m\n", + "Answer:\u001b[32;1m\u001b[1;3m There are 8 employees in the foobar table.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "[' SELECT COUNT(*) FROM Employee;', '[(9,)]']" + "[' SELECT COUNT(*) FROM Employee;', '[(8,)]']" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -264,7 +263,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 10, "id": "6adaa799", "metadata": {}, "outputs": [], @@ -274,7 +273,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 11, "id": "edfc8a8e", "metadata": {}, "outputs": [ @@ -286,8 +285,8 @@ "\n", "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n", "What are some example tracks by composer Johann Sebastian Bach? \n", - "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name FROM Track WHERE Composer = 'Johann Sebastian Bach' LIMIT 3;\u001b[0m\n", - "SQLResult: \u001b[33;1m\u001b[1;3m[('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace',), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria',), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude',)]\u001b[0m\n", + "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name, Composer FROM Track WHERE Composer = 'Johann Sebastian Bach' LIMIT 3;\u001b[0m\n", + "SQLResult: \u001b[33;1m\u001b[1;3m[('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', 'Johann Sebastian Bach')]\u001b[0m\n", "Answer:\u001b[32;1m\u001b[1;3m Examples of tracks by Johann Sebastian Bach include 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', and 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude'.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -298,7 +297,7 @@ "' Examples of tracks by Johann Sebastian Bach include \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', and \\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\\'.'" ] }, - "execution_count": 8, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -312,13 +311,13 @@ "id": "bcc5e936", "metadata": {}, "source": [ - "## Adding first row of each table\n", - "Sometimes, the format of the data is not obvious and it is optimal to include the first row of the table in the prompt to allow the LLM to understand the data before providing a final query. Here we will use this feature to let the LLM know that artists are saved with their full names." + "## Adding example rows from each table\n", + "Sometimes, the format of the data is not obvious and it is optimal to include a sample of rows from the tables in the prompt to allow the LLM to understand the data before providing a final query. Here we will use this feature to let the LLM know that artists are saved with their full names by providing two rows from the `Track` table." ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "9a22ee47", "metadata": {}, "outputs": [], @@ -326,12 +325,40 @@ "db = SQLDatabase.from_uri(\n", " \"sqlite:///../../../../notebooks/Chinook.db\", \n", " include_tables=['Track'], # we include only one table to save tokens in the prompt :)\n", - " sample_row_in_table_info=True)" + " sample_rows_in_table_info=2)" + ] + }, + { + "cell_type": "markdown", + "id": "952c0b4d", + "metadata": {}, + "source": [ + "The sample rows are added to the prompt after each corresponding table's column information:" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, + "id": "9de86267", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Table 'Track' has columns: TrackId (INTEGER), Name (NVARCHAR(200)), AlbumId (INTEGER), MediaTypeId (INTEGER), GenreId (INTEGER), Composer (NVARCHAR(220)), Milliseconds (INTEGER), Bytes (INTEGER), UnitPrice (NUMERIC(10, 2)). Here is an example of 2 rows from this table (long strings are truncated):\n", + "1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99\n", + "2 Balls to the Wall 2 2 1 None 342562 5510424 0.99\n" + ] + } + ], + "source": [ + "print(db.table_info)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, "id": "bcb7a489", "metadata": {}, "outputs": [], @@ -341,7 +368,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "id": "81e05d82", "metadata": {}, "outputs": [ @@ -353,20 +380,19 @@ "\n", "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n", "What are some example tracks by Bach? \n", - "SQLQuery:Table 'Track' has columns: TrackId (INTEGER), Name (NVARCHAR(200)), AlbumId (INTEGER), MediaTypeId (INTEGER), GenreId (INTEGER), Composer (NVARCHAR(220)), Milliseconds (INTEGER), Bytes (INTEGER), UnitPrice (NUMERIC(10, 2)). Here is an example row for this table (long strings are truncated): ['1', 'For Those About To Rock (We Salute You)', '1', '1', '1', 'Angus Young, Malcolm Young, Brian Johnson', '343719', '11170334', '0.99'].\n", - "\u001b[32;1m\u001b[1;3m SELECT TrackId, Name, Composer FROM Track WHERE Composer LIKE '%Bach%' ORDER BY Name LIMIT 5;\u001b[0m\n", - "SQLResult: \u001b[33;1m\u001b[1;3m[(1709, 'American Woman', 'B. Cummings/G. Peterson/M.J. Kale/R. Bachman'), (3408, 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), (3433, 'Concerto No.2 in F Major, BWV1047, I. Allegro', 'Johann Sebastian Bach'), (3407, 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), (3490, 'Partita in E Major, BWV 1006A: I. Prelude', 'Johann Sebastian Bach')]\u001b[0m\n", - "Answer:\u001b[32;1m\u001b[1;3m Some example tracks by Bach are 'American Woman', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Concerto No.2 in F Major, BWV1047, I. Allegro', 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', and 'Partita in E Major, BWV 1006A: I. Prelude'.\u001b[0m\n", + "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name, Composer FROM Track WHERE Composer LIKE '%Bach%' LIMIT 5;\u001b[0m\n", + "SQLResult: \u001b[33;1m\u001b[1;3m[('American Woman', 'B. Cummings/G. Peterson/M.J. Kale/R. Bachman'), ('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', 'Johann Sebastian Bach'), ('Toccata and Fugue in D Minor, BWV 565: I. Toccata', 'Johann Sebastian Bach')]\u001b[0m\n", + "Answer:\u001b[32;1m\u001b[1;3m Some example tracks by Bach are 'American Woman', 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', and 'Toccata and Fugue in D Minor, BWV 565: I. Toccata'.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "' Some example tracks by Bach are \\'American Woman\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', \\'Concerto No.2 in F Major, BWV1047, I. Allegro\\', \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', and \\'Partita in E Major, BWV 1006A: I. Prelude\\'.'" + "' Some example tracks by Bach are \\'American Woman\\', \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', \\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\\', and \\'Toccata and Fugue in D Minor, BWV 565: I. Toccata\\'.'" ] }, - "execution_count": 13, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -455,6 +481,10 @@ } ], "metadata": { + "@webio": { + "lastCommId": null, + "lastKernelId": null + }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", @@ -470,7 +500,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.9.2" } }, "nbformat": 4, diff --git a/langchain/sql_database.py b/langchain/sql_database.py index 27fc98a6..8fe071d9 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -16,9 +16,16 @@ class SQLDatabase: schema: Optional[str] = None, ignore_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None, + sample_rows_in_table_info: int = 0, + # TODO: deprecate. sample_row_in_table_info: bool = False, ): """Create engine from database URI.""" + if sample_row_in_table_info and sample_rows_in_table_info > 0: + raise ValueError( + "Only one of `sample_row_in_table_info` " + "and `sample_rows_in_table_info` should be set" + ) self._engine = engine self._schema = schema if include_tables and ignore_tables: @@ -40,7 +47,10 @@ class SQLDatabase: raise ValueError( f"ignore_tables {missing_tables} not found in database" ) - self._sample_row_in_table_info = sample_row_in_table_info + self._sample_rows_in_table_info = sample_rows_in_table_info + # TODO: deprecate + if sample_row_in_table_info: + self._sample_rows_in_table_info = 1 @classmethod def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase: @@ -64,7 +74,12 @@ class SQLDatabase: return self.get_table_info() def get_table_info(self, table_names: Optional[List[str]] = None) -> str: - """Get information about specified tables.""" + """Get information about specified tables. + + If `sample_rows_in_table_info`, the specified number of sample rows will be + appended to each table description. This can increase performance as + demonstrated by Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498). + """ all_table_names = self.get_table_names() if table_names is not None: missing_tables = set(table_names).difference(all_table_names) @@ -83,15 +98,25 @@ class SQLDatabase: column_str = ", ".join(columns) table_str = template.format(table_name=table_name, columns=column_str) - if self._sample_row_in_table_info: + if self._sample_rows_in_table_info: row_template = ( - " Here is an example row for this table" - " (long strings are truncated): {sample_row}." + " Here is an example of {n_rows} rows from this table " + "(long strings are truncated):\n" + "{sample_rows}" + ) + sample_rows = self.run( + f"SELECT * FROM '{table_name}' LIMIT " + f"{self._sample_rows_in_table_info}" ) - sample_row = self.run(f"SELECT * FROM '{table_name}' LIMIT 1") - if len(eval(sample_row)) > 0: - sample_row = " ".join([str(i)[:100] for i in eval(sample_row)[0]]) - table_str += row_template.format(sample_row=sample_row) + sample_rows = eval(sample_rows) + if len(sample_rows) > 0: + n_rows = len(sample_rows) + sample_rows = "\n".join( + [" ".join([str(i)[:100] for i in row]) for row in sample_rows] + ) + table_str += row_template.format( + n_rows=n_rows, sample_rows=sample_rows + ) tables.append(table_str) return "\n".join(tables) diff --git a/tests/unit_tests/test_sql_database.py b/tests/unit_tests/test_sql_database.py index d735d055..a6d21549 100644 --- a/tests/unit_tests/test_sql_database.py +++ b/tests/unit_tests/test_sql_database.py @@ -35,23 +35,27 @@ def test_table_info() -> None: assert sorted(output.split("\n")) == sorted(expected_output) -def test_table_info_w_sample_row() -> None: +def test_table_info_w_sample_rows() -> None: """Test that table info is constructed properly.""" engine = create_engine("sqlite:///:memory:") metadata_obj.create_all(engine) - stmt = insert(user).values(user_id=13, user_name="Harrison") + values = [ + {"user_id": 13, "user_name": "Harrison"}, + {"user_id": 14, "user_name": "Chase"}, + ] + stmt = insert(user).values(values) with engine.begin() as conn: conn.execute(stmt) - db = SQLDatabase(engine, sample_row_in_table_info=True) + db = SQLDatabase(engine, sample_rows_in_table_info=2) output = db.table_info expected_output = ( "Table 'company' has columns: company_id (INTEGER), " "company_location (VARCHAR).\n" "Table 'user' has columns: user_id (INTEGER), " - "user_name (VARCHAR(16)). Here is an example row " - "for this table (long strings are truncated): 13 Harrison." + "user_name (VARCHAR(16)). Here is an example of 2 rows " + "from this table (long strings are truncated):\n13 Harrison\n14 Chase" ) assert sorted(output.split("\n")) == sorted(expected_output.split("\n"))