mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Harrison/align table (#1081)
Co-authored-by: Francisco Ingham <fpingham@gmail.com>
This commit is contained in:
parent
c60954d0f8
commit
5e10e19bfe
@ -29,7 +29,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 3,
|
||||
"id": "d0e27d88",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
@ -43,7 +43,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 1,
|
||||
"id": "72ede462",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
@ -346,15 +346,46 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CREATE TABLE [Album]\n",
|
||||
"(\n",
|
||||
" [AlbumId] INTEGER NOT NULL,\n",
|
||||
" [Title] NVARCHAR(160) NOT NULL,\n",
|
||||
" [ArtistId] INTEGER NOT NULL,\n",
|
||||
" CONSTRAINT [PK_Album] PRIMARY KEY ([AlbumId]),\n",
|
||||
" FOREIGN KEY ([ArtistId]) REFERENCES [Artist] ([ArtistId]) \n",
|
||||
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n",
|
||||
")\n",
|
||||
"\n",
|
||||
" Table data will be described in the following format:\n",
|
||||
"SELECT * FROM 'Album' LIMIT 2\n",
|
||||
"AlbumId Title ArtistId\n",
|
||||
"1 For Those About To Rock We Salute You 1\n",
|
||||
"2 Balls to the Wall 2\n",
|
||||
"\n",
|
||||
" Table 'table name' has columns: {column1 name: (column1 type, [list of example values for column1]),\n",
|
||||
" column2 name: (column2 type, [list of example values for column2], ...)\n",
|
||||
"\n",
|
||||
" These are the tables you can use, together with their column information:\n",
|
||||
"CREATE TABLE [Track]\n",
|
||||
"(\n",
|
||||
" [TrackId] INTEGER NOT NULL,\n",
|
||||
" [Name] NVARCHAR(200) NOT NULL,\n",
|
||||
" [AlbumId] INTEGER,\n",
|
||||
" [MediaTypeId] INTEGER NOT NULL,\n",
|
||||
" [GenreId] INTEGER,\n",
|
||||
" [Composer] NVARCHAR(220),\n",
|
||||
" [Milliseconds] INTEGER NOT NULL,\n",
|
||||
" [Bytes] INTEGER,\n",
|
||||
" [UnitPrice] NUMERIC(10,2) NOT NULL,\n",
|
||||
" CONSTRAINT [PK_Track] PRIMARY KEY ([TrackId]),\n",
|
||||
" FOREIGN KEY ([AlbumId]) REFERENCES [Album] ([AlbumId]) \n",
|
||||
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION,\n",
|
||||
" FOREIGN KEY ([GenreId]) REFERENCES [Genre] ([GenreId]) \n",
|
||||
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION,\n",
|
||||
" FOREIGN KEY ([MediaTypeId]) REFERENCES [MediaType] ([MediaTypeId]) \n",
|
||||
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n",
|
||||
")\n",
|
||||
"\n",
|
||||
" Table 'Track' has columns: {'TrackId': ['INTEGER', ['1', '2']], 'Name': ['NVARCHAR(200)', ['For Those About To Rock (We Salute You)', 'Balls to the Wall']], 'AlbumId': ['INTEGER', ['1', '2']], 'MediaTypeId': ['INTEGER', ['1', '2']], 'GenreId': ['INTEGER', ['1', '1']], 'Composer': ['NVARCHAR(220)', ['Angus Young, Malcolm Young, Brian Johnson', 'None']], 'Milliseconds': ['INTEGER', ['343719', '342562']], 'Bytes': ['INTEGER', ['11170334', '5510424']], 'UnitPrice': ['NUMERIC(10, 2)', ['0.99', '0.99']]}\n"
|
||||
"SELECT * FROM 'Track' LIMIT 2\n",
|
||||
"TrackId Name AlbumId MediaTypeId GenreId Composer Milliseconds Bytes UnitPrice\n",
|
||||
"1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99\n",
|
||||
"2 Balls to the Wall 2 2 1 None 342562 5510424 0.99\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -492,9 +523,9 @@
|
||||
"lastKernelId": null
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "langchain",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
"name": "langchain"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
@ -506,7 +537,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.8.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -2,7 +2,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from collections import defaultdict
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from sqlalchemy import create_engine, inspect
|
||||
@ -30,7 +29,7 @@ class SQLDatabase:
|
||||
schema: Optional[str] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
sample_rows_in_table_info: int = 0,
|
||||
sample_rows_in_table_info: int = 3,
|
||||
):
|
||||
"""Create engine from database URI."""
|
||||
self._engine = engine
|
||||
@ -80,9 +79,12 @@ class SQLDatabase:
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
"""Get information about specified tables.
|
||||
|
||||
Follows best practices as specified in: Rajkumar et al, 2022
|
||||
(https://arxiv.org/abs/2204.00498)
|
||||
|
||||
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
||||
appended to each table description. This can increase performance as
|
||||
demonstrated by Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498).
|
||||
demonstrated in the paper.
|
||||
"""
|
||||
all_table_names = self.get_table_names()
|
||||
if table_names is not None:
|
||||
@ -93,33 +95,51 @@ class SQLDatabase:
|
||||
|
||||
tables = []
|
||||
for table_name in all_table_names:
|
||||
columns = defaultdict(list)
|
||||
columns = []
|
||||
create_table = self.run(
|
||||
(
|
||||
"SELECT sql FROM sqlite_master WHERE "
|
||||
f"type='table' AND name='{table_name}'"
|
||||
),
|
||||
fetch="one",
|
||||
)
|
||||
|
||||
for column in self._inspector.get_columns(table_name, schema=self._schema):
|
||||
columns[f"{column['name']}"].append(str(column["type"]))
|
||||
columns.append(column["name"])
|
||||
|
||||
if self._sample_rows_in_table_info:
|
||||
sample_rows = self.run(
|
||||
select_star = (
|
||||
f"SELECT * FROM '{table_name}' LIMIT "
|
||||
f"{self._sample_rows_in_table_info}"
|
||||
)
|
||||
|
||||
sample_rows = self.run(select_star)
|
||||
|
||||
sample_rows_ls = ast.literal_eval(sample_rows)
|
||||
sample_rows_ls = list(
|
||||
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_ls)
|
||||
)
|
||||
|
||||
for e, col in enumerate(columns):
|
||||
columns[col].append(
|
||||
[row[e] for row in sample_rows_ls] # type: ignore
|
||||
columns_str = " ".join(columns)
|
||||
sample_rows_str = "\n".join([" ".join(row) for row in sample_rows_ls])
|
||||
|
||||
tables.append(
|
||||
create_table
|
||||
+ "\n\n"
|
||||
+ select_star
|
||||
+ "\n"
|
||||
+ columns_str
|
||||
+ "\n"
|
||||
+ sample_rows_str
|
||||
)
|
||||
|
||||
table_str = f"Table '{table_name}' has columns: " + str(dict(columns))
|
||||
tables.append(table_str)
|
||||
else:
|
||||
tables.append(create_table)
|
||||
|
||||
final_str = _TEMPLATE_PREFIX + "\n".join(tables)
|
||||
final_str = "\n\n\n".join(tables)
|
||||
return final_str
|
||||
|
||||
def run(self, command: str) -> str:
|
||||
def run(self, command: str, fetch: str = "all") -> str:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
@ -130,6 +150,11 @@ class SQLDatabase:
|
||||
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
|
||||
cursor = connection.exec_driver_sql(command)
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
elif fetch == "one":
|
||||
result = cursor.fetchone()[0]
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
return str(result)
|
||||
return ""
|
||||
|
@ -28,12 +28,28 @@ def test_table_info() -> None:
|
||||
metadata_obj.create_all(engine)
|
||||
db = SQLDatabase(engine)
|
||||
output = db.table_info
|
||||
output = output[len(_TEMPLATE_PREFIX) :]
|
||||
expected_output = (
|
||||
"Table 'user' has columns: {'user_id': ['INTEGER'], 'user_name': ['VARCHAR(16)']}",
|
||||
"Table 'company' has columns: {'company_id': ['INTEGER'], 'company_location': ['VARCHAR']}",
|
||||
expected_output = """
|
||||
CREATE TABLE user (
|
||||
user_id INTEGER NOT NULL,
|
||||
user_name VARCHAR(16) NOT NULL,
|
||||
PRIMARY KEY (user_id)
|
||||
)
|
||||
assert sorted(output.split("\n")) == sorted(expected_output)
|
||||
|
||||
SELECT * FROM 'user' LIMIT 3
|
||||
user_id user_name
|
||||
|
||||
|
||||
CREATE TABLE company (
|
||||
company_id INTEGER NOT NULL,
|
||||
company_location VARCHAR NOT NULL,
|
||||
PRIMARY KEY (company_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'company' LIMIT 3
|
||||
company_id company_location
|
||||
"""
|
||||
|
||||
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))
|
||||
|
||||
|
||||
def test_table_info_w_sample_rows() -> None:
|
||||
@ -51,12 +67,31 @@ def test_table_info_w_sample_rows() -> None:
|
||||
db = SQLDatabase(engine, sample_rows_in_table_info=2)
|
||||
|
||||
output = db.table_info
|
||||
output = output[len(_TEMPLATE_PREFIX) :]
|
||||
expected_output = (
|
||||
"Table 'user' has columns: {'user_id': ['INTEGER', ['13', '14']], 'user_name': ['VARCHAR(16)', ['Harrison', 'Chase']]}",
|
||||
"Table 'company' has columns: {'company_id': ['INTEGER', []], 'company_location': ['VARCHAR', []]}",
|
||||
|
||||
expected_output = """
|
||||
CREATE TABLE company (
|
||||
company_id INTEGER NOT NULL,
|
||||
company_location VARCHAR NOT NULL,
|
||||
PRIMARY KEY (company_id)
|
||||
)
|
||||
|
||||
SELECT * FROM 'company' LIMIT 2
|
||||
company_id company_location
|
||||
|
||||
|
||||
CREATE TABLE user (
|
||||
user_id INTEGER NOT NULL,
|
||||
user_name VARCHAR(16) NOT NULL,
|
||||
PRIMARY KEY (user_id)
|
||||
)
|
||||
assert sorted(output.split("\n")) == sorted(expected_output)
|
||||
|
||||
SELECT * FROM 'user' LIMIT 2
|
||||
user_id user_name
|
||||
13 Harrison
|
||||
14 Chase
|
||||
"""
|
||||
|
||||
assert sorted(output.split()) == sorted(expected_output.split())
|
||||
|
||||
|
||||
def test_sql_database_run() -> None:
|
||||
|
@ -1,3 +1,4 @@
|
||||
# flake8: noqa
|
||||
"""Test SQL database wrapper with schema support.
|
||||
|
||||
Using DuckDB as SQLite does not support schemas.
|
||||
@ -16,7 +17,7 @@ from sqlalchemy import (
|
||||
schema,
|
||||
)
|
||||
|
||||
from langchain.sql_database import _TEMPLATE_PREFIX, SQLDatabase
|
||||
from langchain.sql_database import SQLDatabase
|
||||
|
||||
metadata_obj = MetaData()
|
||||
|
||||
@ -46,11 +47,14 @@ def test_table_info() -> None:
|
||||
metadata_obj.create_all(engine)
|
||||
db = SQLDatabase(engine, schema="schema_a")
|
||||
output = db.table_info
|
||||
output = output[len(_TEMPLATE_PREFIX) :]
|
||||
expected_output = (
|
||||
"Table 'user' has columns: {'user_id': ['INTEGER'], 'user_name': ['VARCHAR']}"
|
||||
)
|
||||
assert output == expected_output
|
||||
expected_output = """
|
||||
CREATE TABLE schema_a."user"(user_id INTEGER, user_name VARCHAR NOT NULL, PRIMARY KEY(user_id));
|
||||
|
||||
SELECT * FROM 'user' LIMIT 3
|
||||
user_id user_name
|
||||
"""
|
||||
|
||||
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))
|
||||
|
||||
|
||||
def test_sql_database_run() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user