remove sample_row_in_table_info and simplify set operations in SQLDB (#932)

-Address TODO: deprecate for sample_row_in_table_info
-Simplify set operations by casting to sets to not need multiple set
casts + .difference() calls
makefile-update-1
Kevin Huo 1 year ago committed by GitHub
parent e323d0cfb1
commit 512c523368
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,40 +17,30 @@ class SQLDatabase:
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 0,
# TODO: deprecate.
sample_row_in_table_info: bool = False,
):
"""Create engine from database URI."""
if sample_row_in_table_info and sample_rows_in_table_info > 0:
raise ValueError(
"Only one of `sample_row_in_table_info` "
"and `sample_rows_in_table_info` should be set"
)
self._engine = engine
self._schema = schema
if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")
self._inspector = inspect(self._engine)
self._all_tables = self._inspector.get_table_names(schema=schema)
self._include_tables = include_tables or []
self._all_tables = set(self._inspector.get_table_names(schema=schema))
self._include_tables = set(include_tables) if include_tables else set()
if self._include_tables:
missing_tables = set(self._include_tables).difference(self._all_tables)
missing_tables = self._include_tables - self._all_tables
if missing_tables:
raise ValueError(
f"include_tables {missing_tables} not found in database"
)
self._ignore_tables = ignore_tables or []
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
if self._ignore_tables:
missing_tables = set(self._ignore_tables).difference(self._all_tables)
missing_tables = self._ignore_tables - self._all_tables
if missing_tables:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
self._sample_rows_in_table_info = sample_rows_in_table_info
# TODO: deprecate
if sample_row_in_table_info:
self._sample_rows_in_table_info = 1
@classmethod
def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase:
@ -66,7 +56,7 @@ class SQLDatabase:
"""Get names of tables available."""
if self._include_tables:
return self._include_tables
return set(self._all_tables) - set(self._ignore_tables)
return self._all_tables - self._ignore_tables
@property
def table_info(self) -> str:

Loading…
Cancel
Save