mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Sql alchemy commands used in table info (#1135)
This approach has several advantages: * it improves the readability of the code * removes incompatibilities between SQL dialects * fixes a bug with `datetime` values in rows and `ast.literal_eval` Huge thanks and credits to @jzluo for finding the weaknesses in the current approach and for the thoughtful discussion on the best way to implement this. --------- Co-authored-by: Francisco Ingham <> Co-authored-by: Jon Luo <20971593+jzluo@users.noreply.github.com>
This commit is contained in:
parent
483821ea3b
commit
3f29742adc
@ -1,23 +1,12 @@
|
|||||||
"""SQLAlchemy wrapper around a database."""
|
"""SQLAlchemy wrapper around a database."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import ast
|
|
||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
|
|
||||||
from sqlalchemy import create_engine, inspect
|
from sqlalchemy import MetaData, create_engine, inspect, select
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.exc import ProgrammingError
|
||||||
_TEMPLATE_PREFIX = """Table data will be described in the following format:
|
from sqlalchemy.schema import CreateTable
|
||||||
|
|
||||||
Table 'table name' has columns: {
|
|
||||||
column1 name: (column1 type, [list of example values for column1]),
|
|
||||||
column2 name: (column2 type, [list of example values for column2]),
|
|
||||||
...
|
|
||||||
}
|
|
||||||
|
|
||||||
These are the tables you can use, together with their column information:
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class SQLDatabase:
|
class SQLDatabase:
|
||||||
@ -27,6 +16,7 @@ class SQLDatabase:
|
|||||||
self,
|
self,
|
||||||
engine: Engine,
|
engine: Engine,
|
||||||
schema: Optional[str] = None,
|
schema: Optional[str] = None,
|
||||||
|
metadata: Optional[MetaData] = None,
|
||||||
ignore_tables: Optional[List[str]] = None,
|
ignore_tables: Optional[List[str]] = None,
|
||||||
include_tables: Optional[List[str]] = None,
|
include_tables: Optional[List[str]] = None,
|
||||||
sample_rows_in_table_info: int = 3,
|
sample_rows_in_table_info: int = 3,
|
||||||
@ -53,8 +43,15 @@ class SQLDatabase:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"ignore_tables {missing_tables} not found in database"
|
f"ignore_tables {missing_tables} not found in database"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not isinstance(sample_rows_in_table_info, int):
|
||||||
|
raise TypeError("sample_rows_in_table_info must be an integer")
|
||||||
|
|
||||||
self._sample_rows_in_table_info = sample_rows_in_table_info
|
self._sample_rows_in_table_info = sample_rows_in_table_info
|
||||||
|
|
||||||
|
self._metadata = metadata or MetaData()
|
||||||
|
self._metadata.reflect(bind=self._engine)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase:
|
def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase:
|
||||||
"""Construct a SQLAlchemy engine from URI."""
|
"""Construct a SQLAlchemy engine from URI."""
|
||||||
@ -93,52 +90,53 @@ class SQLDatabase:
|
|||||||
raise ValueError(f"table_names {missing_tables} not found in database")
|
raise ValueError(f"table_names {missing_tables} not found in database")
|
||||||
all_table_names = table_names
|
all_table_names = table_names
|
||||||
|
|
||||||
tables = []
|
meta_tables = [
|
||||||
for table_name in all_table_names:
|
tbl
|
||||||
columns = []
|
for tbl in self._metadata.sorted_tables
|
||||||
if self.dialect in ("sqlite", "duckdb"):
|
if tbl.name in set(all_table_names)
|
||||||
create_table = self.run(
|
]
|
||||||
(
|
|
||||||
"SELECT sql FROM sqlite_master WHERE "
|
|
||||||
f"type='table' AND name='{table_name}'"
|
|
||||||
),
|
|
||||||
fetch="one",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
create_table = self.run(
|
|
||||||
f"SHOW CREATE TABLE `{table_name}`;",
|
|
||||||
)
|
|
||||||
|
|
||||||
for column in self._inspector.get_columns(table_name, schema=self._schema):
|
tables = []
|
||||||
columns.append(column["name"])
|
for table in meta_tables:
|
||||||
|
# add create table command
|
||||||
|
create_table = str(CreateTable(table).compile(self._engine))
|
||||||
|
|
||||||
if self._sample_rows_in_table_info:
|
if self._sample_rows_in_table_info:
|
||||||
if self.dialect in ("sqlite", "duckdb"):
|
# build the select command
|
||||||
|
command = select(table).limit(self._sample_rows_in_table_info)
|
||||||
|
|
||||||
|
# save the command in string format
|
||||||
select_star = (
|
select_star = (
|
||||||
f"SELECT * FROM '{table_name}' LIMIT "
|
f"SELECT * FROM '{table.name}' LIMIT "
|
||||||
f"{self._sample_rows_in_table_info}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
select_star = (
|
|
||||||
f"SELECT * FROM `{table_name}` LIMIT "
|
|
||||||
f"{self._sample_rows_in_table_info}"
|
f"{self._sample_rows_in_table_info}"
|
||||||
)
|
)
|
||||||
|
|
||||||
sample_rows = self.run(select_star)
|
# save the columns in string format
|
||||||
|
columns_str = " ".join([col.name for col in table.columns])
|
||||||
|
|
||||||
sample_rows_ls = ast.literal_eval(sample_rows)
|
# get the sample rows
|
||||||
sample_rows_ls = list(
|
with self._engine.connect() as connection:
|
||||||
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_ls)
|
sample_rows = connection.execute(command)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# shorten values in the smaple rows
|
||||||
|
sample_rows = list(
|
||||||
|
map(lambda ls: [str(i)[:100] for i in ls], sample_rows)
|
||||||
)
|
)
|
||||||
|
|
||||||
columns_str = " ".join(columns)
|
# save the sample rows in string format
|
||||||
sample_rows_str = "\n".join([" ".join(row) for row in sample_rows_ls])
|
sample_rows_str = "\n".join([" ".join(row) for row in sample_rows])
|
||||||
|
|
||||||
|
# in some dialects when there are no rows in the table a
|
||||||
|
# 'ProgrammingError' is returned
|
||||||
|
except ProgrammingError:
|
||||||
|
sample_rows_str = ""
|
||||||
|
|
||||||
|
# build final info for table
|
||||||
tables.append(
|
tables.append(
|
||||||
create_table
|
create_table
|
||||||
+ "\n\n"
|
|
||||||
+ select_star
|
+ select_star
|
||||||
+ "\n"
|
+ ";\n"
|
||||||
+ columns_str
|
+ columns_str
|
||||||
+ "\n"
|
+ "\n"
|
||||||
+ sample_rows_str
|
+ sample_rows_str
|
||||||
@ -147,7 +145,7 @@ class SQLDatabase:
|
|||||||
else:
|
else:
|
||||||
tables.append(create_table)
|
tables.append(create_table)
|
||||||
|
|
||||||
final_str = "\n\n\n".join(tables)
|
final_str = "\n\n".join(tables)
|
||||||
return final_str
|
return final_str
|
||||||
|
|
||||||
def run(self, command: str, fetch: str = "all") -> str:
|
def run(self, command: str, fetch: str = "all") -> str:
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert
|
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert
|
||||||
|
|
||||||
from langchain.sql_database import _TEMPLATE_PREFIX, SQLDatabase
|
from langchain.sql_database import SQLDatabase
|
||||||
|
|
||||||
metadata_obj = MetaData()
|
metadata_obj = MetaData()
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ def test_table_info() -> None:
|
|||||||
PRIMARY KEY (user_id)
|
PRIMARY KEY (user_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
SELECT * FROM 'user' LIMIT 3
|
SELECT * FROM 'user' LIMIT 3;
|
||||||
user_id user_name
|
user_id user_name
|
||||||
|
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ def test_table_info() -> None:
|
|||||||
PRIMARY KEY (company_id)
|
PRIMARY KEY (company_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
SELECT * FROM 'company' LIMIT 3
|
SELECT * FROM 'company' LIMIT 3;
|
||||||
company_id company_location
|
company_id company_location
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -75,7 +75,7 @@ def test_table_info_w_sample_rows() -> None:
|
|||||||
PRIMARY KEY (company_id)
|
PRIMARY KEY (company_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
SELECT * FROM 'company' LIMIT 2
|
SELECT * FROM 'company' LIMIT 2;
|
||||||
company_id company_location
|
company_id company_location
|
||||||
|
|
||||||
|
|
||||||
@ -85,7 +85,7 @@ def test_table_info_w_sample_rows() -> None:
|
|||||||
PRIMARY KEY (user_id)
|
PRIMARY KEY (user_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
SELECT * FROM 'user' LIMIT 2
|
SELECT * FROM 'user' LIMIT 2;
|
||||||
user_id user_name
|
user_id user_name
|
||||||
13 Harrison
|
13 Harrison
|
||||||
14 Chase
|
14 Chase
|
||||||
|
@ -45,12 +45,17 @@ def test_table_info() -> None:
|
|||||||
"""Test that table info is constructed properly."""
|
"""Test that table info is constructed properly."""
|
||||||
engine = create_engine("duckdb:///:memory:")
|
engine = create_engine("duckdb:///:memory:")
|
||||||
metadata_obj.create_all(engine)
|
metadata_obj.create_all(engine)
|
||||||
db = SQLDatabase(engine, schema="schema_a")
|
|
||||||
|
db = SQLDatabase(engine, schema="schema_a", metadata=metadata_obj)
|
||||||
output = db.table_info
|
output = db.table_info
|
||||||
expected_output = """
|
expected_output = """
|
||||||
CREATE TABLE schema_a."user"(user_id INTEGER, user_name VARCHAR NOT NULL, PRIMARY KEY(user_id));
|
CREATE TABLE schema_a."user" (
|
||||||
|
user_id INTEGER NOT NULL,
|
||||||
|
user_name VARCHAR NOT NULL,
|
||||||
|
PRIMARY KEY (user_id)
|
||||||
|
)
|
||||||
|
|
||||||
SELECT * FROM 'user' LIMIT 3
|
SELECT * FROM 'user' LIMIT 3;
|
||||||
user_id user_name
|
user_id user_name
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user