mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Harrison/table index (#2526)
Co-authored-by: Alvaro Sevilla <alvaro@chainalysis.com>
This commit is contained in:
parent
704b0feb38
commit
15cdfa9e7f
@ -4,12 +4,19 @@ from __future__ import annotations
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
|
|
||||||
from sqlalchemy import MetaData, create_engine, inspect, select, text
|
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||||
from sqlalchemy.schema import CreateTable
|
from sqlalchemy.schema import CreateTable
|
||||||
|
|
||||||
|
|
||||||
|
def _format_index(index: dict) -> str:
|
||||||
|
return (
|
||||||
|
f'Name: {index["name"]}, Unique: {index["unique"]},'
|
||||||
|
f' Columns: {str(index["column_names"])}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SQLDatabase:
|
class SQLDatabase:
|
||||||
"""SQLAlchemy wrapper around a database."""
|
"""SQLAlchemy wrapper around a database."""
|
||||||
|
|
||||||
@ -21,6 +28,7 @@ class SQLDatabase:
|
|||||||
ignore_tables: Optional[List[str]] = None,
|
ignore_tables: Optional[List[str]] = None,
|
||||||
include_tables: Optional[List[str]] = None,
|
include_tables: Optional[List[str]] = None,
|
||||||
sample_rows_in_table_info: int = 3,
|
sample_rows_in_table_info: int = 3,
|
||||||
|
indexes_in_table_info: bool = False,
|
||||||
custom_table_info: Optional[dict] = None,
|
custom_table_info: Optional[dict] = None,
|
||||||
view_support: Optional[bool] = False,
|
view_support: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
@ -60,6 +68,7 @@ class SQLDatabase:
|
|||||||
raise TypeError("sample_rows_in_table_info must be an integer")
|
raise TypeError("sample_rows_in_table_info must be an integer")
|
||||||
|
|
||||||
self._sample_rows_in_table_info = sample_rows_in_table_info
|
self._sample_rows_in_table_info = sample_rows_in_table_info
|
||||||
|
self._indexes_in_table_info = indexes_in_table_info
|
||||||
|
|
||||||
self._custom_table_info = custom_table_info
|
self._custom_table_info = custom_table_info
|
||||||
if self._custom_table_info:
|
if self._custom_table_info:
|
||||||
@ -148,49 +157,57 @@ class SQLDatabase:
|
|||||||
|
|
||||||
# add create table command
|
# add create table command
|
||||||
create_table = str(CreateTable(table).compile(self._engine))
|
create_table = str(CreateTable(table).compile(self._engine))
|
||||||
|
table_info = f"{create_table.rstrip()}"
|
||||||
|
has_extra_info = (
|
||||||
|
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||||
|
)
|
||||||
|
if has_extra_info:
|
||||||
|
table_info += "\n\n/*"
|
||||||
|
if self._indexes_in_table_info:
|
||||||
|
table_info += f"\n{self._get_table_indexes(table)}\n"
|
||||||
if self._sample_rows_in_table_info:
|
if self._sample_rows_in_table_info:
|
||||||
# build the select command
|
table_info += f"\n{self._get_sample_rows(table)}\n"
|
||||||
command = select([table]).limit(self._sample_rows_in_table_info)
|
if has_extra_info:
|
||||||
|
table_info += "*/"
|
||||||
# save the columns in string format
|
tables.append(table_info)
|
||||||
columns_str = "\t".join([col.name for col in table.columns])
|
|
||||||
|
|
||||||
try:
|
|
||||||
# get the sample rows
|
|
||||||
with self._engine.connect() as connection:
|
|
||||||
sample_rows = connection.execute(command)
|
|
||||||
# shorten values in the sample rows
|
|
||||||
sample_rows = list(
|
|
||||||
map(lambda ls: [str(i)[:100] for i in ls], sample_rows)
|
|
||||||
)
|
|
||||||
|
|
||||||
# save the sample rows in string format
|
|
||||||
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
|
|
||||||
|
|
||||||
# in some dialects when there are no rows in the table a
|
|
||||||
# 'ProgrammingError' is returned
|
|
||||||
except ProgrammingError:
|
|
||||||
sample_rows_str = ""
|
|
||||||
|
|
||||||
table_info = (
|
|
||||||
f"{create_table.rstrip()}\n"
|
|
||||||
f"/*\n"
|
|
||||||
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
|
|
||||||
f"{columns_str}\n"
|
|
||||||
f"{sample_rows_str}\n"
|
|
||||||
f"*/"
|
|
||||||
)
|
|
||||||
|
|
||||||
# build final info for table
|
|
||||||
tables.append(table_info)
|
|
||||||
|
|
||||||
else:
|
|
||||||
tables.append(create_table)
|
|
||||||
|
|
||||||
final_str = "\n\n".join(tables)
|
final_str = "\n\n".join(tables)
|
||||||
return final_str
|
return final_str
|
||||||
|
|
||||||
|
def _get_table_indexes(self, table: Table) -> str:
|
||||||
|
indexes = self._inspector.get_indexes(table.name)
|
||||||
|
indexes_formatted = "\n".join(map(_format_index, indexes))
|
||||||
|
return f"Table Indexes:\n{indexes_formatted}"
|
||||||
|
|
||||||
|
def _get_sample_rows(self, table: Table) -> str:
|
||||||
|
# build the select command
|
||||||
|
command = select([table]).limit(self._sample_rows_in_table_info)
|
||||||
|
|
||||||
|
# save the columns in string format
|
||||||
|
columns_str = "\t".join([col.name for col in table.columns])
|
||||||
|
|
||||||
|
try:
|
||||||
|
# get the sample rows
|
||||||
|
with self._engine.connect() as connection:
|
||||||
|
sample_rows = connection.execute(command)
|
||||||
|
# shorten values in the sample rows
|
||||||
|
sample_rows = list(
|
||||||
|
map(lambda ls: [str(i)[:100] for i in ls], sample_rows)
|
||||||
|
)
|
||||||
|
|
||||||
|
# save the sample rows in string format
|
||||||
|
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
|
||||||
|
|
||||||
|
# in some dialects when there are no rows in the table a
|
||||||
|
# 'ProgrammingError' is returned
|
||||||
|
except ProgrammingError:
|
||||||
|
sample_rows_str = ""
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
|
||||||
|
f"{columns_str}\n"
|
||||||
|
f"{sample_rows_str}"
|
||||||
|
)
|
||||||
|
|
||||||
def run(self, command: str, fetch: str = "all") -> str:
|
def run(self, command: str, fetch: str = "all") -> str:
|
||||||
"""Execute a SQL command and return a string representing the results.
|
"""Execute a SQL command and return a string representing the results.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user