mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Harrison/sql alchemy (#2216)
Co-authored-by: Jason B. Hart <jasonbhart@users.noreply.github.com>
This commit is contained in:
parent
1ddd6dbf0b
commit
609b14a570
@ -162,7 +162,7 @@ class SQLDatabaseSequentialChain(Chain, BaseModel):
|
|||||||
return [self.output_key, "intermediate_steps"]
|
return [self.output_key, "intermediate_steps"]
|
||||||
|
|
||||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
_table_names = self.sql_chain.database.get_table_names()
|
_table_names = self.sql_chain.database.get_usable_table_names()
|
||||||
table_names = ", ".join(_table_names)
|
table_names = ", ".join(_table_names)
|
||||||
llm_inputs = {
|
llm_inputs = {
|
||||||
"query": inputs[self.input_key],
|
"query": inputs[self.input_key],
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""SQLAlchemy wrapper around a database."""
|
"""SQLAlchemy wrapper around a database."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings
|
||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
|
|
||||||
from sqlalchemy import MetaData, create_engine, inspect, select, text
|
from sqlalchemy import MetaData, create_engine, inspect, select, text
|
||||||
@ -44,6 +45,8 @@ class SQLDatabase:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"ignore_tables {missing_tables} not found in database"
|
f"ignore_tables {missing_tables} not found in database"
|
||||||
)
|
)
|
||||||
|
usable_tables = self.get_usable_table_names()
|
||||||
|
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
|
||||||
|
|
||||||
if not isinstance(sample_rows_in_table_info, int):
|
if not isinstance(sample_rows_in_table_info, int):
|
||||||
raise TypeError("sample_rows_in_table_info must be an integer")
|
raise TypeError("sample_rows_in_table_info must be an integer")
|
||||||
@ -66,7 +69,9 @@ class SQLDatabase:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._metadata = metadata or MetaData()
|
self._metadata = metadata or MetaData()
|
||||||
self._metadata.reflect(bind=self._engine)
|
self._metadata.reflect(
|
||||||
|
bind=self._engine, only=self._usable_tables, schema=self._schema
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_uri(
|
def from_uri(
|
||||||
@ -81,12 +86,19 @@ class SQLDatabase:
|
|||||||
"""Return string representation of dialect to use."""
|
"""Return string representation of dialect to use."""
|
||||||
return self._engine.dialect.name
|
return self._engine.dialect.name
|
||||||
|
|
||||||
def get_table_names(self) -> Iterable[str]:
|
def get_usable_table_names(self) -> Iterable[str]:
|
||||||
"""Get names of tables available."""
|
"""Get names of tables available."""
|
||||||
if self._include_tables:
|
if self._include_tables:
|
||||||
return self._include_tables
|
return self._include_tables
|
||||||
return self._all_tables - self._ignore_tables
|
return self._all_tables - self._ignore_tables
|
||||||
|
|
||||||
|
def get_table_names(self) -> Iterable[str]:
|
||||||
|
"""Get names of tables available."""
|
||||||
|
warnings.warn(
|
||||||
|
"This method is deprecated - please use `get_usable_table_names`."
|
||||||
|
)
|
||||||
|
return self.get_usable_table_names()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def table_info(self) -> str:
|
def table_info(self) -> str:
|
||||||
"""Information about all tables in the database."""
|
"""Information about all tables in the database."""
|
||||||
@ -102,7 +114,7 @@ class SQLDatabase:
|
|||||||
appended to each table description. This can increase performance as
|
appended to each table description. This can increase performance as
|
||||||
demonstrated in the paper.
|
demonstrated in the paper.
|
||||||
"""
|
"""
|
||||||
all_table_names = self.get_table_names()
|
all_table_names = self.get_usable_table_names()
|
||||||
if table_names is not None:
|
if table_names is not None:
|
||||||
missing_tables = set(table_names).difference(all_table_names)
|
missing_tables = set(table_names).difference(all_table_names)
|
||||||
if missing_tables:
|
if missing_tables:
|
||||||
|
@ -30,7 +30,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
|||||||
name = "query_sql_db"
|
name = "query_sql_db"
|
||||||
description = """
|
description = """
|
||||||
Input to this tool is a detailed and correct SQL query, output is a result from the database.
|
Input to this tool is a detailed and correct SQL query, output is a result from the database.
|
||||||
If the query is not correct, an error message will be returned.
|
If the query is not correct, an error message will be returned.
|
||||||
If an error is returned, rewrite the query, check the query, and try again.
|
If an error is returned, rewrite the query, check the query, and try again.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -49,7 +49,7 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
|||||||
description = """
|
description = """
|
||||||
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
|
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
|
||||||
Be sure that the tables actually exist by calling list_tables_sql_db first!
|
Be sure that the tables actually exist by calling list_tables_sql_db first!
|
||||||
|
|
||||||
Example Input: "table1, table2, table3"
|
Example Input: "table1, table2, table3"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
|||||||
|
|
||||||
def _run(self, tool_input: str = "") -> str:
|
def _run(self, tool_input: str = "") -> str:
|
||||||
"""Get the schema for a specific table."""
|
"""Get the schema for a specific table."""
|
||||||
return ", ".join(self.db.get_table_names())
|
return ", ".join(self.db.get_usable_table_names())
|
||||||
|
|
||||||
async def _arun(self, tool_input: str = "") -> str:
|
async def _arun(self, tool_input: str = "") -> str:
|
||||||
raise NotImplementedError("ListTablesSqlDbTool does not support async")
|
raise NotImplementedError("ListTablesSqlDbTool does not support async")
|
||||||
|
Loading…
Reference in New Issue
Block a user