From 5e10e19bfeafe2e813ee6ac6c9acfef9028d9680 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 15 Feb 2023 23:53:37 -0800 Subject: [PATCH] Harrison/align table (#1081) Co-authored-by: Francisco Ingham --- docs/modules/chains/examples/sqlite.ipynb | 51 ++++++++++++++---- langchain/sql_database.py | 55 +++++++++++++------ tests/unit_tests/test_sql_database.py | 57 ++++++++++++++++---- tests/unit_tests/test_sql_database_schema.py | 16 +++--- 4 files changed, 137 insertions(+), 42 deletions(-) diff --git a/docs/modules/chains/examples/sqlite.ipynb b/docs/modules/chains/examples/sqlite.ipynb index e49abbebd0..909b70278d 100644 --- a/docs/modules/chains/examples/sqlite.ipynb +++ b/docs/modules/chains/examples/sqlite.ipynb @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "id": "d0e27d88", "metadata": { "pycharm": { @@ -43,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "72ede462", "metadata": { "pycharm": { @@ -346,15 +346,46 @@ "name": "stdout", "output_type": "stream", "text": [ + "CREATE TABLE [Album]\n", + "(\n", + " [AlbumId] INTEGER NOT NULL,\n", + " [Title] NVARCHAR(160) NOT NULL,\n", + " [ArtistId] INTEGER NOT NULL,\n", + " CONSTRAINT [PK_Album] PRIMARY KEY ([AlbumId]),\n", + " FOREIGN KEY ([ArtistId]) REFERENCES [Artist] ([ArtistId]) \n", + "\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n", + ")\n", "\n", - " Table data will be described in the following format:\n", + "SELECT * FROM 'Album' LIMIT 2\n", + "AlbumId Title ArtistId\n", + "1 For Those About To Rock We Salute You 1\n", + "2 Balls to the Wall 2\n", "\n", - " Table 'table name' has columns: {column1 name: (column1 type, [list of example values for column1]),\n", - " column2 name: (column2 type, [list of example values for column2], ...)\n", "\n", - " These are the tables you can use, together with their column information:\n", + "CREATE TABLE [Track]\n", + "(\n", + " [TrackId] INTEGER NOT NULL,\n", + " [Name] NVARCHAR(200) NOT NULL,\n", + " [AlbumId] INTEGER,\n", + " [MediaTypeId] INTEGER NOT NULL,\n", + " [GenreId] INTEGER,\n", + " [Composer] NVARCHAR(220),\n", + " [Milliseconds] INTEGER NOT NULL,\n", + " [Bytes] INTEGER,\n", + " [UnitPrice] NUMERIC(10,2) NOT NULL,\n", + " CONSTRAINT [PK_Track] PRIMARY KEY ([TrackId]),\n", + " FOREIGN KEY ([AlbumId]) REFERENCES [Album] ([AlbumId]) \n", + "\t\tON DELETE NO ACTION ON UPDATE NO ACTION,\n", + " FOREIGN KEY ([GenreId]) REFERENCES [Genre] ([GenreId]) \n", + "\t\tON DELETE NO ACTION ON UPDATE NO ACTION,\n", + " FOREIGN KEY ([MediaTypeId]) REFERENCES [MediaType] ([MediaTypeId]) \n", + "\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n", + ")\n", "\n", - " Table 'Track' has columns: {'TrackId': ['INTEGER', ['1', '2']], 'Name': ['NVARCHAR(200)', ['For Those About To Rock (We Salute You)', 'Balls to the Wall']], 'AlbumId': ['INTEGER', ['1', '2']], 'MediaTypeId': ['INTEGER', ['1', '2']], 'GenreId': ['INTEGER', ['1', '1']], 'Composer': ['NVARCHAR(220)', ['Angus Young, Malcolm Young, Brian Johnson', 'None']], 'Milliseconds': ['INTEGER', ['343719', '342562']], 'Bytes': ['INTEGER', ['11170334', '5510424']], 'UnitPrice': ['NUMERIC(10, 2)', ['0.99', '0.99']]}\n" + "SELECT * FROM 'Track' LIMIT 2\n", + "TrackId Name AlbumId MediaTypeId GenreId Composer Milliseconds Bytes UnitPrice\n", + "1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99\n", + "2 Balls to the Wall 2 2 1 None 342562 5510424 0.99\n" ] } ], @@ -492,9 +523,9 @@ "lastKernelId": null }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "langchain", "language": "python", - "name": "python3" + "name": "langchain" }, "language_info": { "codemirror_mode": { @@ -506,7 +537,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.8.16" } }, "nbformat": 4, diff --git a/langchain/sql_database.py b/langchain/sql_database.py index a9ce0c4ff2..078c527463 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -2,7 +2,6 @@ from __future__ import annotations import ast -from collections import defaultdict from typing import Any, Iterable, List, Optional from sqlalchemy import create_engine, inspect @@ -30,7 +29,7 @@ class SQLDatabase: schema: Optional[str] = None, ignore_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None, - sample_rows_in_table_info: int = 0, + sample_rows_in_table_info: int = 3, ): """Create engine from database URI.""" self._engine = engine @@ -80,9 +79,12 @@ class SQLDatabase: def get_table_info(self, table_names: Optional[List[str]] = None) -> str: """Get information about specified tables. + Follows best practices as specified in: Rajkumar et al, 2022 + (https://arxiv.org/abs/2204.00498) + If `sample_rows_in_table_info`, the specified number of sample rows will be appended to each table description. This can increase performance as - demonstrated by Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498). + demonstrated in the paper. """ all_table_names = self.get_table_names() if table_names is not None: @@ -93,33 +95,51 @@ class SQLDatabase: tables = [] for table_name in all_table_names: - columns = defaultdict(list) + columns = [] + create_table = self.run( + ( + "SELECT sql FROM sqlite_master WHERE " + f"type='table' AND name='{table_name}'" + ), + fetch="one", + ) + for column in self._inspector.get_columns(table_name, schema=self._schema): - columns[f"{column['name']}"].append(str(column["type"])) + columns.append(column["name"]) if self._sample_rows_in_table_info: - sample_rows = self.run( + select_star = ( f"SELECT * FROM '{table_name}' LIMIT " f"{self._sample_rows_in_table_info}" ) + sample_rows = self.run(select_star) + sample_rows_ls = ast.literal_eval(sample_rows) sample_rows_ls = list( map(lambda ls: [str(i)[:100] for i in ls], sample_rows_ls) ) - for e, col in enumerate(columns): - columns[col].append( - [row[e] for row in sample_rows_ls] # type: ignore - ) + columns_str = " ".join(columns) + sample_rows_str = "\n".join([" ".join(row) for row in sample_rows_ls]) - table_str = f"Table '{table_name}' has columns: " + str(dict(columns)) - tables.append(table_str) + tables.append( + create_table + + "\n\n" + + select_star + + "\n" + + columns_str + + "\n" + + sample_rows_str + ) - final_str = _TEMPLATE_PREFIX + "\n".join(tables) + else: + tables.append(create_table) + + final_str = "\n\n\n".join(tables) return final_str - def run(self, command: str) -> str: + def run(self, command: str, fetch: str = "all") -> str: """Execute a SQL command and return a string representing the results. If the statement returns rows, a string of the results is returned. @@ -130,6 +150,11 @@ class SQLDatabase: connection.exec_driver_sql(f"SET search_path TO {self._schema}") cursor = connection.exec_driver_sql(command) if cursor.returns_rows: - result = cursor.fetchall() + if fetch == "all": + result = cursor.fetchall() + elif fetch == "one": + result = cursor.fetchone()[0] + else: + raise ValueError("Fetch parameter must be either 'one' or 'all'") return str(result) return "" diff --git a/tests/unit_tests/test_sql_database.py b/tests/unit_tests/test_sql_database.py index 96ac6d519b..9e696624b3 100644 --- a/tests/unit_tests/test_sql_database.py +++ b/tests/unit_tests/test_sql_database.py @@ -28,12 +28,28 @@ def test_table_info() -> None: metadata_obj.create_all(engine) db = SQLDatabase(engine) output = db.table_info - output = output[len(_TEMPLATE_PREFIX) :] - expected_output = ( - "Table 'user' has columns: {'user_id': ['INTEGER'], 'user_name': ['VARCHAR(16)']}", - "Table 'company' has columns: {'company_id': ['INTEGER'], 'company_location': ['VARCHAR']}", + expected_output = """ + CREATE TABLE user ( + user_id INTEGER NOT NULL, + user_name VARCHAR(16) NOT NULL, + PRIMARY KEY (user_id) ) - assert sorted(output.split("\n")) == sorted(expected_output) + + SELECT * FROM 'user' LIMIT 3 + user_id user_name + + + CREATE TABLE company ( + company_id INTEGER NOT NULL, + company_location VARCHAR NOT NULL, + PRIMARY KEY (company_id) + ) + + SELECT * FROM 'company' LIMIT 3 + company_id company_location + """ + + assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split())) def test_table_info_w_sample_rows() -> None: @@ -51,12 +67,31 @@ def test_table_info_w_sample_rows() -> None: db = SQLDatabase(engine, sample_rows_in_table_info=2) output = db.table_info - output = output[len(_TEMPLATE_PREFIX) :] - expected_output = ( - "Table 'user' has columns: {'user_id': ['INTEGER', ['13', '14']], 'user_name': ['VARCHAR(16)', ['Harrison', 'Chase']]}", - "Table 'company' has columns: {'company_id': ['INTEGER', []], 'company_location': ['VARCHAR', []]}", - ) - assert sorted(output.split("\n")) == sorted(expected_output) + + expected_output = """ + CREATE TABLE company ( + company_id INTEGER NOT NULL, + company_location VARCHAR NOT NULL, + PRIMARY KEY (company_id) +) + + SELECT * FROM 'company' LIMIT 2 + company_id company_location + + + CREATE TABLE user ( + user_id INTEGER NOT NULL, + user_name VARCHAR(16) NOT NULL, + PRIMARY KEY (user_id) + ) + + SELECT * FROM 'user' LIMIT 2 + user_id user_name + 13 Harrison + 14 Chase + """ + + assert sorted(output.split()) == sorted(expected_output.split()) def test_sql_database_run() -> None: diff --git a/tests/unit_tests/test_sql_database_schema.py b/tests/unit_tests/test_sql_database_schema.py index 3173ce09b1..3c8b0b7381 100644 --- a/tests/unit_tests/test_sql_database_schema.py +++ b/tests/unit_tests/test_sql_database_schema.py @@ -1,3 +1,4 @@ +# flake8: noqa """Test SQL database wrapper with schema support. Using DuckDB as SQLite does not support schemas. @@ -16,7 +17,7 @@ from sqlalchemy import ( schema, ) -from langchain.sql_database import _TEMPLATE_PREFIX, SQLDatabase +from langchain.sql_database import SQLDatabase metadata_obj = MetaData() @@ -46,11 +47,14 @@ def test_table_info() -> None: metadata_obj.create_all(engine) db = SQLDatabase(engine, schema="schema_a") output = db.table_info - output = output[len(_TEMPLATE_PREFIX) :] - expected_output = ( - "Table 'user' has columns: {'user_id': ['INTEGER'], 'user_name': ['VARCHAR']}" - ) - assert output == expected_output + expected_output = """ + CREATE TABLE schema_a."user"(user_id INTEGER, user_name VARCHAR NOT NULL, PRIMARY KEY(user_id)); + + SELECT * FROM 'user' LIMIT 3 + user_id user_name + """ + + assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split())) def test_sql_database_run() -> None: