Harrison/align table (#1081)

Co-authored-by: Francisco Ingham <fpingham@gmail.com>
This commit is contained in:
Harrison Chase 2023-02-15 23:53:37 -08:00 committed by GitHub
parent c60954d0f8
commit 5e10e19bfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 137 additions and 42 deletions

View File

@ -29,7 +29,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"id": "d0e27d88",
"metadata": {
"pycharm": {
@ -43,7 +43,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "72ede462",
"metadata": {
"pycharm": {
@ -346,15 +346,46 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CREATE TABLE [Album]\n",
"(\n",
" [AlbumId] INTEGER NOT NULL,\n",
" [Title] NVARCHAR(160) NOT NULL,\n",
" [ArtistId] INTEGER NOT NULL,\n",
" CONSTRAINT [PK_Album] PRIMARY KEY ([AlbumId]),\n",
" FOREIGN KEY ([ArtistId]) REFERENCES [Artist] ([ArtistId]) \n",
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n",
")\n",
"\n",
" Table data will be described in the following format:\n",
"SELECT * FROM 'Album' LIMIT 2\n",
"AlbumId Title ArtistId\n",
"1 For Those About To Rock We Salute You 1\n",
"2 Balls to the Wall 2\n",
"\n",
" Table 'table name' has columns: {column1 name: (column1 type, [list of example values for column1]),\n",
" column2 name: (column2 type, [list of example values for column2], ...)\n",
"\n",
" These are the tables you can use, together with their column information:\n",
"CREATE TABLE [Track]\n",
"(\n",
" [TrackId] INTEGER NOT NULL,\n",
" [Name] NVARCHAR(200) NOT NULL,\n",
" [AlbumId] INTEGER,\n",
" [MediaTypeId] INTEGER NOT NULL,\n",
" [GenreId] INTEGER,\n",
" [Composer] NVARCHAR(220),\n",
" [Milliseconds] INTEGER NOT NULL,\n",
" [Bytes] INTEGER,\n",
" [UnitPrice] NUMERIC(10,2) NOT NULL,\n",
" CONSTRAINT [PK_Track] PRIMARY KEY ([TrackId]),\n",
" FOREIGN KEY ([AlbumId]) REFERENCES [Album] ([AlbumId]) \n",
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION,\n",
" FOREIGN KEY ([GenreId]) REFERENCES [Genre] ([GenreId]) \n",
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION,\n",
" FOREIGN KEY ([MediaTypeId]) REFERENCES [MediaType] ([MediaTypeId]) \n",
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n",
")\n",
"\n",
" Table 'Track' has columns: {'TrackId': ['INTEGER', ['1', '2']], 'Name': ['NVARCHAR(200)', ['For Those About To Rock (We Salute You)', 'Balls to the Wall']], 'AlbumId': ['INTEGER', ['1', '2']], 'MediaTypeId': ['INTEGER', ['1', '2']], 'GenreId': ['INTEGER', ['1', '1']], 'Composer': ['NVARCHAR(220)', ['Angus Young, Malcolm Young, Brian Johnson', 'None']], 'Milliseconds': ['INTEGER', ['343719', '342562']], 'Bytes': ['INTEGER', ['11170334', '5510424']], 'UnitPrice': ['NUMERIC(10, 2)', ['0.99', '0.99']]}\n"
"SELECT * FROM 'Track' LIMIT 2\n",
"TrackId Name AlbumId MediaTypeId GenreId Composer Milliseconds Bytes UnitPrice\n",
"1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99\n",
"2 Balls to the Wall 2 2 1 None 342562 5510424 0.99\n"
]
}
],
@ -492,9 +523,9 @@
"lastKernelId": null
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "langchain",
"language": "python",
"name": "python3"
"name": "langchain"
},
"language_info": {
"codemirror_mode": {
@ -506,7 +537,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.8.16"
}
},
"nbformat": 4,

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import ast
from collections import defaultdict
from typing import Any, Iterable, List, Optional
from sqlalchemy import create_engine, inspect
@ -30,7 +29,7 @@ class SQLDatabase:
schema: Optional[str] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 0,
sample_rows_in_table_info: int = 3,
):
"""Create engine from database URI."""
self._engine = engine
@ -80,9 +79,12 @@ class SQLDatabase:
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables.
Follows best practices as specified in: Rajkumar et al, 2022
(https://arxiv.org/abs/2204.00498)
If `sample_rows_in_table_info`, the specified number of sample rows will be
appended to each table description. This can increase performance as
demonstrated by Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498).
demonstrated in the paper.
"""
all_table_names = self.get_table_names()
if table_names is not None:
@ -93,33 +95,51 @@ class SQLDatabase:
tables = []
for table_name in all_table_names:
columns = defaultdict(list)
columns = []
create_table = self.run(
(
"SELECT sql FROM sqlite_master WHERE "
f"type='table' AND name='{table_name}'"
),
fetch="one",
)
for column in self._inspector.get_columns(table_name, schema=self._schema):
columns[f"{column['name']}"].append(str(column["type"]))
columns.append(column["name"])
if self._sample_rows_in_table_info:
sample_rows = self.run(
select_star = (
f"SELECT * FROM '{table_name}' LIMIT "
f"{self._sample_rows_in_table_info}"
)
sample_rows = self.run(select_star)
sample_rows_ls = ast.literal_eval(sample_rows)
sample_rows_ls = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_ls)
)
for e, col in enumerate(columns):
columns[col].append(
[row[e] for row in sample_rows_ls] # type: ignore
columns_str = " ".join(columns)
sample_rows_str = "\n".join([" ".join(row) for row in sample_rows_ls])
tables.append(
create_table
+ "\n\n"
+ select_star
+ "\n"
+ columns_str
+ "\n"
+ sample_rows_str
)
table_str = f"Table '{table_name}' has columns: " + str(dict(columns))
tables.append(table_str)
else:
tables.append(create_table)
final_str = _TEMPLATE_PREFIX + "\n".join(tables)
final_str = "\n\n\n".join(tables)
return final_str
def run(self, command: str) -> str:
def run(self, command: str, fetch: str = "all") -> str:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
@ -130,6 +150,11 @@ class SQLDatabase:
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
cursor = connection.exec_driver_sql(command)
if cursor.returns_rows:
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0]
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
return str(result)
return ""

View File

@ -28,12 +28,28 @@ def test_table_info() -> None:
metadata_obj.create_all(engine)
db = SQLDatabase(engine)
output = db.table_info
output = output[len(_TEMPLATE_PREFIX) :]
expected_output = (
"Table 'user' has columns: {'user_id': ['INTEGER'], 'user_name': ['VARCHAR(16)']}",
"Table 'company' has columns: {'company_id': ['INTEGER'], 'company_location': ['VARCHAR']}",
expected_output = """
CREATE TABLE user (
user_id INTEGER NOT NULL,
user_name VARCHAR(16) NOT NULL,
PRIMARY KEY (user_id)
)
assert sorted(output.split("\n")) == sorted(expected_output)
SELECT * FROM 'user' LIMIT 3
user_id user_name
CREATE TABLE company (
company_id INTEGER NOT NULL,
company_location VARCHAR NOT NULL,
PRIMARY KEY (company_id)
)
SELECT * FROM 'company' LIMIT 3
company_id company_location
"""
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))
def test_table_info_w_sample_rows() -> None:
@ -51,12 +67,31 @@ def test_table_info_w_sample_rows() -> None:
db = SQLDatabase(engine, sample_rows_in_table_info=2)
output = db.table_info
output = output[len(_TEMPLATE_PREFIX) :]
expected_output = (
"Table 'user' has columns: {'user_id': ['INTEGER', ['13', '14']], 'user_name': ['VARCHAR(16)', ['Harrison', 'Chase']]}",
"Table 'company' has columns: {'company_id': ['INTEGER', []], 'company_location': ['VARCHAR', []]}",
expected_output = """
CREATE TABLE company (
company_id INTEGER NOT NULL,
company_location VARCHAR NOT NULL,
PRIMARY KEY (company_id)
)
SELECT * FROM 'company' LIMIT 2
company_id company_location
CREATE TABLE user (
user_id INTEGER NOT NULL,
user_name VARCHAR(16) NOT NULL,
PRIMARY KEY (user_id)
)
assert sorted(output.split("\n")) == sorted(expected_output)
SELECT * FROM 'user' LIMIT 2
user_id user_name
13 Harrison
14 Chase
"""
assert sorted(output.split()) == sorted(expected_output.split())
def test_sql_database_run() -> None:

View File

@ -1,3 +1,4 @@
# flake8: noqa
"""Test SQL database wrapper with schema support.
Using DuckDB as SQLite does not support schemas.
@ -16,7 +17,7 @@ from sqlalchemy import (
schema,
)
from langchain.sql_database import _TEMPLATE_PREFIX, SQLDatabase
from langchain.sql_database import SQLDatabase
metadata_obj = MetaData()
@ -46,11 +47,14 @@ def test_table_info() -> None:
metadata_obj.create_all(engine)
db = SQLDatabase(engine, schema="schema_a")
output = db.table_info
output = output[len(_TEMPLATE_PREFIX) :]
expected_output = (
"Table 'user' has columns: {'user_id': ['INTEGER'], 'user_name': ['VARCHAR']}"
)
assert output == expected_output
expected_output = """
CREATE TABLE schema_a."user"(user_id INTEGER, user_name VARCHAR NOT NULL, PRIMARY KEY(user_id));
SELECT * FROM 'user' LIMIT 3
user_id user_name
"""
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))
def test_sql_database_run() -> None: