"""SQLAlchemy wrapper around a database.""" from __future__ import annotations from typing import Any, Iterable, List, Optional from sqlalchemy import MetaData, create_engine, inspect, select from sqlalchemy.engine import Engine from sqlalchemy.exc import ProgrammingError from sqlalchemy.schema import CreateTable class SQLDatabase: """SQLAlchemy wrapper around a database.""" def __init__( self, engine: Engine, schema: Optional[str] = None, metadata: Optional[MetaData] = None, ignore_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None, sample_rows_in_table_info: int = 3, ): """Create engine from database URI.""" self._engine = engine self._schema = schema if include_tables and ignore_tables: raise ValueError("Cannot specify both include_tables and ignore_tables") self._inspector = inspect(self._engine) self._all_tables = set(self._inspector.get_table_names(schema=schema)) self._include_tables = set(include_tables) if include_tables else set() if self._include_tables: missing_tables = self._include_tables - self._all_tables if missing_tables: raise ValueError( f"include_tables {missing_tables} not found in database" ) self._ignore_tables = set(ignore_tables) if ignore_tables else set() if self._ignore_tables: missing_tables = self._ignore_tables - self._all_tables if missing_tables: raise ValueError( f"ignore_tables {missing_tables} not found in database" ) if not isinstance(sample_rows_in_table_info, int): raise TypeError("sample_rows_in_table_info must be an integer") self._sample_rows_in_table_info = sample_rows_in_table_info self._metadata = metadata or MetaData() self._metadata.reflect(bind=self._engine) @classmethod def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase: """Construct a SQLAlchemy engine from URI.""" return cls(create_engine(database_uri), **kwargs) @property def dialect(self) -> str: """Return string representation of dialect to use.""" return self._engine.dialect.name def get_table_names(self) -> Iterable[str]: """Get names of tables available.""" if self._include_tables: return self._include_tables return self._all_tables - self._ignore_tables @property def table_info(self) -> str: """Information about all tables in the database.""" return self.get_table_info() def get_table_info(self, table_names: Optional[List[str]] = None) -> str: """Get information about specified tables. Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498) If `sample_rows_in_table_info`, the specified number of sample rows will be appended to each table description. This can increase performance as demonstrated in the paper. """ all_table_names = self.get_table_names() if table_names is not None: missing_tables = set(table_names).difference(all_table_names) if missing_tables: raise ValueError(f"table_names {missing_tables} not found in database") all_table_names = table_names meta_tables = [ tbl for tbl in self._metadata.sorted_tables if tbl.name in set(all_table_names) and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) ] tables = [] for table in meta_tables: # add create table command create_table = str(CreateTable(table).compile(self._engine)) if self._sample_rows_in_table_info: # build the select command command = select(table).limit(self._sample_rows_in_table_info) # save the command in string format select_star = ( f"SELECT * FROM '{table.name}' LIMIT " f"{self._sample_rows_in_table_info}" ) # save the columns in string format columns_str = " ".join([col.name for col in table.columns]) try: # get the sample rows with self._engine.connect() as connection: sample_rows = connection.execute(command) # shorten values in the sample rows sample_rows = list( map(lambda ls: [str(i)[:100] for i in ls], sample_rows) ) # save the sample rows in string format sample_rows_str = "\n".join([" ".join(row) for row in sample_rows]) # in some dialects when there are no rows in the table a # 'ProgrammingError' is returned except ProgrammingError: sample_rows_str = "" # build final info for table tables.append( create_table + select_star + ";\n" + columns_str + "\n" + sample_rows_str ) else: tables.append(create_table) final_str = "\n\n".join(tables) return final_str def run(self, command: str, fetch: str = "all") -> str: """Execute a SQL command and return a string representing the results. If the statement returns rows, a string of the results is returned. If the statement returns no rows, an empty string is returned. """ with self._engine.begin() as connection: if self._schema is not None: connection.exec_driver_sql(f"SET search_path TO {self._schema}") cursor = connection.exec_driver_sql(command) if cursor.returns_rows: if fetch == "all": result = cursor.fetchall() elif fetch == "one": result = cursor.fetchone()[0] else: raise ValueError("Fetch parameter must be either 'one' or 'all'") return str(result) return ""