Harrison/table index (#2526)

Co-authored-by: Alvaro Sevilla <alvaro@chainalysis.com>
doc
Harrison Chase 1 year ago committed by GitHub
parent 704b0feb38
commit 15cdfa9e7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,12 +4,19 @@ from __future__ import annotations
import warnings import warnings
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
from sqlalchemy import MetaData, create_engine, inspect, select, text from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable from sqlalchemy.schema import CreateTable
def _format_index(index: dict) -> str:
return (
f'Name: {index["name"]}, Unique: {index["unique"]},'
f' Columns: {str(index["column_names"])}'
)
class SQLDatabase: class SQLDatabase:
"""SQLAlchemy wrapper around a database.""" """SQLAlchemy wrapper around a database."""
@ -21,6 +28,7 @@ class SQLDatabase:
ignore_tables: Optional[List[str]] = None, ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3, sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None, custom_table_info: Optional[dict] = None,
view_support: Optional[bool] = False, view_support: Optional[bool] = False,
): ):
@ -60,6 +68,7 @@ class SQLDatabase:
raise TypeError("sample_rows_in_table_info must be an integer") raise TypeError("sample_rows_in_table_info must be an integer")
self._sample_rows_in_table_info = sample_rows_in_table_info self._sample_rows_in_table_info = sample_rows_in_table_info
self._indexes_in_table_info = indexes_in_table_info
self._custom_table_info = custom_table_info self._custom_table_info = custom_table_info
if self._custom_table_info: if self._custom_table_info:
@ -148,48 +157,56 @@ class SQLDatabase:
# add create table command # add create table command
create_table = str(CreateTable(table).compile(self._engine)) create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}"
has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info
)
if has_extra_info:
table_info += "\n\n/*"
if self._indexes_in_table_info:
table_info += f"\n{self._get_table_indexes(table)}\n"
if self._sample_rows_in_table_info: if self._sample_rows_in_table_info:
# build the select command table_info += f"\n{self._get_sample_rows(table)}\n"
command = select([table]).limit(self._sample_rows_in_table_info) if has_extra_info:
table_info += "*/"
# save the columns in string format tables.append(table_info)
columns_str = "\t".join([col.name for col in table.columns]) final_str = "\n\n".join(tables)
return final_str
try:
# get the sample rows def _get_table_indexes(self, table: Table) -> str:
with self._engine.connect() as connection: indexes = self._inspector.get_indexes(table.name)
sample_rows = connection.execute(command) indexes_formatted = "\n".join(map(_format_index, indexes))
# shorten values in the sample rows return f"Table Indexes:\n{indexes_formatted}"
sample_rows = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows) def _get_sample_rows(self, table: Table) -> str:
) # build the select command
command = select([table]).limit(self._sample_rows_in_table_info)
# save the sample rows in string format
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) # save the columns in string format
columns_str = "\t".join([col.name for col in table.columns])
# in some dialects when there are no rows in the table a
# 'ProgrammingError' is returned try:
except ProgrammingError: # get the sample rows
sample_rows_str = "" with self._engine.connect() as connection:
sample_rows = connection.execute(command)
table_info = ( # shorten values in the sample rows
f"{create_table.rstrip()}\n" sample_rows = list(
f"/*\n" map(lambda ls: [str(i)[:100] for i in ls], sample_rows)
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
f"{columns_str}\n"
f"{sample_rows_str}\n"
f"*/"
) )
# build final info for table # save the sample rows in string format
tables.append(table_info) sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
else: # in some dialects when there are no rows in the table a
tables.append(create_table) # 'ProgrammingError' is returned
except ProgrammingError:
sample_rows_str = ""
final_str = "\n\n".join(tables) return (
return final_str f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
f"{columns_str}\n"
f"{sample_rows_str}"
)
def run(self, command: str, fetch: str = "all") -> str: def run(self, command: str, fetch: str = "all") -> str:
"""Execute a SQL command and return a string representing the results. """Execute a SQL command and return a string representing the results.

Loading…
Cancel
Save