langchain/docs/extras/use_cases/sql.ipynb
2023-08-08 03:30:18 -07:00

859 lines
30 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SQL\n",
"\n",
"[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/langchain-ai/langchain/blob/master/docs/extras/use_cases/sql.ipynb)\n",
"\n",
"## Use case\n",
"\n",
"Enterprise data is often stored in SQL databases.\n",
"\n",
"LLMs make it possible to interact with SQL databases using natural langugae.\n",
"\n",
"LangChain offers SQL Chains and Agents to build and run SQL queries based on natural language prompts. \n",
"\n",
"These are compatible with any SQL dialect supported by SQLAlchemy (e.g., MySQL, PostgreSQL, Oracle SQL, Databricks, SQLite).\n",
"\n",
"They enable use cases such as:\n",
"\n",
"- Generating queries that will be run based on natural language questions\n",
"- Creating chatbots that can answer questions based on database data\n",
"- Building custom dashboards based on insights a user wants to analyze\n",
"\n",
"## Overview\n",
"\n",
"LangChain provides tools to interact with SQL Databases:\n",
"\n",
"1. `Build SQL queries` based on natural language user questions\n",
"2. `Query a SQL database` using chains for query creation and execution\n",
"3. `Interact with a SQL database` using agents for robust and flexible querying \n",
"\n",
"![sql_usecase.png](/img/sql_usecase.png)\n",
"\n",
"## Quickstart\n",
"\n",
"First, get required packages and set environment variables:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"! pip install langchain langchain-experimental openai\n",
"\n",
"# Set env var OPENAI_API_KEY or load from a .env file\n",
"# import dotenv\n",
"\n",
"# dotenv.load_env()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The below example will use a SQLite connection with Chinook database. \n",
" \n",
"Follow [installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n",
"\n",
"* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) to the directory as `Chinook_Sqlite.sql`\n",
"* Run `sqlite3 Chinook.db`\n",
"* Run `.read Chinook_Sqlite.sql`\n",
"* Test `SELECT * FROM Artist LIMIT 10;`\n",
"\n",
"Now, `Chinhook.db` is in our directory.\n",
"\n",
"Let's create a `SQLDatabaseChain` to create and execute SQL queries."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from langchain.utilities import SQLDatabase\n",
"from langchain.llms import OpenAI\n",
"from langchain_experimental.sql import SQLDatabaseChain\n",
"\n",
"db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n",
"llm = OpenAI(temperature=0, verbose=True)\n",
"db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"How many employees are there?\n",
"SQLQuery:\u001b[32;1m\u001b[1;3mSELECT COUNT(*) FROM \"Employee\";\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(8,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3mThere are 8 employees.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'There are 8 employees.'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_chain.run(\"How many employees are there?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that this both creates and executes the query. \n",
"\n",
"In the following sections, we will cover the 3 different use cases mentioned in the overview.\n",
"\n",
"### Go deeper\n",
"\n",
"You can load tabular data from other sources other than SQL Databases.\n",
"For example:\n",
"- [Loading a CSV file](/docs/integrations/document_loaders/csv)\n",
"- [Loading a Pandas DataFrame](/docs/integrations/document_loaders/pandas_dataframe)\n",
"Here you can [check full list of Document Loaders](/docs/integrations/document_loaders/)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Case 1: Text-to-SQL query\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.chains import create_sql_query_chain"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create the chain that will build the SQL Query:\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SELECT COUNT(*) FROM Employee\n"
]
}
],
"source": [
"chain = create_sql_query_chain(ChatOpenAI(temperature=0), db)\n",
"response = chain.invoke({\"question\":\"How many employees are there\"})\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After building the SQL query based on a user question, we can execute the query:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'[(8,)]'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.run(response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, the SQL Query Builder chain **only created** the query, and we handled the **query execution separately**."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Go deeper\n",
"\n",
"**Looking under the hood**\n",
"\n",
"We can look at the [LangSmith trace](https://smith.langchain.com/public/c8fa52ea-be46-4829-bde2-52894970b830/r) to unpack this:\n",
"\n",
"[Some papers](https://arxiv.org/pdf/2204.00498.pdf) have reported good performance when prompting with:\n",
" \n",
"* A `CREATE TABLE` description for each table, which include column names, their types, etc\n",
"* Followed by three example rows in a `SELECT` statement\n",
"\n",
"`create_sql_query_chain` adopts this the best practice (see more in this [blog](https://blog.langchain.dev/llms-and-sql/)). \n",
"![sql_usecase.png](/img/create_sql_query_chain.png)\n",
"\n",
"**Improvements**\n",
"\n",
"The query builder can be improved in several ways, such as (but not limited to):\n",
"\n",
"- Customizing database description to your specific use case\n",
"- Hardcoding a few examples of questions and their corresponding SQL query in the prompt\n",
"- Using a vector database to include dynamic examples that are relevant to the specific user question\n",
"\n",
"All these examples involve customizing the chain's prompt. \n",
"\n",
"For example, we can include a few examples in our prompt like so:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"\n",
"TEMPLATE = \"\"\"Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n",
"Use the following format:\n",
"\n",
"Question: \"Question here\"\n",
"SQLQuery: \"SQL Query to run\"\n",
"SQLResult: \"Result of the SQLQuery\"\n",
"Answer: \"Final answer here\"\n",
"\n",
"Only use the following tables:\n",
"\n",
"{table_info}.\n",
"\n",
"Some examples of SQL queries that corrsespond to questions are:\n",
"\n",
"{few_shot_examples}\n",
"\n",
"Question: {input}\"\"\"\n",
"\n",
"CUSTOM_PROMPT = PromptTemplate(\n",
" input_variables=[\"input\", \"few_shot_examples\", \"table_info\", \"dialect\"], template=TEMPLATE\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Case 2: Text-to-SQL query and execution\n",
"\n",
"We can use `SQLDatabaseChain` from `langchain_experimental` to create and run SQL queries."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import OpenAI\n",
"from langchain_experimental.sql import SQLDatabaseChain\n",
"\n",
"llm = OpenAI(temperature=0, verbose=True)\n",
"db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"How many employees are there?\n",
"SQLQuery:\u001b[32;1m\u001b[1;3mSELECT COUNT(*) FROM \"Employee\";\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(8,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3mThere are 8 employees.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'There are 8 employees.'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_chain.run(\"How many employees are there?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, we get the same result as the previous case.\n",
"\n",
"Here, the chain **also handles the query execution** and provides a final answer based on the user question and the query result.\n",
"\n",
"**Be careful** while using this approach as it is susceptible to `SQL Injection`:\n",
"\n",
"* The chain is executing queries that are created by an LLM, and weren't validated\n",
"* e.g. records may be created, modified or deleted unintentionally_\n",
"\n",
"This is why we see the `SQLDatabaseChain` is inside `langchain_experimental`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Go deeper\n",
"\n",
"**Looking under the hood**\n",
"\n",
"We can use the [LangSmith trace](https://smith.langchain.com/public/7f202a0c-1e35-42d6-a84a-6c2a58f697ef/r) to see what is happening under the hood:\n",
"\n",
"* As discussed above, first we create the query:\n",
"\n",
"```\n",
"text: ' SELECT COUNT(*) FROM \"Employee\";'\n",
"```\n",
"\n",
"* Then, it executes the query and passes the results to an LLM for synthesis.\n",
"\n",
"![sql_usecase.png](/img/sqldbchain_trace.png)\n",
"\n",
"**Improvements**\n",
"\n",
"The performance of the `SQLDatabaseChain` can be enhanced in several ways:\n",
"\n",
"- [Adding sample rows](#adding-sample-rows)\n",
"- [Specifying custom table information](/docs/integrations/tools/sqlite#custom-table-info)\n",
"- [Using Query Checker](/docs/integrations/tools/sqlite#use-query-checker) self-correct invalid SQL using parameter `use_query_checker=True`\n",
"- [Customizing the LLM Prompt](/docs/integrations/tools/sqlite#customize-prompt) include specific instructions or relevant information, using parameter `prompt=CUSTOM_PROMPT`\n",
"- [Get intermediate steps](/docs/integrations/tools/sqlite#return-intermediate-steps) access the SQL statement as well as the final result using parameter `return_intermediate_steps=True`\n",
"- [Limit the number of rows](/docs/integrations/tools/sqlite#choosing-how-to-limit-the-number-of-rows-returned) a query will return using parameter `top_k=5`\n",
"\n",
"You might find [SQLDatabaseSequentialChain](/docs/integrations/tools/sqlite#sqldatabasesequentialchain)\n",
"useful for cases in which the number of tables in the database is large.\n",
"\n",
"This `Sequential Chain` handles the process of:\n",
"\n",
"1. Determining which tables to use based on the user question\n",
"2. Calling the normal SQL database chain using only relevant tables\n",
"\n",
"**Adding Sample Rows**\n",
"\n",
"Providing sample data can help the LLM construct correct queries when the data format is not obvious. \n",
"\n",
"For example, we can tell LLM that artists are saved with their full names by providing two rows from the Track table.\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"db = SQLDatabase.from_uri(\n",
" \"sqlite:///Chinook.db\",\n",
" include_tables=['Track'], # we include only one table to save tokens in the prompt :)\n",
" sample_rows_in_table_info=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The sample rows are added to the prompt after each corresponding table's column information.\n",
"\n",
"We can use `db.table_info` and check which sample rows are included:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"CREATE TABLE \"Track\" (\n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\t\"Name\" NVARCHAR(200) NOT NULL, \n",
"\t\"AlbumId\" INTEGER, \n",
"\t\"MediaTypeId\" INTEGER NOT NULL, \n",
"\t\"GenreId\" INTEGER, \n",
"\t\"Composer\" NVARCHAR(220), \n",
"\t\"Milliseconds\" INTEGER NOT NULL, \n",
"\t\"Bytes\" INTEGER, \n",
"\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n",
"\tPRIMARY KEY (\"TrackId\"), \n",
"\tFOREIGN KEY(\"MediaTypeId\") REFERENCES \"MediaType\" (\"MediaTypeId\"), \n",
"\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \n",
"\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\n",
")\n",
"\n",
"/*\n",
"2 rows from Track table:\n",
"TrackId\tName\tAlbumId\tMediaTypeId\tGenreId\tComposer\tMilliseconds\tBytes\tUnitPrice\n",
"1\tFor Those About To Rock (We Salute You)\t1\t1\t1\tAngus Young, Malcolm Young, Brian Johnson\t343719\t11170334\t0.99\n",
"2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n",
"*/\n"
]
}
],
"source": [
"print(db.table_info)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Case 3: SQL agents\n",
"\n",
"LangChain has an SQL Agent which provides a more flexible way of interacting with SQL Databases than the `SQLDatabaseChain`.\n",
"\n",
"The main advantages of using the SQL Agent are:\n",
"\n",
"- It can answer questions based on the databases' schema as well as on the databases' content (like describing a specific table)\n",
"- It can recover from errors by running a generated query, catching the traceback and regenerating it correctly\n",
"\n",
"To initialize the agent, we use `create_sql_agent` function. \n",
"\n",
"This agent contains the `SQLDatabaseToolkit` which contains tools to: \n",
"\n",
"* Create and execute queries\n",
"* Check query syntax\n",
"* Retrieve table descriptions\n",
"* ... and more"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import create_sql_agent\n",
"from langchain.agents.agent_toolkits import SQLDatabaseToolkit\n",
"# from langchain.agents import AgentExecutor\n",
"from langchain.agents.agent_types import AgentType\n",
"\n",
"db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n",
"llm = OpenAI(temperature=0, verbose=True)\n",
"\n",
"agent_executor = create_sql_agent(\n",
" llm=OpenAI(temperature=0),\n",
" toolkit=SQLDatabaseToolkit(db=db, llm=OpenAI(temperature=0)),\n",
" verbose=True,\n",
" agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Agent task example #1 - Running queries\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mAction: sql_db_list_tables\n",
"Action Input: \u001b[0m\n",
"Observation: \u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I should query the schema of the Invoice and Customer tables.\n",
"Action: sql_db_schema\n",
"Action Input: Invoice, Customer\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE \"Customer\" (\n",
"\t\"CustomerId\" INTEGER NOT NULL, \n",
"\t\"FirstName\" NVARCHAR(40) NOT NULL, \n",
"\t\"LastName\" NVARCHAR(20) NOT NULL, \n",
"\t\"Company\" NVARCHAR(80), \n",
"\t\"Address\" NVARCHAR(70), \n",
"\t\"City\" NVARCHAR(40), \n",
"\t\"State\" NVARCHAR(40), \n",
"\t\"Country\" NVARCHAR(40), \n",
"\t\"PostalCode\" NVARCHAR(10), \n",
"\t\"Phone\" NVARCHAR(24), \n",
"\t\"Fax\" NVARCHAR(24), \n",
"\t\"Email\" NVARCHAR(60) NOT NULL, \n",
"\t\"SupportRepId\" INTEGER, \n",
"\tPRIMARY KEY (\"CustomerId\"), \n",
"\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Customer table:\n",
"CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n",
"1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n",
"2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n",
"3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n",
"*/\n",
"\n",
"\n",
"CREATE TABLE \"Invoice\" (\n",
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
"\t\"CustomerId\" INTEGER NOT NULL, \n",
"\t\"InvoiceDate\" DATETIME NOT NULL, \n",
"\t\"BillingAddress\" NVARCHAR(70), \n",
"\t\"BillingCity\" NVARCHAR(40), \n",
"\t\"BillingState\" NVARCHAR(40), \n",
"\t\"BillingCountry\" NVARCHAR(40), \n",
"\t\"BillingPostalCode\" NVARCHAR(10), \n",
"\t\"Total\" NUMERIC(10, 2) NOT NULL, \n",
"\tPRIMARY KEY (\"InvoiceId\"), \n",
"\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Invoice table:\n",
"InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n",
"1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n",
"2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n",
"3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n",
"*/\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I should query the total sales per country.\n",
"Action: sql_db_query\n",
"Action Input: SELECT Country, SUM(Total) AS TotalSales FROM Invoice INNER JOIN Customer ON Invoice.CustomerId = Customer.CustomerId GROUP BY Country ORDER BY TotalSales DESC LIMIT 10\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62)]\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: The country with the highest total sales is the USA, with a total of $523.06.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'The country with the highest total sales is the USA, with a total of $523.06.'"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.run(\n",
" \"List the total sales per country. Which country's customers spent the most?\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Looking at the [LangSmith trace](https://smith.langchain.com/public/a86dbe17-5782-4020-bce6-2de85343168a/r), we can see:\n",
"\n",
"* The agent is using a ReAct style prompt\n",
"* First, it will look at the tables: `Action: sql_db_list_tables` using tool `sql_db_list_tables`\n",
"* Given the tables as an observation, it `thinks` and then determinates the next `action`:\n",
"\n",
"```\n",
"Observation: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\n",
"Thought: I should query the schema of the Invoice and Customer tables.\n",
"Action: sql_db_schema\n",
"Action Input: Invoice, Customer\n",
"```\n",
"\n",
"* It then formulates the query using the schema from tool `sql_db_schema`\n",
"\n",
"```\n",
"Thought: I should query the total sales per country.\n",
"Action: sql_db_query\n",
"Action Input: SELECT Country, SUM(Total) AS TotalSales FROM Invoice INNER JOIN Customer ON Invoice.CustomerId = Customer.CustomerId GROUP BY Country ORDER BY TotalSales DESC LIMIT 10\n",
"```\n",
"\n",
"* It finally executes the generated query using tool `sql_db_query`\n",
"\n",
"![sql_usecase.png](/img/SQLDatabaseToolkit.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Agent task example #2 - Describing a Table"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mAction: sql_db_list_tables\n",
"Action Input: \u001b[0m\n",
"Observation: \u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m The PlaylistTrack table is the most relevant to the question.\n",
"Action: sql_db_schema\n",
"Action Input: PlaylistTrack\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE \"PlaylistTrack\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n",
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from PlaylistTrack table:\n",
"PlaylistId\tTrackId\n",
"1\t3402\n",
"1\t3389\n",
"1\t3390\n",
"*/\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: The PlaylistTrack table contains two columns, PlaylistId and TrackId, which are both integers and form a primary key. It also has two foreign keys, one to the Track table and one to the Playlist table.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'The PlaylistTrack table contains two columns, PlaylistId and TrackId, which are both integers and form a primary key. It also has two foreign keys, one to the Track table and one to the Playlist table.'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.run(\"Describe the playlisttrack table\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Go deeper\n",
"\n",
"To learn more about the SQL Agent and how it works we refer to the [SQL Agent Toolkit](/docs/integrations/toolkits/sql_database) documentation.\n",
"\n",
"You can also check Agents for other document types:\n",
"- [Pandas Agent](/docs/integrations/toolkits/pandas.html)\n",
"- [CSV Agent](/docs/integrations/toolkits/csv.html)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Elastic Search\n",
"\n",
"Going beyond the above use-case, there are integrations with other databases.\n",
"\n",
"For example, we can interact with Elasticsearch analytics database. \n",
"\n",
"This chain builds search queries via the Elasticsearch DSL API (filters and aggregations).\n",
"\n",
"The Elasticsearch client must have permissions for index listing, mapping description and search queries.\n",
"\n",
"See [here](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html) for instructions on how to run Elasticsearch locally.\n",
"\n",
"Make sure to install the Elasticsearch Python client before:\n",
"\n",
"```sh\n",
"pip install elasticsearch\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"from elasticsearch import Elasticsearch\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.chains.elasticsearch_database import ElasticsearchDatabaseChain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize Elasticsearch python client.\n",
"# See https://elasticsearch-py.readthedocs.io/en/v8.8.2/api.html#elasticsearch.Elasticsearch\n",
"ELASTIC_SEARCH_SERVER = \"https://elastic:pass@localhost:9200\"\n",
"db = Elasticsearch(ELASTIC_SEARCH_SERVER)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Uncomment the next cell to initially populate your db."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# customers = [\n",
"# {\"firstname\": \"Jennifer\", \"lastname\": \"Walters\"},\n",
"# {\"firstname\": \"Monica\",\"lastname\":\"Rambeau\"},\n",
"# {\"firstname\": \"Carol\",\"lastname\":\"Danvers\"},\n",
"# {\"firstname\": \"Wanda\",\"lastname\":\"Maximoff\"},\n",
"# {\"firstname\": \"Jennifer\",\"lastname\":\"Takeda\"},\n",
"# ]\n",
"# for i, customer in enumerate(customers):\n",
"# db.create(index=\"customers\", document=customer, id=i)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(model_name=\"gpt-4\", temperature=0)\n",
"chain = ElasticsearchDatabaseChain.from_llm(llm=llm, database=db, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"question = \"What are the first names of all the customers?\"\n",
"chain.run(question)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can customize the prompt."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.elasticsearch_database.prompts import DEFAULT_DSL_TEMPLATE\n",
"from langchain.prompts.prompt import PromptTemplate\n",
"\n",
"PROMPT_TEMPLATE = \"\"\"Given an input question, create a syntactically correct Elasticsearch query to run. Unless the user specifies in their question a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.\n",
"\n",
"Unless told to do not query for all the columns from a specific index, only ask for a the few relevant columns given the question.\n",
"\n",
"Pay attention to use only the column names that you can see in the mapping description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which index. Return the query as valid json.\n",
"\n",
"Use the following format:\n",
"\n",
"Question: Question here\n",
"ESQuery: Elasticsearch Query formatted as json\n",
"\"\"\"\n",
"\n",
"PROMPT = PromptTemplate.from_template(\n",
" PROMPT_TEMPLATE,\n",
")\n",
"chain = ElasticsearchDatabaseChain.from_llm(llm=llm, database=db, query_prompt=PROMPT)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}