From 3f29742adca46b4ed520669a12386542ca2c50b3 Mon Sep 17 00:00:00 2001 From: Francisco Ingham Date: Sat, 18 Feb 2023 15:58:29 -0300 Subject: [PATCH] Sql alchemy commands used in table info (#1135) This approach has several advantages: * it improves the readability of the code * removes incompatibilities between SQL dialects * fixes a bug with `datetime` values in rows and `ast.literal_eval` Huge thanks and credits to @jzluo for finding the weaknesses in the current approach and for the thoughtful discussion on the best way to implement this. --------- Co-authored-by: Francisco Ingham <> Co-authored-by: Jon Luo <20971593+jzluo@users.noreply.github.com> --- langchain/sql_database.py | 98 ++++++++++---------- tests/unit_tests/test_sql_database.py | 12 +-- tests/unit_tests/test_sql_database_schema.py | 11 ++- 3 files changed, 62 insertions(+), 59 deletions(-) diff --git a/langchain/sql_database.py b/langchain/sql_database.py index c23d6ab6..aedde352 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -1,23 +1,12 @@ """SQLAlchemy wrapper around a database.""" from __future__ import annotations -import ast from typing import Any, Iterable, List, Optional -from sqlalchemy import create_engine, inspect +from sqlalchemy import MetaData, create_engine, inspect, select from sqlalchemy.engine import Engine - -_TEMPLATE_PREFIX = """Table data will be described in the following format: - -Table 'table name' has columns: { -column1 name: (column1 type, [list of example values for column1]), -column2 name: (column2 type, [list of example values for column2]), -... -} - -These are the tables you can use, together with their column information: - -""" +from sqlalchemy.exc import ProgrammingError +from sqlalchemy.schema import CreateTable class SQLDatabase: @@ -27,6 +16,7 @@ class SQLDatabase: self, engine: Engine, schema: Optional[str] = None, + metadata: Optional[MetaData] = None, ignore_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None, sample_rows_in_table_info: int = 3, @@ -53,8 +43,15 @@ class SQLDatabase: raise ValueError( f"ignore_tables {missing_tables} not found in database" ) + + if not isinstance(sample_rows_in_table_info, int): + raise TypeError("sample_rows_in_table_info must be an integer") + self._sample_rows_in_table_info = sample_rows_in_table_info + self._metadata = metadata or MetaData() + self._metadata.reflect(bind=self._engine) + @classmethod def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase: """Construct a SQLAlchemy engine from URI.""" @@ -93,52 +90,53 @@ class SQLDatabase: raise ValueError(f"table_names {missing_tables} not found in database") all_table_names = table_names + meta_tables = [ + tbl + for tbl in self._metadata.sorted_tables + if tbl.name in set(all_table_names) + ] + tables = [] - for table_name in all_table_names: - columns = [] - if self.dialect in ("sqlite", "duckdb"): - create_table = self.run( - ( - "SELECT sql FROM sqlite_master WHERE " - f"type='table' AND name='{table_name}'" - ), - fetch="one", - ) - else: - create_table = self.run( - f"SHOW CREATE TABLE `{table_name}`;", + for table in meta_tables: + # add create table command + create_table = str(CreateTable(table).compile(self._engine)) + + if self._sample_rows_in_table_info: + # build the select command + command = select(table).limit(self._sample_rows_in_table_info) + + # save the command in string format + select_star = ( + f"SELECT * FROM '{table.name}' LIMIT " + f"{self._sample_rows_in_table_info}" ) - for column in self._inspector.get_columns(table_name, schema=self._schema): - columns.append(column["name"]) + # save the columns in string format + columns_str = " ".join([col.name for col in table.columns]) - if self._sample_rows_in_table_info: - if self.dialect in ("sqlite", "duckdb"): - select_star = ( - f"SELECT * FROM '{table_name}' LIMIT " - f"{self._sample_rows_in_table_info}" - ) - else: - select_star = ( - f"SELECT * FROM `{table_name}` LIMIT " - f"{self._sample_rows_in_table_info}" + # get the sample rows + with self._engine.connect() as connection: + sample_rows = connection.execute(command) + + try: + # shorten values in the smaple rows + sample_rows = list( + map(lambda ls: [str(i)[:100] for i in ls], sample_rows) ) - sample_rows = self.run(select_star) + # save the sample rows in string format + sample_rows_str = "\n".join([" ".join(row) for row in sample_rows]) - sample_rows_ls = ast.literal_eval(sample_rows) - sample_rows_ls = list( - map(lambda ls: [str(i)[:100] for i in ls], sample_rows_ls) - ) - - columns_str = " ".join(columns) - sample_rows_str = "\n".join([" ".join(row) for row in sample_rows_ls]) + # in some dialects when there are no rows in the table a + # 'ProgrammingError' is returned + except ProgrammingError: + sample_rows_str = "" + # build final info for table tables.append( create_table - + "\n\n" + select_star - + "\n" + + ";\n" + columns_str + "\n" + sample_rows_str @@ -147,7 +145,7 @@ class SQLDatabase: else: tables.append(create_table) - final_str = "\n\n\n".join(tables) + final_str = "\n\n".join(tables) return final_str def run(self, command: str, fetch: str = "all") -> str: diff --git a/tests/unit_tests/test_sql_database.py b/tests/unit_tests/test_sql_database.py index 9e696624..c503c7fb 100644 --- a/tests/unit_tests/test_sql_database.py +++ b/tests/unit_tests/test_sql_database.py @@ -3,7 +3,7 @@ from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert -from langchain.sql_database import _TEMPLATE_PREFIX, SQLDatabase +from langchain.sql_database import SQLDatabase metadata_obj = MetaData() @@ -29,13 +29,13 @@ def test_table_info() -> None: db = SQLDatabase(engine) output = db.table_info expected_output = """ - CREATE TABLE user ( + CREATE TABLE user ( user_id INTEGER NOT NULL, user_name VARCHAR(16) NOT NULL, PRIMARY KEY (user_id) ) - SELECT * FROM 'user' LIMIT 3 + SELECT * FROM 'user' LIMIT 3; user_id user_name @@ -45,7 +45,7 @@ def test_table_info() -> None: PRIMARY KEY (company_id) ) - SELECT * FROM 'company' LIMIT 3 + SELECT * FROM 'company' LIMIT 3; company_id company_location """ @@ -75,7 +75,7 @@ def test_table_info_w_sample_rows() -> None: PRIMARY KEY (company_id) ) - SELECT * FROM 'company' LIMIT 2 + SELECT * FROM 'company' LIMIT 2; company_id company_location @@ -85,7 +85,7 @@ def test_table_info_w_sample_rows() -> None: PRIMARY KEY (user_id) ) - SELECT * FROM 'user' LIMIT 2 + SELECT * FROM 'user' LIMIT 2; user_id user_name 13 Harrison 14 Chase diff --git a/tests/unit_tests/test_sql_database_schema.py b/tests/unit_tests/test_sql_database_schema.py index 3c8b0b73..6251b098 100644 --- a/tests/unit_tests/test_sql_database_schema.py +++ b/tests/unit_tests/test_sql_database_schema.py @@ -45,12 +45,17 @@ def test_table_info() -> None: """Test that table info is constructed properly.""" engine = create_engine("duckdb:///:memory:") metadata_obj.create_all(engine) - db = SQLDatabase(engine, schema="schema_a") + + db = SQLDatabase(engine, schema="schema_a", metadata=metadata_obj) output = db.table_info expected_output = """ - CREATE TABLE schema_a."user"(user_id INTEGER, user_name VARCHAR NOT NULL, PRIMARY KEY(user_id)); + CREATE TABLE schema_a."user" ( + user_id INTEGER NOT NULL, + user_name VARCHAR NOT NULL, + PRIMARY KEY (user_id) + ) - SELECT * FROM 'user' LIMIT 3 + SELECT * FROM 'user' LIMIT 3; user_id user_name """