"[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/langchain-ai/langchain/blob/master/docs/extras/use_cases/qa_structured/sql.ipynb)\n",
"The below example will use a SQLite connection with Chinook database. \n",
" \n",
"Follow [installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n",
"\n",
"* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) to the directory as `Chinook_Sqlite.sql`\n",
"* Run `sqlite3 Chinook.db`\n",
"* Run `.read Chinook_Sqlite.sql`\n",
"* Test `SELECT * FROM Artist LIMIT 10;`\n",
"\n",
"Now, `Chinhook.db` is in our directory.\n",
"\n",
"Let's create a `SQLDatabaseChain` to create and execute SQL queries."
"TEMPLATE = \"\"\"Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n",
"Use the following format:\n",
"\n",
"Question: \"Question here\"\n",
"SQLQuery: \"SQL Query to run\"\n",
"SQLResult: \"Result of the SQLQuery\"\n",
"Answer: \"Final answer here\"\n",
"\n",
"Only use the following tables:\n",
"\n",
"{table_info}.\n",
"\n",
"Some examples of SQL queries that corrsespond to questions are:\n",
"Answer:\u001b[32;1m\u001b[1;3mThere are 8 employees.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'There are 8 employees.'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_chain.run(\"How many employees are there?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, we get the same result as the previous case.\n",
"\n",
"Here, the chain **also handles the query execution** and provides a final answer based on the user question and the query result.\n",
"\n",
"**Be careful** while using this approach as it is susceptible to `SQL Injection`:\n",
"\n",
"* The chain is executing queries that are created by an LLM, and weren't validated\n",
"* e.g. records may be created, modified or deleted unintentionally_\n",
"\n",
"This is why we see the `SQLDatabaseChain` is inside `langchain_experimental`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Go deeper\n",
"\n",
"**Looking under the hood**\n",
"\n",
"We can use the [LangSmith trace](https://smith.langchain.com/public/7f202a0c-1e35-42d6-a84a-6c2a58f697ef/r) to see what is happening under the hood:\n",
"\n",
"* As discussed above, first we create the query:\n",
"\n",
"```\n",
"text: ' SELECT COUNT(*) FROM \"Employee\";'\n",
"```\n",
"\n",
"* Then, it executes the query and passes the results to an LLM for synthesis.\n",
"- [Using Query Checker](/docs/integrations/tools/sqlite#use-query-checker) self-correct invalid SQL using parameter `use_query_checker=True`\n",
"- [Customizing the LLM Prompt](/docs/integrations/tools/sqlite#customize-prompt) include specific instructions or relevant information, using parameter `prompt=CUSTOM_PROMPT`\n",
"- [Get intermediate steps](/docs/integrations/tools/sqlite#return-intermediate-steps) access the SQL statement as well as the final result using parameter `return_intermediate_steps=True`\n",
"- [Limit the number of rows](/docs/integrations/tools/sqlite#choosing-how-to-limit-the-number-of-rows-returned) a query will return using parameter `top_k=5`\n",
"1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n",
"Thought:\u001b[32;1m\u001b[1;3m I should query the total sales per country.\n",
"Action: sql_db_query\n",
"Action Input: SELECT Country, SUM(Total) AS TotalSales FROM Invoice INNER JOIN Customer ON Invoice.CustomerId = Customer.CustomerId GROUP BY Country ORDER BY TotalSales DESC LIMIT 10\u001b[0m\n",
"Thought: I should query the schema of the Invoice and Customer tables.\n",
"Action: sql_db_schema\n",
"Action Input: Invoice, Customer\n",
"```\n",
"\n",
"* It then formulates the query using the schema from tool `sql_db_schema`\n",
"\n",
"```\n",
"Thought: I should query the total sales per country.\n",
"Action: sql_db_query\n",
"Action Input: SELECT Country, SUM(Total) AS TotalSales FROM Invoice INNER JOIN Customer ON Invoice.CustomerId = Customer.CustomerId GROUP BY Country ORDER BY TotalSales DESC LIMIT 10\n",
"```\n",
"\n",
"* It finally executes the generated query using tool `sql_db_query`\n",
"\n",
"![sql_usecase.png](/img/SQLDatabaseToolkit.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Agent task example #2 - Describing a Table"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: The PlaylistTrack table contains two columns, PlaylistId and TrackId, which are both integers and form a primary key. It also has two foreign keys, one to the Track table and one to the Playlist table.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'The PlaylistTrack table contains two columns, PlaylistId and TrackId, which are both integers and form a primary key. It also has two foreign keys, one to the Track table and one to the Playlist table.'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.run(\"Describe the playlisttrack table\")"
"Although the out-of-the-box SQL Toolkit contains the necessary tools to start working on a database, it is often the case that some extra tools may be useful for extending the agent's capabilities. This is particularly useful when trying to use **domain specific knowledge** in the solution, in order to improve its overall performance.\n",
"\n",
"Some examples include:\n",
"\n",
"- Including dynamic few shot examples\n",
"- Finding misspellings in proper nouns to use as column filters\n",
"\n",
"We can create separate tools which tackle these specific use cases and include them as a complement to the standard SQL Toolkit. Let's see how to include these two custom tools.\n",
"\n",
"#### Including dynamic few-shot examples\n",
"\n",
"In order to include dynamic few-shot examples, we need a custom **Retriever Tool** that handles the vector database in order to retrieve the examples that are semantically similar to the user’s question.\n",
"\n",
"Let's start by creating a dictionary with some examples: "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# few_shots = {'List all artists.': 'SELECT * FROM artists;',\n",
"# \"Find all albums for the artist 'AC/DC'.\": \"SELECT * FROM albums WHERE ArtistId = (SELECT ArtistId FROM artists WHERE Name = 'AC/DC');\",\n",
"# \"List all tracks in the 'Rock' genre.\": \"SELECT * FROM tracks WHERE GenreId = (SELECT GenreId FROM genres WHERE Name = 'Rock');\",\n",
"# 'Find the total duration of all tracks.': 'SELECT SUM(Milliseconds) FROM tracks;',\n",
"# 'List all customers from Canada.': \"SELECT * FROM customers WHERE Country = 'Canada';\",\n",
"# 'How many tracks are there in the album with ID 5?': 'SELECT COUNT(*) FROM tracks WHERE AlbumId = 5;',\n",
"# 'Find the total number of invoices.': 'SELECT COUNT(*) FROM invoices;',\n",
"# 'List all tracks that are longer than 5 minutes.': 'SELECT * FROM tracks WHERE Milliseconds > 300000;',\n",
"# 'Who are the top 5 customers by total purchase?': 'SELECT CustomerId, SUM(Total) AS TotalPurchase FROM invoices GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;',\n",
"# 'Which albums are from the year 2000?': \"SELECT * FROM albums WHERE strftime('%Y', ReleaseDate) = '2000';\",\n",
"# 'How many employees are there': 'SELECT COUNT(*) FROM \"employee\"'\n",
"# }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can then create a retriever using the list of questions, assigning the target SQL query as metadata:"
"This tool will help you understand similar examples to adapt them to the user question.\n",
"Input to this tool should be the user question.\n",
"\"\"\"\n",
"\n",
"retriever_tool = create_retriever_tool(\n",
" retriever,\n",
" name='sql_get_similar_examples',\n",
" description=tool_description\n",
" )\n",
"custom_tool_list = [retriever_tool]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can create the agent, adjusting the standard SQL Agent suffix to consider our use case. Although the most straightforward way to handle this would be to include it just in the tool description, this is often not enough and we need to specify it in the agent prompt using the `suffix` argument in the constructor."
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_get_similar_examples` with `How many employees do we have?`\n",
"\n",
"\n",
"\u001b[0m\u001b[33;1m\u001b[1;3m[Document(page_content='How many employees are there', metadata={'sql_query': 'SELECT COUNT(*) FROM \"employee\"'}), Document(page_content='Find the total number of invoices.', metadata={'sql_query': 'SELECT COUNT(*) FROM invoices;'})]\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query_checker` with `SELECT COUNT(*) FROM employee`\n",
"responded: {content}\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3mSELECT COUNT(*) FROM employee\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query` with `SELECT COUNT(*) FROM employee`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m[(8,)]\u001b[0m\u001b[32;1m\u001b[1;3mWe have 8 employees.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'We have 8 employees.'"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"How many employees do we have?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, the agent first used the `sql_get_similar_examples` tool in order to retrieve similar examples. As the question was very similar to other few shot examples, the agent **didn't need to use any other tool** from the standard Toolkit, thus **saving time and tokens**."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Finding and correcting misspellings for proper nouns\n",
"\n",
"In order to filter columns that contain proper nouns such as addresses, song names or artists, we first need to double-check the spelling in order to filter the data correctly. \n",
"\n",
"We can achieve this by creating a vector store using all the distinct proper nouns that exist in the database. We can then have the agent query that vector store each time the user includes a proper noun in their question, to find the correct spelling for that word. In this way, the agent can make sure it understands which entity the user is referring to before building the target query.\n",
"\n",
"Let's follow a similar approach to the few shots, but without metadata: just embedding the proper nouns and then querying to get the most similar one to the misspelled user question.\n",
"\n",
"First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"import ast\n",
"import re\n",
"\n",
"def run_query_save_results(db, query):\n",
" res = db.run(query)\n",
" res = [el for sub in ast.literal_eval(res) for el in sub if el]\n",
" res = [re.sub(r'\\b\\d+\\b', '', string).strip() for string in res]\n",
" return res\n",
"\n",
"artists = run_query_save_results(db, \"SELECT Name FROM Artist\")\n",
"albums = run_query_save_results(db, \"SELECT Title FROM Album\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can proceed with creating the custom **retreiver tool** and the final agent:"
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `name_search` with `alis in pains`\n",
"\n",
"\n",
"\u001b[0m\u001b[33;1m\u001b[1;3m[Document(page_content='House of Pain', metadata={}), Document(page_content='Alice In Chains', metadata={}), Document(page_content='Aisha Duo', metadata={}), Document(page_content='House Of Pain', metadata={})]\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query_checker` with `SELECT COUNT(*) FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = 'Alice In Chains'`\n",
"responded: {content}\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3mSELECT COUNT(*) FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = 'Alice In Chains'\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query` with `SELECT COUNT(*) FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = 'Alice In Chains'`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m[(1,)]\u001b[0m\u001b[32;1m\u001b[1;3mAlice In Chains has 1 album in the database.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Alice In Chains has 1 album in the database.'"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"How many albums does alis in pains have?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, the agent used the `name_search` tool in order to check how to correctly query the database for this specific artist."
"PROMPT_TEMPLATE = \"\"\"Given an input question, create a syntactically correct Elasticsearch query to run. Unless the user specifies in their question a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.\n",
"\n",
"Unless told to do not query for all the columns from a specific index, only ask for a the few relevant columns given the question.\n",
"\n",
"Pay attention to use only the column names that you can see in the mapping description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which index. Return the query as valid json.\n",
"\n",
"Use the following format:\n",
"\n",
"Question: Question here\n",
"ESQuery: Elasticsearch Query formatted as json\n",