docs: added a few suggestions for sql docs (#16508)

pull/16520/head
Francisco Ingham 8 months ago committed by GitHub
parent 6004e9706f
commit 13cf4594f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -20,6 +20,7 @@
"- It can answer questions based on the databases' schema as well as on the databases' content (like describing a specific table).\n",
"- It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.\n",
"- It can query the database as many times as needed to answer the user question.\n",
"- It will save tokens by only retrieving the schema from relevant tables.\n",
"\n",
"To initialize the agent we'll use the [create_sql_agent](https://api.python.langchain.com/en/latest/agent_toolkits/langchain_community.agent_toolkits.sql.base.create_sql_agent.html) constructor. This agent uses the `SQLDatabaseToolkit` which contains tools to: \n",
"\n",
@ -35,7 +36,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
@ -51,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -81,7 +82,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@ -98,7 +99,7 @@
"\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\""
]
},
"execution_count": 1,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@ -121,9 +122,16 @@
"We'll use an OpenAI chat model and an `\"openai-tools\"` agent, which will use OpenAI's function-calling API to drive the agent's tool selection and invocations."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, the agent will first choose which tables are relevant and then add the schema for those tables and a few sample rows to the prompt."
]
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@ -136,7 +144,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 45,
"metadata": {},
"outputs": [
{
@ -203,7 +211,7 @@
"2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n",
"3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n",
"*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query` with `SELECT c.Country, SUM(i.Total) AS TotalSales FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC`\n",
"Invoking: `sql_db_query` with `SELECT c.Country, SUM(i.Total) AS TotalSales FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC LIMIT 10;`\n",
"responded: To list the total sales per country, I can query the \"Invoice\" and \"Customer\" tables. I will join these tables on the \"CustomerId\" column and group the results by the \"BillingCountry\" column. Then, I will calculate the sum of the \"Total\" column to get the total sales per country. Finally, I will order the results in descending order of the total sales.\n",
"\n",
"Here is the SQL query:\n",
@ -214,11 +222,12 @@
"JOIN Customer c ON i.CustomerId = c.CustomerId\n",
"GROUP BY c.Country\n",
"ORDER BY TotalSales DESC\n",
"LIMIT 10;\n",
"```\n",
"\n",
"Now, I will execute this query to get the results.\n",
"Now, I will execute this query to get the total sales per country.\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62), ('Ireland', 45.62), ('Hungary', 45.62), ('Austria', 42.62), ('Finland', 41.620000000000005), ('Netherlands', 40.62), ('Norway', 39.62), ('Sweden', 38.620000000000005), ('Poland', 37.620000000000005), ('Italy', 37.620000000000005), ('Denmark', 37.620000000000005), ('Australia', 37.620000000000005), ('Argentina', 37.620000000000005), ('Spain', 37.62), ('Belgium', 37.62)]\u001b[0m\u001b[32;1m\u001b[1;3mThe total sales per country are as follows:\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62)]\u001b[0m\u001b[32;1m\u001b[1;3mThe total sales per country are as follows:\n",
"\n",
"1. USA: $523.06\n",
"2. Canada: $303.96\n",
@ -231,7 +240,7 @@
"9. India: $75.26\n",
"10. Chile: $46.62\n",
"\n",
"The country whose customers spent the most is the USA, with a total sales of $523.06.\u001b[0m\n",
"To answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -240,10 +249,10 @@
"data": {
"text/plain": [
"{'input': \"List the total sales per country. Which country's customers spent the most?\",\n",
" 'output': 'The total sales per country are as follows:\\n\\n1. USA: $523.06\\n2. Canada: $303.96\\n3. France: $195.10\\n4. Brazil: $190.10\\n5. Germany: $156.48\\n6. United Kingdom: $112.86\\n7. Czech Republic: $90.24\\n8. Portugal: $77.24\\n9. India: $75.26\\n10. Chile: $46.62\\n\\nThe country whose customers spent the most is the USA, with a total sales of $523.06.'}"
" 'output': 'The total sales per country are as follows:\\n\\n1. USA: $523.06\\n2. Canada: $303.96\\n3. France: $195.10\\n4. Brazil: $190.10\\n5. Germany: $156.48\\n6. United Kingdom: $112.86\\n7. Czech Republic: $90.24\\n8. Portugal: $77.24\\n9. India: $75.26\\n10. Chile: $46.62\\n\\nTo answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.'}"
]
},
"execution_count": 3,
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
@ -256,7 +265,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 46,
"metadata": {},
"outputs": [
{
@ -326,7 +335,7 @@
" 'output': 'The `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \\n\\nHere is the schema of the `PlaylistTrack` table:\\n\\n```\\nCREATE TABLE \"PlaylistTrack\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\\n)\\n```\\n\\nThe `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\\n\\nHere are three sample rows from the `PlaylistTrack` table:\\n\\n```\\nPlaylistId TrackId\\n1 3402\\n1 3389\\n1 3390\\n```\\n\\nPlease let me know if there is anything else I can help with.'}"
]
},
"execution_count": 4,
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
@ -341,14 +350,14 @@
"source": [
"## Using a dynamic few-shot prompt\n",
"\n",
"To optimize agent performance, we can provide a custom prompt with domain-specific knowledge. In this case we'll create a few shot prompt with an example selector, that will dynamically build the few shot prompt based on the user input.\n",
"To optimize agent performance, we can provide a custom prompt with domain-specific knowledge. In this case we'll create a few shot prompt with an example selector, that will dynamically build the few shot prompt based on the user input. This will help the model make better queries by inserting relevant queries in the prompt that the model can use as reference.\n",
"\n",
"First we need some user input <> SQL query examples:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@ -406,7 +415,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@ -628,20 +637,20 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['For Those About To Rock We Salute You',\n",
" 'Balls to the Wall',\n",
" 'Restless and Wild',\n",
" 'Let There Be Rock',\n",
" 'Big Ones']"
"['Os Cães Ladram Mas A Caravana Não Pára',\n",
" 'War',\n",
" 'Mais Do Mesmo',\n",
" \"Up An' Atom\",\n",
" 'Riot Act']"
]
},
"execution_count": 11,
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
@ -655,7 +664,7 @@
" res = db.run(query)\n",
" res = [el for sub in ast.literal_eval(res) for el in sub if el]\n",
" res = [re.sub(r\"\\b\\d+\\b\", \"\", string).strip() for string in res]\n",
" return res\n",
" return list(set(res))\n",
"\n",
"\n",
"artists = query_as_list(db, \"SELECT Name FROM Artist\")\n",
@ -672,7 +681,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
@ -691,7 +700,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
@ -706,7 +715,7 @@
"\n",
"DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
"\n",
"If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the \"search_proper_nouns\" tool!\n",
"If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the \"search_proper_nouns\" tool! \n",
"\n",
"You have access to the following tables: {table_names}\n",
"\n",
@ -725,7 +734,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 52,
"metadata": {},
"outputs": [
{
@ -736,19 +745,19 @@
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `search_proper_nouns` with `{'query': 'alice in chains'}`\n",
"Invoking: `search_proper_nouns` with `{'query': 'alis in chain'}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3mAlice In Chains\n",
"\n",
"Metallica\n",
"Aisha Duo\n",
"\n",
"Pearl Jam\n",
"Xis\n",
"\n",
"Pearl Jam\n",
"Da Lama Ao Caos\n",
"\n",
"Smashing Pumpkins\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query` with `{'query': \"SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')\"}`\n",
"A-Sides\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query` with `SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m[(1,)]\u001b[0m\u001b[32;1m\u001b[1;3mAlice In Chains has 1 album.\u001b[0m\n",
@ -759,17 +768,17 @@
{
"data": {
"text/plain": [
"{'input': 'How many albums does alice in chains have?',\n",
"{'input': 'How many albums does alis in chain have?',\n",
" 'output': 'Alice In Chains has 1 album.'}"
]
},
"execution_count": 14,
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.invoke({\"input\": \"How many albums does alice in chains have?\"})"
"agent.invoke({\"input\": \"How many albums does alis in chain have?\"})"
]
},
{
@ -793,9 +802,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv",
"display_name": "pampa-labs",
"language": "python",
"name": "poetry-venv"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
@ -807,7 +816,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.10.12"
}
},
"nbformat": 4,

@ -195,7 +195,7 @@
"* Has definitions for all the available tables.\n",
"* Has three examples rows for each table.\n",
"\n",
"This technique is inspired by papers like [this](https://arxiv.org/pdf/2204.00498.pdf), which suggest showing examples rows and being explicit about tables improves performance. We can also in"
"This technique is inspired by papers like [this](https://arxiv.org/pdf/2204.00498.pdf), which suggest showing examples rows and being explicit about tables improves performance. We can also inspect the full prompt like so:"
]
},
{
@ -343,6 +343,7 @@
"- It can answer questions based on the databases' schema as well as on the databases' content (like describing a specific table).\n",
"- It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.\n",
"- It can answer questions that require multiple dependent queries.\n",
"- It will save tokens by only considering the schema from relevant tables.\n",
"\n",
"To initialize the agent, we use `create_sql_agent` function. This agent contains the `SQLDatabaseToolkit` which contains tools to: \n",
"\n",

Loading…
Cancel
Save