Harrison/align table (#1081)

Co-authored-by: Francisco Ingham <fpingham@gmail.com>
searx-api
Harrison Chase 1 year ago committed by GitHub
parent c60954d0f8
commit 5e10e19bfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -29,7 +29,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 3,
"id": "d0e27d88", "id": "d0e27d88",
"metadata": { "metadata": {
"pycharm": { "pycharm": {
@ -43,7 +43,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 1,
"id": "72ede462", "id": "72ede462",
"metadata": { "metadata": {
"pycharm": { "pycharm": {
@ -346,15 +346,46 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"CREATE TABLE [Album]\n",
"(\n",
" [AlbumId] INTEGER NOT NULL,\n",
" [Title] NVARCHAR(160) NOT NULL,\n",
" [ArtistId] INTEGER NOT NULL,\n",
" CONSTRAINT [PK_Album] PRIMARY KEY ([AlbumId]),\n",
" FOREIGN KEY ([ArtistId]) REFERENCES [Artist] ([ArtistId]) \n",
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n",
")\n",
"\n", "\n",
" Table data will be described in the following format:\n", "SELECT * FROM 'Album' LIMIT 2\n",
"AlbumId Title ArtistId\n",
"1 For Those About To Rock We Salute You 1\n",
"2 Balls to the Wall 2\n",
"\n", "\n",
" Table 'table name' has columns: {column1 name: (column1 type, [list of example values for column1]),\n",
" column2 name: (column2 type, [list of example values for column2], ...)\n",
"\n", "\n",
" These are the tables you can use, together with their column information:\n", "CREATE TABLE [Track]\n",
"(\n",
" [TrackId] INTEGER NOT NULL,\n",
" [Name] NVARCHAR(200) NOT NULL,\n",
" [AlbumId] INTEGER,\n",
" [MediaTypeId] INTEGER NOT NULL,\n",
" [GenreId] INTEGER,\n",
" [Composer] NVARCHAR(220),\n",
" [Milliseconds] INTEGER NOT NULL,\n",
" [Bytes] INTEGER,\n",
" [UnitPrice] NUMERIC(10,2) NOT NULL,\n",
" CONSTRAINT [PK_Track] PRIMARY KEY ([TrackId]),\n",
" FOREIGN KEY ([AlbumId]) REFERENCES [Album] ([AlbumId]) \n",
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION,\n",
" FOREIGN KEY ([GenreId]) REFERENCES [Genre] ([GenreId]) \n",
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION,\n",
" FOREIGN KEY ([MediaTypeId]) REFERENCES [MediaType] ([MediaTypeId]) \n",
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n",
")\n",
"\n", "\n",
" Table 'Track' has columns: {'TrackId': ['INTEGER', ['1', '2']], 'Name': ['NVARCHAR(200)', ['For Those About To Rock (We Salute You)', 'Balls to the Wall']], 'AlbumId': ['INTEGER', ['1', '2']], 'MediaTypeId': ['INTEGER', ['1', '2']], 'GenreId': ['INTEGER', ['1', '1']], 'Composer': ['NVARCHAR(220)', ['Angus Young, Malcolm Young, Brian Johnson', 'None']], 'Milliseconds': ['INTEGER', ['343719', '342562']], 'Bytes': ['INTEGER', ['11170334', '5510424']], 'UnitPrice': ['NUMERIC(10, 2)', ['0.99', '0.99']]}\n" "SELECT * FROM 'Track' LIMIT 2\n",
"TrackId Name AlbumId MediaTypeId GenreId Composer Milliseconds Bytes UnitPrice\n",
"1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99\n",
"2 Balls to the Wall 2 2 1 None 342562 5510424 0.99\n"
] ]
} }
], ],
@ -492,9 +523,9 @@
"lastKernelId": null "lastKernelId": null
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "langchain",
"language": "python", "language": "python",
"name": "python3" "name": "langchain"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
@ -506,7 +537,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.1" "version": "3.8.16"
} }
}, },
"nbformat": 4, "nbformat": 4,

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import ast import ast
from collections import defaultdict
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
from sqlalchemy import create_engine, inspect from sqlalchemy import create_engine, inspect
@ -30,7 +29,7 @@ class SQLDatabase:
schema: Optional[str] = None, schema: Optional[str] = None,
ignore_tables: Optional[List[str]] = None, ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 0, sample_rows_in_table_info: int = 3,
): ):
"""Create engine from database URI.""" """Create engine from database URI."""
self._engine = engine self._engine = engine
@ -80,9 +79,12 @@ class SQLDatabase:
def get_table_info(self, table_names: Optional[List[str]] = None) -> str: def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables. """Get information about specified tables.
Follows best practices as specified in: Rajkumar et al, 2022
(https://arxiv.org/abs/2204.00498)
If `sample_rows_in_table_info`, the specified number of sample rows will be If `sample_rows_in_table_info`, the specified number of sample rows will be
appended to each table description. This can increase performance as appended to each table description. This can increase performance as
demonstrated by Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498). demonstrated in the paper.
""" """
all_table_names = self.get_table_names() all_table_names = self.get_table_names()
if table_names is not None: if table_names is not None:
@ -93,33 +95,51 @@ class SQLDatabase:
tables = [] tables = []
for table_name in all_table_names: for table_name in all_table_names:
columns = defaultdict(list) columns = []
create_table = self.run(
(
"SELECT sql FROM sqlite_master WHERE "
f"type='table' AND name='{table_name}'"
),
fetch="one",
)
for column in self._inspector.get_columns(table_name, schema=self._schema): for column in self._inspector.get_columns(table_name, schema=self._schema):
columns[f"{column['name']}"].append(str(column["type"])) columns.append(column["name"])
if self._sample_rows_in_table_info: if self._sample_rows_in_table_info:
sample_rows = self.run( select_star = (
f"SELECT * FROM '{table_name}' LIMIT " f"SELECT * FROM '{table_name}' LIMIT "
f"{self._sample_rows_in_table_info}" f"{self._sample_rows_in_table_info}"
) )
sample_rows = self.run(select_star)
sample_rows_ls = ast.literal_eval(sample_rows) sample_rows_ls = ast.literal_eval(sample_rows)
sample_rows_ls = list( sample_rows_ls = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_ls) map(lambda ls: [str(i)[:100] for i in ls], sample_rows_ls)
) )
for e, col in enumerate(columns): columns_str = " ".join(columns)
columns[col].append( sample_rows_str = "\n".join([" ".join(row) for row in sample_rows_ls])
[row[e] for row in sample_rows_ls] # type: ignore
) tables.append(
create_table
+ "\n\n"
+ select_star
+ "\n"
+ columns_str
+ "\n"
+ sample_rows_str
)
table_str = f"Table '{table_name}' has columns: " + str(dict(columns)) else:
tables.append(table_str) tables.append(create_table)
final_str = _TEMPLATE_PREFIX + "\n".join(tables) final_str = "\n\n\n".join(tables)
return final_str return final_str
def run(self, command: str) -> str: def run(self, command: str, fetch: str = "all") -> str:
"""Execute a SQL command and return a string representing the results. """Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned. If the statement returns rows, a string of the results is returned.
@ -130,6 +150,11 @@ class SQLDatabase:
connection.exec_driver_sql(f"SET search_path TO {self._schema}") connection.exec_driver_sql(f"SET search_path TO {self._schema}")
cursor = connection.exec_driver_sql(command) cursor = connection.exec_driver_sql(command)
if cursor.returns_rows: if cursor.returns_rows:
result = cursor.fetchall() if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0]
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
return str(result) return str(result)
return "" return ""

@ -28,12 +28,28 @@ def test_table_info() -> None:
metadata_obj.create_all(engine) metadata_obj.create_all(engine)
db = SQLDatabase(engine) db = SQLDatabase(engine)
output = db.table_info output = db.table_info
output = output[len(_TEMPLATE_PREFIX) :] expected_output = """
expected_output = ( CREATE TABLE user (
"Table 'user' has columns: {'user_id': ['INTEGER'], 'user_name': ['VARCHAR(16)']}", user_id INTEGER NOT NULL,
"Table 'company' has columns: {'company_id': ['INTEGER'], 'company_location': ['VARCHAR']}", user_name VARCHAR(16) NOT NULL,
PRIMARY KEY (user_id)
) )
assert sorted(output.split("\n")) == sorted(expected_output)
SELECT * FROM 'user' LIMIT 3
user_id user_name
CREATE TABLE company (
company_id INTEGER NOT NULL,
company_location VARCHAR NOT NULL,
PRIMARY KEY (company_id)
)
SELECT * FROM 'company' LIMIT 3
company_id company_location
"""
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))
def test_table_info_w_sample_rows() -> None: def test_table_info_w_sample_rows() -> None:
@ -51,12 +67,31 @@ def test_table_info_w_sample_rows() -> None:
db = SQLDatabase(engine, sample_rows_in_table_info=2) db = SQLDatabase(engine, sample_rows_in_table_info=2)
output = db.table_info output = db.table_info
output = output[len(_TEMPLATE_PREFIX) :]
expected_output = ( expected_output = """
"Table 'user' has columns: {'user_id': ['INTEGER', ['13', '14']], 'user_name': ['VARCHAR(16)', ['Harrison', 'Chase']]}", CREATE TABLE company (
"Table 'company' has columns: {'company_id': ['INTEGER', []], 'company_location': ['VARCHAR', []]}", company_id INTEGER NOT NULL,
) company_location VARCHAR NOT NULL,
assert sorted(output.split("\n")) == sorted(expected_output) PRIMARY KEY (company_id)
)
SELECT * FROM 'company' LIMIT 2
company_id company_location
CREATE TABLE user (
user_id INTEGER NOT NULL,
user_name VARCHAR(16) NOT NULL,
PRIMARY KEY (user_id)
)
SELECT * FROM 'user' LIMIT 2
user_id user_name
13 Harrison
14 Chase
"""
assert sorted(output.split()) == sorted(expected_output.split())
def test_sql_database_run() -> None: def test_sql_database_run() -> None:

@ -1,3 +1,4 @@
# flake8: noqa
"""Test SQL database wrapper with schema support. """Test SQL database wrapper with schema support.
Using DuckDB as SQLite does not support schemas. Using DuckDB as SQLite does not support schemas.
@ -16,7 +17,7 @@ from sqlalchemy import (
schema, schema,
) )
from langchain.sql_database import _TEMPLATE_PREFIX, SQLDatabase from langchain.sql_database import SQLDatabase
metadata_obj = MetaData() metadata_obj = MetaData()
@ -46,11 +47,14 @@ def test_table_info() -> None:
metadata_obj.create_all(engine) metadata_obj.create_all(engine)
db = SQLDatabase(engine, schema="schema_a") db = SQLDatabase(engine, schema="schema_a")
output = db.table_info output = db.table_info
output = output[len(_TEMPLATE_PREFIX) :] expected_output = """
expected_output = ( CREATE TABLE schema_a."user"(user_id INTEGER, user_name VARCHAR NOT NULL, PRIMARY KEY(user_id));
"Table 'user' has columns: {'user_id': ['INTEGER'], 'user_name': ['VARCHAR']}"
) SELECT * FROM 'user' LIMIT 3
assert output == expected_output user_id user_name
"""
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))
def test_sql_database_run() -> None: def test_sql_database_run() -> None:

Loading…
Cancel
Save