diff --git a/langchain/sql_database.py b/langchain/sql_database.py index 4193e619..b68b523a 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -4,12 +4,19 @@ from __future__ import annotations import warnings from typing import Any, Iterable, List, Optional -from sqlalchemy import MetaData, create_engine, inspect, select, text +from sqlalchemy import MetaData, Table, create_engine, inspect, select, text from sqlalchemy.engine import Engine from sqlalchemy.exc import ProgrammingError, SQLAlchemyError from sqlalchemy.schema import CreateTable +def _format_index(index: dict) -> str: + return ( + f'Name: {index["name"]}, Unique: {index["unique"]},' + f' Columns: {str(index["column_names"])}' + ) + + class SQLDatabase: """SQLAlchemy wrapper around a database.""" @@ -21,6 +28,7 @@ class SQLDatabase: ignore_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None, sample_rows_in_table_info: int = 3, + indexes_in_table_info: bool = False, custom_table_info: Optional[dict] = None, view_support: Optional[bool] = False, ): @@ -60,6 +68,7 @@ class SQLDatabase: raise TypeError("sample_rows_in_table_info must be an integer") self._sample_rows_in_table_info = sample_rows_in_table_info + self._indexes_in_table_info = indexes_in_table_info self._custom_table_info = custom_table_info if self._custom_table_info: @@ -148,48 +157,56 @@ class SQLDatabase: # add create table command create_table = str(CreateTable(table).compile(self._engine)) - + table_info = f"{create_table.rstrip()}" + has_extra_info = ( + self._indexes_in_table_info or self._sample_rows_in_table_info + ) + if has_extra_info: + table_info += "\n\n/*" + if self._indexes_in_table_info: + table_info += f"\n{self._get_table_indexes(table)}\n" if self._sample_rows_in_table_info: - # build the select command - command = select([table]).limit(self._sample_rows_in_table_info) - - # save the columns in string format - columns_str = "\t".join([col.name for col in table.columns]) - - try: - # get the sample rows - with self._engine.connect() as connection: - sample_rows = connection.execute(command) - # shorten values in the sample rows - sample_rows = list( - map(lambda ls: [str(i)[:100] for i in ls], sample_rows) - ) - - # save the sample rows in string format - sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) - - # in some dialects when there are no rows in the table a - # 'ProgrammingError' is returned - except ProgrammingError: - sample_rows_str = "" - - table_info = ( - f"{create_table.rstrip()}\n" - f"/*\n" - f"{self._sample_rows_in_table_info} rows from {table.name} table:\n" - f"{columns_str}\n" - f"{sample_rows_str}\n" - f"*/" + table_info += f"\n{self._get_sample_rows(table)}\n" + if has_extra_info: + table_info += "*/" + tables.append(table_info) + final_str = "\n\n".join(tables) + return final_str + + def _get_table_indexes(self, table: Table) -> str: + indexes = self._inspector.get_indexes(table.name) + indexes_formatted = "\n".join(map(_format_index, indexes)) + return f"Table Indexes:\n{indexes_formatted}" + + def _get_sample_rows(self, table: Table) -> str: + # build the select command + command = select([table]).limit(self._sample_rows_in_table_info) + + # save the columns in string format + columns_str = "\t".join([col.name for col in table.columns]) + + try: + # get the sample rows + with self._engine.connect() as connection: + sample_rows = connection.execute(command) + # shorten values in the sample rows + sample_rows = list( + map(lambda ls: [str(i)[:100] for i in ls], sample_rows) ) - # build final info for table - tables.append(table_info) + # save the sample rows in string format + sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) - else: - tables.append(create_table) + # in some dialects when there are no rows in the table a + # 'ProgrammingError' is returned + except ProgrammingError: + sample_rows_str = "" - final_str = "\n\n".join(tables) - return final_str + return ( + f"{self._sample_rows_in_table_info} rows from {table.name} table:\n" + f"{columns_str}\n" + f"{sample_rows_str}" + ) def run(self, command: str, fetch: str = "all") -> str: """Execute a SQL command and return a string representing the results.