Sql alchemy commands used in table info (#1135)

This approach has several advantages:

* it improves the readability of the code
* removes incompatibilities between SQL dialects
* fixes a bug with `datetime` values in rows and `ast.literal_eval`

Huge thanks and credits to @jzluo for finding the weaknesses in the
current approach and for the thoughtful discussion on the best way to
implement this.

---------

Co-authored-by: Francisco Ingham <>
Co-authored-by: Jon Luo <20971593+jzluo@users.noreply.github.com>
This commit is contained in:
Francisco Ingham 2023-02-18 15:58:29 -03:00 committed by GitHub
parent 483821ea3b
commit 3f29742adc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 63 additions and 60 deletions

View File

@ -1,23 +1,12 @@
"""SQLAlchemy wrapper around a database."""
from __future__ import annotations
import ast
from typing import Any, Iterable, List, Optional
from sqlalchemy import create_engine, inspect
from sqlalchemy import MetaData, create_engine, inspect, select
from sqlalchemy.engine import Engine
_TEMPLATE_PREFIX = """Table data will be described in the following format:
Table 'table name' has columns: {
column1 name: (column1 type, [list of example values for column1]),
column2 name: (column2 type, [list of example values for column2]),
...
}
These are the tables you can use, together with their column information:
"""
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.schema import CreateTable
class SQLDatabase:
@ -27,6 +16,7 @@ class SQLDatabase:
self,
engine: Engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
@ -53,8 +43,15 @@ class SQLDatabase:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
if not isinstance(sample_rows_in_table_info, int):
raise TypeError("sample_rows_in_table_info must be an integer")
self._sample_rows_in_table_info = sample_rows_in_table_info
self._metadata = metadata or MetaData()
self._metadata.reflect(bind=self._engine)
@classmethod
def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase:
"""Construct a SQLAlchemy engine from URI."""
@ -93,52 +90,53 @@ class SQLDatabase:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
tables = []
for table_name in all_table_names:
columns = []
if self.dialect in ("sqlite", "duckdb"):
create_table = self.run(
(
"SELECT sql FROM sqlite_master WHERE "
f"type='table' AND name='{table_name}'"
),
fetch="one",
)
else:
create_table = self.run(
f"SHOW CREATE TABLE `{table_name}`;",
)
meta_tables = [
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
]
for column in self._inspector.get_columns(table_name, schema=self._schema):
columns.append(column["name"])
tables = []
for table in meta_tables:
# add create table command
create_table = str(CreateTable(table).compile(self._engine))
if self._sample_rows_in_table_info:
if self.dialect in ("sqlite", "duckdb"):
select_star = (
f"SELECT * FROM '{table_name}' LIMIT "
f"{self._sample_rows_in_table_info}"
)
else:
select_star = (
f"SELECT * FROM `{table_name}` LIMIT "
f"{self._sample_rows_in_table_info}"
)
# build the select command
command = select(table).limit(self._sample_rows_in_table_info)
sample_rows = self.run(select_star)
sample_rows_ls = ast.literal_eval(sample_rows)
sample_rows_ls = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_ls)
# save the command in string format
select_star = (
f"SELECT * FROM '{table.name}' LIMIT "
f"{self._sample_rows_in_table_info}"
)
columns_str = " ".join(columns)
sample_rows_str = "\n".join([" ".join(row) for row in sample_rows_ls])
# save the columns in string format
columns_str = " ".join([col.name for col in table.columns])
# get the sample rows
with self._engine.connect() as connection:
sample_rows = connection.execute(command)
try:
# shorten values in the smaple rows
sample_rows = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows)
)
# save the sample rows in string format
sample_rows_str = "\n".join([" ".join(row) for row in sample_rows])
# in some dialects when there are no rows in the table a
# 'ProgrammingError' is returned
except ProgrammingError:
sample_rows_str = ""
# build final info for table
tables.append(
create_table
+ "\n\n"
+ select_star
+ "\n"
+ ";\n"
+ columns_str
+ "\n"
+ sample_rows_str
@ -147,7 +145,7 @@ class SQLDatabase:
else:
tables.append(create_table)
final_str = "\n\n\n".join(tables)
final_str = "\n\n".join(tables)
return final_str
def run(self, command: str, fetch: str = "all") -> str:

View File

@ -3,7 +3,7 @@
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert
from langchain.sql_database import _TEMPLATE_PREFIX, SQLDatabase
from langchain.sql_database import SQLDatabase
metadata_obj = MetaData()
@ -35,7 +35,7 @@ def test_table_info() -> None:
PRIMARY KEY (user_id)
)
SELECT * FROM 'user' LIMIT 3
SELECT * FROM 'user' LIMIT 3;
user_id user_name
@ -45,7 +45,7 @@ def test_table_info() -> None:
PRIMARY KEY (company_id)
)
SELECT * FROM 'company' LIMIT 3
SELECT * FROM 'company' LIMIT 3;
company_id company_location
"""
@ -75,7 +75,7 @@ def test_table_info_w_sample_rows() -> None:
PRIMARY KEY (company_id)
)
SELECT * FROM 'company' LIMIT 2
SELECT * FROM 'company' LIMIT 2;
company_id company_location
@ -85,7 +85,7 @@ def test_table_info_w_sample_rows() -> None:
PRIMARY KEY (user_id)
)
SELECT * FROM 'user' LIMIT 2
SELECT * FROM 'user' LIMIT 2;
user_id user_name
13 Harrison
14 Chase

View File

@ -45,12 +45,17 @@ def test_table_info() -> None:
"""Test that table info is constructed properly."""
engine = create_engine("duckdb:///:memory:")
metadata_obj.create_all(engine)
db = SQLDatabase(engine, schema="schema_a")
db = SQLDatabase(engine, schema="schema_a", metadata=metadata_obj)
output = db.table_info
expected_output = """
CREATE TABLE schema_a."user"(user_id INTEGER, user_name VARCHAR NOT NULL, PRIMARY KEY(user_id));
CREATE TABLE schema_a."user" (
user_id INTEGER NOT NULL,
user_name VARCHAR NOT NULL,
PRIMARY KEY (user_id)
)
SELECT * FROM 'user' LIMIT 3
SELECT * FROM 'user' LIMIT 3;
user_id user_name
"""