Harrison/sql alchemy (#2216)

Co-authored-by: Jason B. Hart <jasonbhart@users.noreply.github.com>
doc
Harrison Chase 1 year ago committed by GitHub
parent 1ddd6dbf0b
commit 609b14a570
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -162,7 +162,7 @@ class SQLDatabaseSequentialChain(Chain, BaseModel):
return [self.output_key, "intermediate_steps"]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
_table_names = self.sql_chain.database.get_table_names()
_table_names = self.sql_chain.database.get_usable_table_names()
table_names = ", ".join(_table_names)
llm_inputs = {
"query": inputs[self.input_key],

@ -1,6 +1,7 @@
"""SQLAlchemy wrapper around a database."""
from __future__ import annotations
import warnings
from typing import Any, Iterable, List, Optional
from sqlalchemy import MetaData, create_engine, inspect, select, text
@ -44,6 +45,8 @@ class SQLDatabase:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
usable_tables = self.get_usable_table_names()
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
if not isinstance(sample_rows_in_table_info, int):
raise TypeError("sample_rows_in_table_info must be an integer")
@ -66,7 +69,9 @@ class SQLDatabase:
)
self._metadata = metadata or MetaData()
self._metadata.reflect(bind=self._engine)
self._metadata.reflect(
bind=self._engine, only=self._usable_tables, schema=self._schema
)
@classmethod
def from_uri(
@ -81,12 +86,19 @@ class SQLDatabase:
"""Return string representation of dialect to use."""
return self._engine.dialect.name
def get_table_names(self) -> Iterable[str]:
def get_usable_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return self._include_tables
return self._all_tables - self._ignore_tables
def get_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
warnings.warn(
"This method is deprecated - please use `get_usable_table_names`."
)
return self.get_usable_table_names()
@property
def table_info(self) -> str:
"""Information about all tables in the database."""
@ -102,7 +114,7 @@ class SQLDatabase:
appended to each table description. This can increase performance as
demonstrated in the paper.
"""
all_table_names = self.get_table_names()
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:

@ -69,7 +69,7 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
def _run(self, tool_input: str = "") -> str:
"""Get the schema for a specific table."""
return ", ".join(self.db.get_table_names())
return ", ".join(self.db.get_usable_table_names())
async def _arun(self, tool_input: str = "") -> str:
raise NotImplementedError("ListTablesSqlDbTool does not support async")

Loading…
Cancel
Save