|
|
|
@ -17,40 +17,30 @@ class SQLDatabase:
|
|
|
|
|
ignore_tables: Optional[List[str]] = None,
|
|
|
|
|
include_tables: Optional[List[str]] = None,
|
|
|
|
|
sample_rows_in_table_info: int = 0,
|
|
|
|
|
# TODO: deprecate.
|
|
|
|
|
sample_row_in_table_info: bool = False,
|
|
|
|
|
):
|
|
|
|
|
"""Create engine from database URI."""
|
|
|
|
|
if sample_row_in_table_info and sample_rows_in_table_info > 0:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Only one of `sample_row_in_table_info` "
|
|
|
|
|
"and `sample_rows_in_table_info` should be set"
|
|
|
|
|
)
|
|
|
|
|
self._engine = engine
|
|
|
|
|
self._schema = schema
|
|
|
|
|
if include_tables and ignore_tables:
|
|
|
|
|
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
|
|
|
|
|
|
|
|
|
self._inspector = inspect(self._engine)
|
|
|
|
|
self._all_tables = self._inspector.get_table_names(schema=schema)
|
|
|
|
|
self._include_tables = include_tables or []
|
|
|
|
|
self._all_tables = set(self._inspector.get_table_names(schema=schema))
|
|
|
|
|
self._include_tables = set(include_tables) if include_tables else set()
|
|
|
|
|
if self._include_tables:
|
|
|
|
|
missing_tables = set(self._include_tables).difference(self._all_tables)
|
|
|
|
|
missing_tables = self._include_tables - self._all_tables
|
|
|
|
|
if missing_tables:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"include_tables {missing_tables} not found in database"
|
|
|
|
|
)
|
|
|
|
|
self._ignore_tables = ignore_tables or []
|
|
|
|
|
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
|
|
|
|
|
if self._ignore_tables:
|
|
|
|
|
missing_tables = set(self._ignore_tables).difference(self._all_tables)
|
|
|
|
|
missing_tables = self._ignore_tables - self._all_tables
|
|
|
|
|
if missing_tables:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"ignore_tables {missing_tables} not found in database"
|
|
|
|
|
)
|
|
|
|
|
self._sample_rows_in_table_info = sample_rows_in_table_info
|
|
|
|
|
# TODO: deprecate
|
|
|
|
|
if sample_row_in_table_info:
|
|
|
|
|
self._sample_rows_in_table_info = 1
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase:
|
|
|
|
@ -66,7 +56,7 @@ class SQLDatabase:
|
|
|
|
|
"""Get names of tables available."""
|
|
|
|
|
if self._include_tables:
|
|
|
|
|
return self._include_tables
|
|
|
|
|
return set(self._all_tables) - set(self._ignore_tables)
|
|
|
|
|
return self._all_tables - self._ignore_tables
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def table_info(self) -> str:
|
|
|
|
|