|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
"""SQLAlchemy wrapper around a database."""
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import warnings
|
|
|
|
|
from typing import Any, Iterable, List, Optional
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import MetaData, create_engine, inspect, select, text
|
|
|
|
@ -44,6 +45,8 @@ class SQLDatabase:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"ignore_tables {missing_tables} not found in database"
|
|
|
|
|
)
|
|
|
|
|
usable_tables = self.get_usable_table_names()
|
|
|
|
|
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
|
|
|
|
|
|
|
|
|
|
if not isinstance(sample_rows_in_table_info, int):
|
|
|
|
|
raise TypeError("sample_rows_in_table_info must be an integer")
|
|
|
|
@ -66,7 +69,9 @@ class SQLDatabase:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self._metadata = metadata or MetaData()
|
|
|
|
|
self._metadata.reflect(bind=self._engine)
|
|
|
|
|
self._metadata.reflect(
|
|
|
|
|
bind=self._engine, only=self._usable_tables, schema=self._schema
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_uri(
|
|
|
|
@ -81,12 +86,19 @@ class SQLDatabase:
|
|
|
|
|
"""Return string representation of dialect to use."""
|
|
|
|
|
return self._engine.dialect.name
|
|
|
|
|
|
|
|
|
|
def get_table_names(self) -> Iterable[str]:
|
|
|
|
|
def get_usable_table_names(self) -> Iterable[str]:
|
|
|
|
|
"""Get names of tables available."""
|
|
|
|
|
if self._include_tables:
|
|
|
|
|
return self._include_tables
|
|
|
|
|
return self._all_tables - self._ignore_tables
|
|
|
|
|
|
|
|
|
|
def get_table_names(self) -> Iterable[str]:
|
|
|
|
|
"""Get names of tables available."""
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"This method is deprecated - please use `get_usable_table_names`."
|
|
|
|
|
)
|
|
|
|
|
return self.get_usable_table_names()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def table_info(self) -> str:
|
|
|
|
|
"""Information about all tables in the database."""
|
|
|
|
@ -102,7 +114,7 @@ class SQLDatabase:
|
|
|
|
|
appended to each table description. This can increase performance as
|
|
|
|
|
demonstrated in the paper.
|
|
|
|
|
"""
|
|
|
|
|
all_table_names = self.get_table_names()
|
|
|
|
|
all_table_names = self.get_usable_table_names()
|
|
|
|
|
if table_names is not None:
|
|
|
|
|
missing_tables = set(table_names).difference(all_table_names)
|
|
|
|
|
if missing_tables:
|
|
|
|
|