Harrison/sql alchemy (#2216)

Co-authored-by: Jason B. Hart <jasonbhart@users.noreply.github.com>
doc
Harrison Chase 1 year ago committed by GitHub
parent 1ddd6dbf0b
commit 609b14a570
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -162,7 +162,7 @@ class SQLDatabaseSequentialChain(Chain, BaseModel):
return [self.output_key, "intermediate_steps"]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
_table_names = self.sql_chain.database.get_table_names()
_table_names = self.sql_chain.database.get_usable_table_names()
table_names = ", ".join(_table_names)
llm_inputs = {
"query": inputs[self.input_key],

@ -1,6 +1,7 @@
"""SQLAlchemy wrapper around a database."""
from __future__ import annotations
import warnings
from typing import Any, Iterable, List, Optional
from sqlalchemy import MetaData, create_engine, inspect, select, text
@ -44,6 +45,8 @@ class SQLDatabase:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
usable_tables = self.get_usable_table_names()
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
if not isinstance(sample_rows_in_table_info, int):
raise TypeError("sample_rows_in_table_info must be an integer")
@ -66,7 +69,9 @@ class SQLDatabase:
)
self._metadata = metadata or MetaData()
self._metadata.reflect(bind=self._engine)
self._metadata.reflect(
bind=self._engine, only=self._usable_tables, schema=self._schema
)
@classmethod
def from_uri(
@ -81,12 +86,19 @@ class SQLDatabase:
"""Return string representation of dialect to use."""
return self._engine.dialect.name
def get_table_names(self) -> Iterable[str]:
def get_usable_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return self._include_tables
return self._all_tables - self._ignore_tables
def get_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
warnings.warn(
"This method is deprecated - please use `get_usable_table_names`."
)
return self.get_usable_table_names()
@property
def table_info(self) -> str:
"""Information about all tables in the database."""
@ -102,7 +114,7 @@ class SQLDatabase:
appended to each table description. This can increase performance as
demonstrated in the paper.
"""
all_table_names = self.get_table_names()
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:

@ -30,7 +30,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
name = "query_sql_db"
description = """
Input to this tool is a detailed and correct SQL query, output is a result from the database.
If the query is not correct, an error message will be returned.
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.
"""
@ -49,7 +49,7 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
description = """
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
Be sure that the tables actually exist by calling list_tables_sql_db first!
Example Input: "table1, table2, table3"
"""
@ -69,7 +69,7 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
def _run(self, tool_input: str = "") -> str:
"""Get the schema for a specific table."""
return ", ".join(self.db.get_table_names())
return ", ".join(self.db.get_usable_table_names())
async def _arun(self, tool_input: str = "") -> str:
raise NotImplementedError("ListTablesSqlDbTool does not support async")

Loading…
Cancel
Save