@ -1,23 +1,12 @@
""" SQLAlchemy wrapper around a database. """
""" SQLAlchemy wrapper around a database. """
from __future__ import annotations
from __future__ import annotations
import ast
from typing import Any , Iterable , List , Optional
from typing import Any , Iterable , List , Optional
from sqlalchemy import create_engine, inspect
from sqlalchemy import MetaData, create_engine, inspect , select
from sqlalchemy . engine import Engine
from sqlalchemy . engine import Engine
from sqlalchemy . exc import ProgrammingError
_TEMPLATE_PREFIX = """ Table data will be described in the following format:
from sqlalchemy . schema import CreateTable
Table ' table name ' has columns : {
column1 name : ( column1 type , [ list of example values for column1 ] ) ,
column2 name : ( column2 type , [ list of example values for column2 ] ) ,
. . .
}
These are the tables you can use , together with their column information :
"""
class SQLDatabase :
class SQLDatabase :
@ -27,6 +16,7 @@ class SQLDatabase:
self ,
self ,
engine : Engine ,
engine : Engine ,
schema : Optional [ str ] = None ,
schema : Optional [ str ] = None ,
metadata : Optional [ MetaData ] = None ,
ignore_tables : Optional [ List [ str ] ] = None ,
ignore_tables : Optional [ List [ str ] ] = None ,
include_tables : Optional [ List [ str ] ] = None ,
include_tables : Optional [ List [ str ] ] = None ,
sample_rows_in_table_info : int = 3 ,
sample_rows_in_table_info : int = 3 ,
@ -53,8 +43,15 @@ class SQLDatabase:
raise ValueError (
raise ValueError (
f " ignore_tables { missing_tables } not found in database "
f " ignore_tables { missing_tables } not found in database "
)
)
if not isinstance ( sample_rows_in_table_info , int ) :
raise TypeError ( " sample_rows_in_table_info must be an integer " )
self . _sample_rows_in_table_info = sample_rows_in_table_info
self . _sample_rows_in_table_info = sample_rows_in_table_info
self . _metadata = metadata or MetaData ( )
self . _metadata . reflect ( bind = self . _engine )
@classmethod
@classmethod
def from_uri ( cls , database_uri : str , * * kwargs : Any ) - > SQLDatabase :
def from_uri ( cls , database_uri : str , * * kwargs : Any ) - > SQLDatabase :
""" Construct a SQLAlchemy engine from URI. """
""" Construct a SQLAlchemy engine from URI. """
@ -93,52 +90,53 @@ class SQLDatabase:
raise ValueError ( f " table_names { missing_tables } not found in database " )
raise ValueError ( f " table_names { missing_tables } not found in database " )
all_table_names = table_names
all_table_names = table_names
meta_tables = [
tbl
for tbl in self . _metadata . sorted_tables
if tbl . name in set ( all_table_names )
]
tables = [ ]
tables = [ ]
for table_name in all_table_names :
for table in meta_tables :
columns = [ ]
# add create table command
if self . dialect in ( " sqlite " , " duckdb " ) :
create_table = str ( CreateTable ( table ) . compile ( self . _engine ) )
create_table = self . run (
(
if self . _sample_rows_in_table_info :
" SELECT sql FROM sqlite_master WHERE "
# build the select command
f " type= ' table ' AND name= ' { table_name } ' "
command = select ( table ) . limit ( self . _sample_rows_in_table_info )
) ,
fetch = " one " ,
# save the command in string format
)
select_star = (
else :
f " SELECT * FROM ' { table . name } ' LIMIT "
create_table = self . run (
f " { self . _sample_rows_in_table_info } "
f " SHOW CREATE TABLE ` { table_name } `; " ,
)
)
for column in self . _inspector . get_columns ( table_name , schema = self . _schema ) :
# save the columns in string format
columns . append ( column [ " name " ] )
columns _str = " " . join ( [ col . name for col in table . columns ] )
if self . _sample_rows_in_table_info :
# get the sample rows
if self . dialect in ( " sqlite " , " duckdb " ) :
with self . _engine . connect ( ) as connection :
select_star = (
sample_rows = connection . execute ( command )
f " SELECT * FROM ' { table_name } ' LIMIT "
f " { self . _sample_rows_in_table_info } "
try :
)
# shorten values in the smaple rows
else :
sample_rows = list (
select_star = (
map ( lambda ls : [ str ( i ) [ : 100 ] for i in ls ] , sample_rows )
f " SELECT * FROM ` { table_name } ` LIMIT "
f " { self . _sample_rows_in_table_info } "
)
)
sample_rows = self . run ( select_star )
# save the sample rows in string format
sample_rows_str = " \n " . join ( [ " " . join ( row ) for row in sample_rows ] )
sample_rows_ls = ast . literal_eval ( sample_rows )
# in some dialects when there are no rows in the table a
sample_rows_ls = list (
# 'ProgrammingError' is returned
map ( lambda ls : [ str ( i ) [ : 100 ] for i in ls ] , sample_rows_ls )
except ProgrammingError :
)
sample_rows_str = " "
columns_str = " " . join ( columns )
sample_rows_str = " \n " . join ( [ " " . join ( row ) for row in sample_rows_ls ] )
# build final info for table
tables . append (
tables . append (
create_table
create_table
+ " \n \n "
+ select_star
+ select_star
+ " \n "
+ " ; \n "
+ columns_str
+ columns_str
+ " \n "
+ " \n "
+ sample_rows_str
+ sample_rows_str
@ -147,7 +145,7 @@ class SQLDatabase:
else :
else :
tables . append ( create_table )
tables . append ( create_table )
final_str = " \n \n \n " . join ( tables )
final_str = " \n \n " . join ( tables )
return final_str
return final_str
def run ( self , command : str , fetch : str = " all " ) - > str :
def run ( self , command : str , fetch : str = " all " ) - > str :