diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index 6f959014..a91c6e91 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -162,7 +162,7 @@ class SQLDatabaseSequentialChain(Chain, BaseModel): return [self.output_key, "intermediate_steps"] def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: - _table_names = self.sql_chain.database.get_table_names() + _table_names = self.sql_chain.database.get_usable_table_names() table_names = ", ".join(_table_names) llm_inputs = { "query": inputs[self.input_key], diff --git a/langchain/sql_database.py b/langchain/sql_database.py index 844084b7..234395a2 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -1,6 +1,7 @@ """SQLAlchemy wrapper around a database.""" from __future__ import annotations +import warnings from typing import Any, Iterable, List, Optional from sqlalchemy import MetaData, create_engine, inspect, select, text @@ -44,6 +45,8 @@ class SQLDatabase: raise ValueError( f"ignore_tables {missing_tables} not found in database" ) + usable_tables = self.get_usable_table_names() + self._usable_tables = set(usable_tables) if usable_tables else self._all_tables if not isinstance(sample_rows_in_table_info, int): raise TypeError("sample_rows_in_table_info must be an integer") @@ -66,7 +69,9 @@ class SQLDatabase: ) self._metadata = metadata or MetaData() - self._metadata.reflect(bind=self._engine) + self._metadata.reflect( + bind=self._engine, only=self._usable_tables, schema=self._schema + ) @classmethod def from_uri( @@ -81,12 +86,19 @@ class SQLDatabase: """Return string representation of dialect to use.""" return self._engine.dialect.name - def get_table_names(self) -> Iterable[str]: + def get_usable_table_names(self) -> Iterable[str]: """Get names of tables available.""" if self._include_tables: return self._include_tables return self._all_tables - self._ignore_tables + def get_table_names(self) -> Iterable[str]: + """Get names of tables available.""" + warnings.warn( + "This method is deprecated - please use `get_usable_table_names`." + ) + return self.get_usable_table_names() + @property def table_info(self) -> str: """Information about all tables in the database.""" @@ -102,7 +114,7 @@ class SQLDatabase: appended to each table description. This can increase performance as demonstrated in the paper. """ - all_table_names = self.get_table_names() + all_table_names = self.get_usable_table_names() if table_names is not None: missing_tables = set(table_names).difference(all_table_names) if missing_tables: diff --git a/langchain/tools/sql_database/tool.py b/langchain/tools/sql_database/tool.py index 2c555d17..d4b72cd4 100644 --- a/langchain/tools/sql_database/tool.py +++ b/langchain/tools/sql_database/tool.py @@ -30,7 +30,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): name = "query_sql_db" description = """ Input to this tool is a detailed and correct SQL query, output is a result from the database. - If the query is not correct, an error message will be returned. + If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. """ @@ -49,7 +49,7 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): description = """ Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling list_tables_sql_db first! - + Example Input: "table1, table2, table3" """ @@ -69,7 +69,7 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): def _run(self, tool_input: str = "") -> str: """Get the schema for a specific table.""" - return ", ".join(self.db.get_table_names()) + return ", ".join(self.db.get_usable_table_names()) async def _arun(self, tool_input: str = "") -> str: raise NotImplementedError("ListTablesSqlDbTool does not support async")