diff --git a/docs/modules/chains/examples/sqlite.ipynb b/docs/modules/chains/examples/sqlite.ipynb index 6f13f805..e49abbeb 100644 --- a/docs/modules/chains/examples/sqlite.ipynb +++ b/docs/modules/chains/examples/sqlite.ipynb @@ -287,14 +287,14 @@ "What are some example tracks by composer Johann Sebastian Bach? \n", "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name, Composer FROM Track WHERE Composer = 'Johann Sebastian Bach' LIMIT 3;\u001b[0m\n", "SQLResult: \u001b[33;1m\u001b[1;3m[('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', 'Johann Sebastian Bach')]\u001b[0m\n", - "Answer:\u001b[32;1m\u001b[1;3m Examples of tracks by Johann Sebastian Bach include 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', and 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude'.\u001b[0m\n", + "Answer:\u001b[32;1m\u001b[1;3m Examples of tracks by composer Johann Sebastian Bach are 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', and 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude'.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "' Examples of tracks by Johann Sebastian Bach include \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', and \\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\\'.'" + "' Examples of tracks by composer Johann Sebastian Bach are \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', and \\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\\'.'" ] }, "execution_count": 11, @@ -317,13 +317,13 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "9a22ee47", "metadata": {}, "outputs": [], "source": [ "db = SQLDatabase.from_uri(\n", - " \"sqlite:///../../../../notebooks/Chinook.db\", \n", + " \"sqlite:///../../../../notebooks/Chinook.db\",\n", " include_tables=['Track'], # we include only one table to save tokens in the prompt :)\n", " sample_rows_in_table_info=2)" ] @@ -338,7 +338,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "9de86267", "metadata": {}, "outputs": [ @@ -346,9 +346,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "Table 'Track' has columns: TrackId (INTEGER), Name (NVARCHAR(200)), AlbumId (INTEGER), MediaTypeId (INTEGER), GenreId (INTEGER), Composer (NVARCHAR(220)), Milliseconds (INTEGER), Bytes (INTEGER), UnitPrice (NUMERIC(10, 2)). Here is an example of 2 rows from this table (long strings are truncated):\n", - "1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99\n", - "2 Balls to the Wall 2 2 1 None 342562 5510424 0.99\n" + "\n", + " Table data will be described in the following format:\n", + "\n", + " Table 'table name' has columns: {column1 name: (column1 type, [list of example values for column1]),\n", + " column2 name: (column2 type, [list of example values for column2], ...)\n", + "\n", + " These are the tables you can use, together with their column information:\n", + "\n", + " Table 'Track' has columns: {'TrackId': ['INTEGER', ['1', '2']], 'Name': ['NVARCHAR(200)', ['For Those About To Rock (We Salute You)', 'Balls to the Wall']], 'AlbumId': ['INTEGER', ['1', '2']], 'MediaTypeId': ['INTEGER', ['1', '2']], 'GenreId': ['INTEGER', ['1', '1']], 'Composer': ['NVARCHAR(220)', ['Angus Young, Malcolm Young, Brian Johnson', 'None']], 'Milliseconds': ['INTEGER', ['343719', '342562']], 'Bytes': ['INTEGER', ['11170334', '5510424']], 'UnitPrice': ['NUMERIC(10, 2)', ['0.99', '0.99']]}\n" ] } ], @@ -358,7 +364,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "id": "bcb7a489", "metadata": {}, "outputs": [], @@ -368,7 +374,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "id": "81e05d82", "metadata": {}, "outputs": [ @@ -380,8 +386,8 @@ "\n", "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n", "What are some example tracks by Bach? \n", - "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name, Composer FROM Track WHERE Composer LIKE '%Bach%' LIMIT 5;\u001b[0m\n", - "SQLResult: \u001b[33;1m\u001b[1;3m[('American Woman', 'B. Cummings/G. Peterson/M.J. Kale/R. Bachman'), ('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', 'Johann Sebastian Bach'), ('Toccata and Fugue in D Minor, BWV 565: I. Toccata', 'Johann Sebastian Bach')]\u001b[0m\n", + "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name FROM Track WHERE Composer LIKE '%Bach%' LIMIT 5;\u001b[0m\n", + "SQLResult: \u001b[33;1m\u001b[1;3m[('American Woman',), ('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace',), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria',), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude',), ('Toccata and Fugue in D Minor, BWV 565: I. Toccata',)]\u001b[0m\n", "Answer:\u001b[32;1m\u001b[1;3m Some example tracks by Bach are 'American Woman', 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', and 'Toccata and Fugue in D Minor, BWV 565: I. Toccata'.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -392,7 +398,7 @@ "' Some example tracks by Bach are \\'American Woman\\', \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', \\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\\', and \\'Toccata and Fugue in D Minor, BWV 565: I. Toccata\\'.'" ] }, - "execution_count": 15, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -500,7 +506,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/langchain/chains/sql_database/prompt.py b/langchain/chains/sql_database/prompt.py index 2b2a97cf..127579e2 100644 --- a/langchain/chains/sql_database/prompt.py +++ b/langchain/chains/sql_database/prompt.py @@ -15,7 +15,7 @@ SQLQuery: "SQL Query to run" SQLResult: "Result of the SQLQuery" Answer: "Final answer here" -Only use the following tables: +Only use the tables listed below. {table_info} diff --git a/langchain/sql_database.py b/langchain/sql_database.py index a03147d9..a9ce0c4f 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -1,11 +1,25 @@ """SQLAlchemy wrapper around a database.""" from __future__ import annotations +import ast +from collections import defaultdict from typing import Any, Iterable, List, Optional from sqlalchemy import create_engine, inspect from sqlalchemy.engine import Engine +_TEMPLATE_PREFIX = """Table data will be described in the following format: + +Table 'table name' has columns: { +column1 name: (column1 type, [list of example values for column1]), +column2 name: (column2 type, [list of example values for column2]), +... +} + +These are the tables you can use, together with their column information: + +""" + class SQLDatabase: """SQLAlchemy wrapper around a database.""" @@ -77,38 +91,33 @@ class SQLDatabase: raise ValueError(f"table_names {missing_tables} not found in database") all_table_names = table_names - template = "Table '{table_name}' has columns: {columns}." - tables = [] for table_name in all_table_names: - columns = [] + columns = defaultdict(list) for column in self._inspector.get_columns(table_name, schema=self._schema): - columns.append(f"{column['name']} ({str(column['type'])})") - column_str = ", ".join(columns) - table_str = template.format(table_name=table_name, columns=column_str) + columns[f"{column['name']}"].append(str(column["type"])) if self._sample_rows_in_table_info: - row_template = ( - " Here is an example of {n_rows} rows from this table " - "(long strings are truncated):\n" - "{sample_rows}" - ) sample_rows = self.run( f"SELECT * FROM '{table_name}' LIMIT " f"{self._sample_rows_in_table_info}" ) - sample_rows = eval(sample_rows) - if len(sample_rows) > 0: - n_rows = len(sample_rows) - sample_rows = "\n".join( - [" ".join([str(i)[:100] for i in row]) for row in sample_rows] - ) - table_str += row_template.format( - n_rows=n_rows, sample_rows=sample_rows + + sample_rows_ls = ast.literal_eval(sample_rows) + sample_rows_ls = list( + map(lambda ls: [str(i)[:100] for i in ls], sample_rows_ls) + ) + + for e, col in enumerate(columns): + columns[col].append( + [row[e] for row in sample_rows_ls] # type: ignore ) + table_str = f"Table '{table_name}' has columns: " + str(dict(columns)) tables.append(table_str) - return "\n".join(tables) + + final_str = _TEMPLATE_PREFIX + "\n".join(tables) + return final_str def run(self, command: str) -> str: """Execute a SQL command and return a string representing the results. diff --git a/tests/unit_tests/test_sql_database.py b/tests/unit_tests/test_sql_database.py index a6d21549..96ac6d51 100644 --- a/tests/unit_tests/test_sql_database.py +++ b/tests/unit_tests/test_sql_database.py @@ -1,8 +1,9 @@ +# flake8: noqa=E501 """Test SQL database wrapper.""" from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert -from langchain.sql_database import SQLDatabase +from langchain.sql_database import _TEMPLATE_PREFIX, SQLDatabase metadata_obj = MetaData() @@ -27,10 +28,10 @@ def test_table_info() -> None: metadata_obj.create_all(engine) db = SQLDatabase(engine) output = db.table_info + output = output[len(_TEMPLATE_PREFIX) :] expected_output = ( - "Table 'company' has columns: company_id (INTEGER), " - "company_location (VARCHAR).", - "Table 'user' has columns: user_id (INTEGER), user_name (VARCHAR(16)).", + "Table 'user' has columns: {'user_id': ['INTEGER'], 'user_name': ['VARCHAR(16)']}", + "Table 'company' has columns: {'company_id': ['INTEGER'], 'company_location': ['VARCHAR']}", ) assert sorted(output.split("\n")) == sorted(expected_output) @@ -50,14 +51,12 @@ def test_table_info_w_sample_rows() -> None: db = SQLDatabase(engine, sample_rows_in_table_info=2) output = db.table_info + output = output[len(_TEMPLATE_PREFIX) :] expected_output = ( - "Table 'company' has columns: company_id (INTEGER), " - "company_location (VARCHAR).\n" - "Table 'user' has columns: user_id (INTEGER), " - "user_name (VARCHAR(16)). Here is an example of 2 rows " - "from this table (long strings are truncated):\n13 Harrison\n14 Chase" + "Table 'user' has columns: {'user_id': ['INTEGER', ['13', '14']], 'user_name': ['VARCHAR(16)', ['Harrison', 'Chase']]}", + "Table 'company' has columns: {'company_id': ['INTEGER', []], 'company_location': ['VARCHAR', []]}", ) - assert sorted(output.split("\n")) == sorted(expected_output.split("\n")) + assert sorted(output.split("\n")) == sorted(expected_output) def test_sql_database_run() -> None: diff --git a/tests/unit_tests/test_sql_database_schema.py b/tests/unit_tests/test_sql_database_schema.py index 6b15e600..3173ce09 100644 --- a/tests/unit_tests/test_sql_database_schema.py +++ b/tests/unit_tests/test_sql_database_schema.py @@ -16,7 +16,7 @@ from sqlalchemy import ( schema, ) -from langchain.sql_database import SQLDatabase +from langchain.sql_database import _TEMPLATE_PREFIX, SQLDatabase metadata_obj = MetaData() @@ -46,10 +46,11 @@ def test_table_info() -> None: metadata_obj.create_all(engine) db = SQLDatabase(engine, schema="schema_a") output = db.table_info + output = output[len(_TEMPLATE_PREFIX) :] expected_output = ( - "Table 'user' has columns: user_id (INTEGER), user_name (VARCHAR).", + "Table 'user' has columns: {'user_id': ['INTEGER'], 'user_name': ['VARCHAR']}" ) - assert sorted(output.split("\n")) == sorted(expected_output) + assert output == expected_output def test_sql_database_run() -> None: