mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Sql alchemy commands used in table info (#1135)
This approach has several advantages: * it improves the readability of the code * removes incompatibilities between SQL dialects * fixes a bug with `datetime` values in rows and `ast.literal_eval` Huge thanks and credits to @jzluo for finding the weaknesses in the current approach and for the thoughtful discussion on the best way to implement this. --------- Co-authored-by: Francisco Ingham <> Co-authored-by: Jon Luo <20971593+jzluo@users.noreply.github.com>
This commit is contained in:
parent
483821ea3b
commit
3f29742adc
@ -1,23 +1,12 @@
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from sqlalchemy import create_engine, inspect
|
||||
from sqlalchemy import MetaData, create_engine, inspect, select
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
_TEMPLATE_PREFIX = """Table data will be described in the following format:
|
||||
|
||||
Table 'table name' has columns: {
|
||||
column1 name: (column1 type, [list of example values for column1]),
|
||||
column2 name: (column2 type, [list of example values for column2]),
|
||||
...
|
||||
}
|
||||
|
||||
These are the tables you can use, together with their column information:
|
||||
|
||||
"""
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
from sqlalchemy.schema import CreateTable
|
||||
|
||||
|
||||
class SQLDatabase:
|
||||
@ -27,6 +16,7 @@ class SQLDatabase:
|
||||
self,
|
||||
engine: Engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
sample_rows_in_table_info: int = 3,
|
||||
@ -53,8 +43,15 @@ class SQLDatabase:
|
||||
raise ValueError(
|
||||
f"ignore_tables {missing_tables} not found in database"
|
||||
)
|
||||
|
||||
if not isinstance(sample_rows_in_table_info, int):
|
||||
raise TypeError("sample_rows_in_table_info must be an integer")
|
||||
|
||||
self._sample_rows_in_table_info = sample_rows_in_table_info
|
||||
|
||||
self._metadata = metadata or MetaData()
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
|
||||
@classmethod
|
||||
def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
@ -93,52 +90,53 @@ class SQLDatabase:
|
||||
raise ValueError(f"table_names {missing_tables} not found in database")
|
||||
all_table_names = table_names
|
||||
|
||||
tables = []
|
||||
for table_name in all_table_names:
|
||||
columns = []
|
||||
if self.dialect in ("sqlite", "duckdb"):
|
||||
create_table = self.run(
|
||||
(
|
||||
"SELECT sql FROM sqlite_master WHERE "
|
||||
f"type='table' AND name='{table_name}'"
|
||||
),
|
||||
fetch="one",
|
||||
)
|
||||
else:
|
||||
create_table = self.run(
|
||||
f"SHOW CREATE TABLE `{table_name}`;",
|
||||
)
|
||||
meta_tables = [
|
||||
tbl
|
||||
for tbl in self._metadata.sorted_tables
|
||||
if tbl.name in set(all_table_names)
|
||||
]
|
||||
|
||||
for column in self._inspector.get_columns(table_name, schema=self._schema):
|
||||
columns.append(column["name"])
|
||||
tables = []
|
||||
for table in meta_tables:
|
||||
# add create table command
|
||||
create_table = str(CreateTable(table).compile(self._engine))
|
||||
|
||||
if self._sample_rows_in_table_info:
|
||||
if self.dialect in ("sqlite", "duckdb"):
|
||||
select_star = (
|
||||
f"SELECT * FROM '{table_name}' LIMIT "
|
||||
f"{self._sample_rows_in_table_info}"
|
||||
)
|
||||
else:
|
||||
select_star = (
|
||||
f"SELECT * FROM `{table_name}` LIMIT "
|
||||
f"{self._sample_rows_in_table_info}"
|
||||
)
|
||||
# build the select command
|
||||
command = select(table).limit(self._sample_rows_in_table_info)
|
||||
|
||||
sample_rows = self.run(select_star)
|
||||
|
||||
sample_rows_ls = ast.literal_eval(sample_rows)
|
||||
sample_rows_ls = list(
|
||||
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_ls)
|
||||
# save the command in string format
|
||||
select_star = (
|
||||
f"SELECT * FROM '{table.name}' LIMIT "
|
||||
f"{self._sample_rows_in_table_info}"
|
||||
)
|
||||
|
||||
columns_str = " ".join(columns)
|
||||
sample_rows_str = "\n".join([" ".join(row) for row in sample_rows_ls])
|
||||
# save the columns in string format
|
||||
columns_str = " ".join([col.name for col in table.columns])
|
||||
|
||||
# get the sample rows
|
||||
with self._engine.connect() as connection:
|
||||
sample_rows = connection.execute(command)
|
||||
|
||||
try:
|
||||
# shorten values in the smaple rows
|
||||
sample_rows = list(
|
||||
map(lambda ls: [str(i)[:100] for i in ls], sample_rows)
|
||||
)
|
||||
|
||||
# save the sample rows in string format
|
||||
sample_rows_str = "\n".join([" ".join(row) for row in sample_rows])
|
||||
|
||||
# in some dialects when there are no rows in the table a
|
||||
# 'ProgrammingError' is returned
|
||||
except ProgrammingError:
|
||||
sample_rows_str = ""
|
||||
|
||||
# build final info for table
|
||||
tables.append(
|
||||
create_table
|
||||
+ "\n\n"
|
||||
+ select_star
|
||||
+ "\n"
|
||||
+ ";\n"
|
||||
+ columns_str
|
||||
+ "\n"
|
||||
+ sample_rows_str
|
||||
@ -147,7 +145,7 @@ class SQLDatabase:
|
||||
else:
|
||||
tables.append(create_table)
|
||||
|
||||
final_str = "\n\n\n".join(tables)
|
||||
final_str = "\n\n".join(tables)
|
||||
return final_str
|
||||
|
||||
def run(self, command: str, fetch: str = "all") -> str:
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert
|
||||
|
||||
from langchain.sql_database import _TEMPLATE_PREFIX, SQLDatabase
|
||||
from langchain.sql_database import SQLDatabase
|
||||
|
||||
metadata_obj = MetaData()
|
||||
|
||||
@ -35,7 +35,7 @@ def test_table_info() -> None:
|
||||
PRIMARY KEY (user_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'user' LIMIT 3
|
||||
SELECT * FROM 'user' LIMIT 3;
|
||||
user_id user_name
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ def test_table_info() -> None:
|
||||
PRIMARY KEY (company_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'company' LIMIT 3
|
||||
SELECT * FROM 'company' LIMIT 3;
|
||||
company_id company_location
|
||||
"""
|
||||
|
||||
@ -75,7 +75,7 @@ def test_table_info_w_sample_rows() -> None:
|
||||
PRIMARY KEY (company_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'company' LIMIT 2
|
||||
SELECT * FROM 'company' LIMIT 2;
|
||||
company_id company_location
|
||||
|
||||
|
||||
@ -85,7 +85,7 @@ def test_table_info_w_sample_rows() -> None:
|
||||
PRIMARY KEY (user_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'user' LIMIT 2
|
||||
SELECT * FROM 'user' LIMIT 2;
|
||||
user_id user_name
|
||||
13 Harrison
|
||||
14 Chase
|
||||
|
@ -45,12 +45,17 @@ def test_table_info() -> None:
|
||||
"""Test that table info is constructed properly."""
|
||||
engine = create_engine("duckdb:///:memory:")
|
||||
metadata_obj.create_all(engine)
|
||||
db = SQLDatabase(engine, schema="schema_a")
|
||||
|
||||
db = SQLDatabase(engine, schema="schema_a", metadata=metadata_obj)
|
||||
output = db.table_info
|
||||
expected_output = """
|
||||
CREATE TABLE schema_a."user"(user_id INTEGER, user_name VARCHAR NOT NULL, PRIMARY KEY(user_id));
|
||||
CREATE TABLE schema_a."user" (
|
||||
user_id INTEGER NOT NULL,
|
||||
user_name VARCHAR NOT NULL,
|
||||
PRIMARY KEY (user_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'user' LIMIT 3
|
||||
SELECT * FROM 'user' LIMIT 3;
|
||||
user_id user_name
|
||||
"""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user