From 45b5640fe5f3ac929338492aa7bbe5a24b4aa352 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 18 Feb 2023 11:49:08 -0800 Subject: [PATCH] fix sql (#1141) --- docs/modules/chains/examples/sqlite.ipynb | 117 +++++++++++----------- langchain/sql_database.py | 15 ++- 2 files changed, 68 insertions(+), 64 deletions(-) diff --git a/docs/modules/chains/examples/sqlite.ipynb b/docs/modules/chains/examples/sqlite.ipynb index 909b7027..b4b621ae 100644 --- a/docs/modules/chains/examples/sqlite.ipynb +++ b/docs/modules/chains/examples/sqlite.ipynb @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "id": "d0e27d88", "metadata": { "pycharm": { @@ -43,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "72ede462", "metadata": { "pycharm": { @@ -92,7 +92,22 @@ "\n", "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n", "How many employees are there? \n", - "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n", + "SQLQuery:" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/harrisonchase/workplace/langchain/langchain/sql_database.py:120: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.\n", + " sample_rows = connection.execute(command)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n", "SQLResult: \u001b[33;1m\u001b[1;3m[(8,)]\u001b[0m\n", "Answer:\u001b[32;1m\u001b[1;3m There are 8 employees.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" @@ -285,16 +300,16 @@ "\n", "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n", "What are some example tracks by composer Johann Sebastian Bach? \n", - "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name, Composer FROM Track WHERE Composer = 'Johann Sebastian Bach' LIMIT 3;\u001b[0m\n", + "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name, Composer FROM Track WHERE Composer LIKE '%Johann Sebastian Bach%' LIMIT 3;\u001b[0m\n", "SQLResult: \u001b[33;1m\u001b[1;3m[('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', 'Johann Sebastian Bach')]\u001b[0m\n", - "Answer:\u001b[32;1m\u001b[1;3m Examples of tracks by composer Johann Sebastian Bach are 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', and 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude'.\u001b[0m\n", + "Answer:\u001b[32;1m\u001b[1;3m Some example tracks by composer Johann Sebastian Bach are 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', and 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude'.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "' Examples of tracks by composer Johann Sebastian Bach are \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', and \\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\\'.'" + "' Some example tracks by composer Johann Sebastian Bach are \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', and \\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\\'.'" ] }, "execution_count": 11, @@ -317,7 +332,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "9a22ee47", "metadata": {}, "outputs": [], @@ -338,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "id": "9de86267", "metadata": {}, "outputs": [ @@ -346,43 +361,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "CREATE TABLE [Album]\n", - "(\n", - " [AlbumId] INTEGER NOT NULL,\n", - " [Title] NVARCHAR(160) NOT NULL,\n", - " [ArtistId] INTEGER NOT NULL,\n", - " CONSTRAINT [PK_Album] PRIMARY KEY ([AlbumId]),\n", - " FOREIGN KEY ([ArtistId]) REFERENCES [Artist] ([ArtistId]) \n", - "\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n", - ")\n", - "\n", - "SELECT * FROM 'Album' LIMIT 2\n", - "AlbumId Title ArtistId\n", - "1 For Those About To Rock We Salute You 1\n", - "2 Balls to the Wall 2\n", "\n", - "\n", - "CREATE TABLE [Track]\n", - "(\n", - " [TrackId] INTEGER NOT NULL,\n", - " [Name] NVARCHAR(200) NOT NULL,\n", - " [AlbumId] INTEGER,\n", - " [MediaTypeId] INTEGER NOT NULL,\n", - " [GenreId] INTEGER,\n", - " [Composer] NVARCHAR(220),\n", - " [Milliseconds] INTEGER NOT NULL,\n", - " [Bytes] INTEGER,\n", - " [UnitPrice] NUMERIC(10,2) NOT NULL,\n", - " CONSTRAINT [PK_Track] PRIMARY KEY ([TrackId]),\n", - " FOREIGN KEY ([AlbumId]) REFERENCES [Album] ([AlbumId]) \n", - "\t\tON DELETE NO ACTION ON UPDATE NO ACTION,\n", - " FOREIGN KEY ([GenreId]) REFERENCES [Genre] ([GenreId]) \n", - "\t\tON DELETE NO ACTION ON UPDATE NO ACTION,\n", - " FOREIGN KEY ([MediaTypeId]) REFERENCES [MediaType] ([MediaTypeId]) \n", - "\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n", + "CREATE TABLE \"Track\" (\n", + "\t\"TrackId\" INTEGER NOT NULL, \n", + "\t\"Name\" NVARCHAR(200) NOT NULL, \n", + "\t\"AlbumId\" INTEGER, \n", + "\t\"MediaTypeId\" INTEGER NOT NULL, \n", + "\t\"GenreId\" INTEGER, \n", + "\t\"Composer\" NVARCHAR(220), \n", + "\t\"Milliseconds\" INTEGER NOT NULL, \n", + "\t\"Bytes\" INTEGER, \n", + "\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n", + "\tPRIMARY KEY (\"TrackId\"), \n", + "\tFOREIGN KEY(\"MediaTypeId\") REFERENCES \"MediaType\" (\"MediaTypeId\"), \n", + "\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \n", + "\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\n", ")\n", "\n", - "SELECT * FROM 'Track' LIMIT 2\n", + "SELECT * FROM 'Track' LIMIT 2;\n", "TrackId Name AlbumId MediaTypeId GenreId Composer Milliseconds Bytes UnitPrice\n", "1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99\n", "2 Balls to the Wall 2 2 1 None 342562 5510424 0.99\n" @@ -395,7 +391,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "id": "bcb7a489", "metadata": {}, "outputs": [], @@ -405,7 +401,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "id": "81e05d82", "metadata": {}, "outputs": [ @@ -429,7 +425,7 @@ "' Some example tracks by Bach are \\'American Woman\\', \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', \\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\\', and \\'Toccata and Fugue in D Minor, BWV 565: I. Toccata\\'.'" ] }, - "execution_count": 16, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -457,17 +453,18 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 20, "id": "e59a4740", "metadata": {}, "outputs": [], "source": [ - "from langchain.chains import SQLDatabaseSequentialChain" + "from langchain.chains import SQLDatabaseSequentialChain\n", + "db = SQLDatabase.from_uri(\"sqlite:///../../../../notebooks/Chinook.db\")" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 21, "id": "58bb49b6", "metadata": {}, "outputs": [], @@ -477,7 +474,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 22, "id": "95017b1a", "metadata": {}, "outputs": [ @@ -493,9 +490,9 @@ "\n", "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n", "How many employees are also customers? \n", - "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Customer c INNER JOIN Employee e ON c.SupportRepId = e.EmployeeId;\u001b[0m\n", + "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee INNER JOIN Customer ON Employee.EmployeeId = Customer.SupportRepId;\u001b[0m\n", "SQLResult: \u001b[33;1m\u001b[1;3m[(59,)]\u001b[0m\n", - "Answer:\u001b[32;1m\u001b[1;3m There are 59 employees who are also customers.\u001b[0m\n", + "Answer:\u001b[32;1m\u001b[1;3m 59 employees are also customers.\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" @@ -504,10 +501,10 @@ { "data": { "text/plain": [ - "' There are 59 employees who are also customers.'" + "' 59 employees are also customers.'" ] }, - "execution_count": 5, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -515,6 +512,14 @@ "source": [ "chain.run(\"How many employees are also customers?\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5eb39db6", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -523,9 +528,9 @@ "lastKernelId": null }, "kernelspec": { - "display_name": "langchain", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "langchain" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -537,7 +542,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/langchain/sql_database.py b/langchain/sql_database.py index aedde352..2abca30c 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -114,15 +114,14 @@ class SQLDatabase: # save the columns in string format columns_str = " ".join([col.name for col in table.columns]) - # get the sample rows - with self._engine.connect() as connection: - sample_rows = connection.execute(command) - try: - # shorten values in the smaple rows - sample_rows = list( - map(lambda ls: [str(i)[:100] for i in ls], sample_rows) - ) + # get the sample rows + with self._engine.connect() as connection: + sample_rows = connection.execute(command) + # shorten values in the sample rows + sample_rows = list( + map(lambda ls: [str(i)[:100] for i in ls], sample_rows) + ) # save the sample rows in string format sample_rows_str = "\n".join([" ".join(row) for row in sample_rows])