diff --git a/langchain/sql_database.py b/langchain/sql_database.py index 8fe071d9..6c01b436 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -17,40 +17,30 @@ class SQLDatabase: ignore_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None, sample_rows_in_table_info: int = 0, - # TODO: deprecate. - sample_row_in_table_info: bool = False, ): """Create engine from database URI.""" - if sample_row_in_table_info and sample_rows_in_table_info > 0: - raise ValueError( - "Only one of `sample_row_in_table_info` " - "and `sample_rows_in_table_info` should be set" - ) self._engine = engine self._schema = schema if include_tables and ignore_tables: raise ValueError("Cannot specify both include_tables and ignore_tables") self._inspector = inspect(self._engine) - self._all_tables = self._inspector.get_table_names(schema=schema) - self._include_tables = include_tables or [] + self._all_tables = set(self._inspector.get_table_names(schema=schema)) + self._include_tables = set(include_tables) if include_tables else set() if self._include_tables: - missing_tables = set(self._include_tables).difference(self._all_tables) + missing_tables = self._include_tables - self._all_tables if missing_tables: raise ValueError( f"include_tables {missing_tables} not found in database" ) - self._ignore_tables = ignore_tables or [] + self._ignore_tables = set(ignore_tables) if ignore_tables else set() if self._ignore_tables: - missing_tables = set(self._ignore_tables).difference(self._all_tables) + missing_tables = self._ignore_tables - self._all_tables if missing_tables: raise ValueError( f"ignore_tables {missing_tables} not found in database" ) self._sample_rows_in_table_info = sample_rows_in_table_info - # TODO: deprecate - if sample_row_in_table_info: - self._sample_rows_in_table_info = 1 @classmethod def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase: @@ -66,7 +56,7 @@ class SQLDatabase: """Get names of tables available.""" if self._include_tables: return self._include_tables - return set(self._all_tables) - set(self._ignore_tables) + return self._all_tables - self._ignore_tables @property def table_info(self) -> str: