Harrison/sql alchemy (#2216)

Co-authored-by: Jason B. Hart <jasonbhart@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-04-01 12:52:08 -07:00 committed by GitHub
parent 1ddd6dbf0b
commit 609b14a570
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 7 deletions

View File

@ -162,7 +162,7 @@ class SQLDatabaseSequentialChain(Chain, BaseModel):
return [self.output_key, "intermediate_steps"] return [self.output_key, "intermediate_steps"]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
_table_names = self.sql_chain.database.get_table_names() _table_names = self.sql_chain.database.get_usable_table_names()
table_names = ", ".join(_table_names) table_names = ", ".join(_table_names)
llm_inputs = { llm_inputs = {
"query": inputs[self.input_key], "query": inputs[self.input_key],

View File

@ -1,6 +1,7 @@
"""SQLAlchemy wrapper around a database.""" """SQLAlchemy wrapper around a database."""
from __future__ import annotations from __future__ import annotations
import warnings
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
from sqlalchemy import MetaData, create_engine, inspect, select, text from sqlalchemy import MetaData, create_engine, inspect, select, text
@ -44,6 +45,8 @@ class SQLDatabase:
raise ValueError( raise ValueError(
f"ignore_tables {missing_tables} not found in database" f"ignore_tables {missing_tables} not found in database"
) )
usable_tables = self.get_usable_table_names()
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
if not isinstance(sample_rows_in_table_info, int): if not isinstance(sample_rows_in_table_info, int):
raise TypeError("sample_rows_in_table_info must be an integer") raise TypeError("sample_rows_in_table_info must be an integer")
@ -66,7 +69,9 @@ class SQLDatabase:
) )
self._metadata = metadata or MetaData() self._metadata = metadata or MetaData()
self._metadata.reflect(bind=self._engine) self._metadata.reflect(
bind=self._engine, only=self._usable_tables, schema=self._schema
)
@classmethod @classmethod
def from_uri( def from_uri(
@ -81,12 +86,19 @@ class SQLDatabase:
"""Return string representation of dialect to use.""" """Return string representation of dialect to use."""
return self._engine.dialect.name return self._engine.dialect.name
def get_table_names(self) -> Iterable[str]: def get_usable_table_names(self) -> Iterable[str]:
"""Get names of tables available.""" """Get names of tables available."""
if self._include_tables: if self._include_tables:
return self._include_tables return self._include_tables
return self._all_tables - self._ignore_tables return self._all_tables - self._ignore_tables
def get_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
warnings.warn(
"This method is deprecated - please use `get_usable_table_names`."
)
return self.get_usable_table_names()
@property @property
def table_info(self) -> str: def table_info(self) -> str:
"""Information about all tables in the database.""" """Information about all tables in the database."""
@ -102,7 +114,7 @@ class SQLDatabase:
appended to each table description. This can increase performance as appended to each table description. This can increase performance as
demonstrated in the paper. demonstrated in the paper.
""" """
all_table_names = self.get_table_names() all_table_names = self.get_usable_table_names()
if table_names is not None: if table_names is not None:
missing_tables = set(table_names).difference(all_table_names) missing_tables = set(table_names).difference(all_table_names)
if missing_tables: if missing_tables:

View File

@ -30,7 +30,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
name = "query_sql_db" name = "query_sql_db"
description = """ description = """
Input to this tool is a detailed and correct SQL query, output is a result from the database. Input to this tool is a detailed and correct SQL query, output is a result from the database.
If the query is not correct, an error message will be returned. If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again. If an error is returned, rewrite the query, check the query, and try again.
""" """
@ -49,7 +49,7 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
description = """ description = """
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
Be sure that the tables actually exist by calling list_tables_sql_db first! Be sure that the tables actually exist by calling list_tables_sql_db first!
Example Input: "table1, table2, table3" Example Input: "table1, table2, table3"
""" """
@ -69,7 +69,7 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
def _run(self, tool_input: str = "") -> str: def _run(self, tool_input: str = "") -> str:
"""Get the schema for a specific table.""" """Get the schema for a specific table."""
return ", ".join(self.db.get_table_names()) return ", ".join(self.db.get_usable_table_names())
async def _arun(self, tool_input: str = "") -> str: async def _arun(self, tool_input: str = "") -> str:
raise NotImplementedError("ListTablesSqlDbTool does not support async") raise NotImplementedError("ListTablesSqlDbTool does not support async")