You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/docs/docs/use_cases/sql/large_db.ipynb

628 lines
20 KiB
Plaintext

{
"cells": [
{
"cell_type": "raw",
"id": "b2788654-3f62-4e2a-ab00-471922cc54df",
"metadata": {},
"source": [
"---\n",
"sidebar_position: 4\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "6751831d-9b08-434f-829b-d0052a3b119f",
"metadata": {},
"source": [
"# Large databases\n",
"\n",
"In order to write valid queries against a database, we need to feed the model the table names, table schemas, and feature values for it to query over. When there are many tables, columns, and/or high-cardinality columns, it becomes impossible for us to dump the full information about our database in every prompt. Instead, we must find ways to dynamically insert into the prompt only the most relevant information. Let's take a look at some techniques for doing this.\n",
"\n",
"\n",
"## Setup\n",
"\n",
"First, get required packages and set environment variables:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9675e433-e608-469e-b04e-2847479a8310",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install --upgrade --quiet langchain langchain-community langchain-openai"
]
},
{
"cell_type": "markdown",
"id": "4f56ff5d-b2e4-49e3-a0b4-fb99466cfedc",
"metadata": {},
"source": [
"We default to OpenAI models in this guide, but you can swap them out for the model provider of your choice."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "06d8dd03-2d7b-4fef-b145-43c074eacb8b",
"metadata": {},
"outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
" ········\n"
]
}
],
"source": [
"import getpass\n",
"import os\n",
"\n",
"# os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n",
"\n",
"# Uncomment the below to use LangSmith. Not required.\n",
"os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n",
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\""
]
},
{
"cell_type": "markdown",
"id": "590ee096-db88-42af-90d4-99b8149df753",
"metadata": {},
"source": [
"The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n",
"\n",
"* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n",
"* Run `sqlite3 Chinook.db`\n",
"* Run `.read Chinook_Sqlite.sql`\n",
"* Test `SELECT * FROM Artist LIMIT 10;`\n",
"\n",
"Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven [SQLDatabase](https://api.python.langchain.com/en/latest/utilities/langchain_community.utilities.sql_database.SQLDatabase.html) class:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cebd3915-f58f-4e73-8459-265630ae8cd4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sqlite\n",
"['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n"
]
},
{
"data": {
"text/plain": [
"\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\""
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_community.utilities import SQLDatabase\n",
"\n",
"db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n",
"print(db.dialect)\n",
"print(db.get_usable_table_names())\n",
"db.run(\"SELECT * FROM Artist LIMIT 10;\")"
]
},
{
"cell_type": "markdown",
"id": "2e572e1f-99b5-46a2-9023-76d1e6256c0a",
"metadata": {},
"source": [
"## Many tables\n",
"\n",
"One of the main pieces of information we need to include in our prompt is the schemas of the relevant tables. When we have very many tables, we can't fit all of the schemas in a single prompt. What we can do in such cases is first extract the names of the tables related to the user input, and then include only their schemas.\n",
"\n",
"One easy and reliable way to do this is using OpenAI function-calling and Pydantic models. LangChain comes with a built-in [create_extraction_chain_pydantic](https://api.python.langchain.com/en/latest/chains/langchain.chains.openai_tools.extraction.create_extraction_chain_pydantic.html) chain that lets us do just this:"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "d8236886-c54f-4bdb-ad74-2514888628fd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Table(name='Genre'), Table(name='Artist'), Table(name='Track')]"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.chains.openai_tools import create_extraction_chain_pydantic\n",
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-1106\", temperature=0)\n",
"\n",
"\n",
"class Table(BaseModel):\n",
" \"\"\"Table in SQL database.\"\"\"\n",
"\n",
" name: str = Field(description=\"Name of table in SQL database.\")\n",
"\n",
"\n",
"table_names = \"\\n\".join(db.get_usable_table_names())\n",
"system = f\"\"\"Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \\\n",
"The tables are:\n",
"\n",
"{table_names}\n",
"\n",
"Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.\"\"\"\n",
"table_chain = create_extraction_chain_pydantic(Table, llm, system_message=system)\n",
"table_chain.invoke({\"input\": \"What are all the genres of Alanis Morisette songs\"})"
]
},
{
"cell_type": "markdown",
"id": "1641dbba-d359-4cb2-ac52-82dfae99f392",
"metadata": {},
"source": [
"This works pretty well! Except, as we'll see below, we actually need a few other tables as well. This would be pretty difficult for the model to know based just on the user question. In this case, we might think to simplify our model's job by grouping the tables together. We'll just ask the model to choose between categories \"Music\" and \"Business\", and then take care of selecting all the relevant tables from there:"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "0ccb0bf5-c580-428f-9cde-a58772ae784e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Table(name='Music')]"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"system = \"\"\"Return the names of the SQL tables that are relevant to the user question. \\\n",
"The tables are:\n",
"\n",
"Music\n",
"Business\"\"\"\n",
"category_chain = create_extraction_chain_pydantic(Table, llm, system_message=system)\n",
"category_chain.invoke({\"input\": \"What are all the genres of Alanis Morisette songs\"})"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "ae4899fc-6f8a-4b10-983c-9e3fef4a7bb9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Album', 'Artist', 'Genre', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from typing import List\n",
"\n",
"\n",
"def get_tables(categories: List[Table]) -> List[str]:\n",
" tables = []\n",
" for category in categories:\n",
" if category.name == \"Music\":\n",
" tables.extend(\n",
" [\n",
" \"Album\",\n",
" \"Artist\",\n",
" \"Genre\",\n",
" \"MediaType\",\n",
" \"Playlist\",\n",
" \"PlaylistTrack\",\n",
" \"Track\",\n",
" ]\n",
" )\n",
" elif category.name == \"Business\":\n",
" tables.extend([\"Customer\", \"Employee\", \"Invoice\", \"InvoiceLine\"])\n",
" return tables\n",
"\n",
"\n",
"table_chain = category_chain | get_tables # noqa\n",
"table_chain.invoke({\"input\": \"What are all the genres of Alanis Morisette songs\"})"
]
},
{
"cell_type": "markdown",
"id": "04d52d01-1ccf-4753-b34a-0dcbc4921f78",
"metadata": {},
"source": [
"Now that we've got a chain that can output the relevant tables for any query we can combine this with our [create_sql_query_chain](https://api.python.langchain.com/en/latest/chains/langchain.chains.sql_database.query.create_sql_query_chain.html), which can accept a list of `table_names_to_use` to determine which table schemas are included in the prompt:"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "79f2a5a2-eb99-47e3-9c2b-e5751a800174",
"metadata": {},
"outputs": [],
"source": [
"from operator import itemgetter\n",
"\n",
"from langchain.chains import create_sql_query_chain\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"\n",
"query_chain = create_sql_query_chain(llm, db)\n",
"# Convert \"question\" key to the \"input\" key expected by current table_chain.\n",
"table_chain = {\"input\": itemgetter(\"question\")} | table_chain\n",
"# Set table_names_to_use using table_chain.\n",
"full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "424a7564-f63c-4584-b734-88021926486d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SELECT \"Genre\".\"Name\"\n",
"FROM \"Genre\"\n",
"JOIN \"Track\" ON \"Genre\".\"GenreId\" = \"Track\".\"GenreId\"\n",
"JOIN \"Album\" ON \"Track\".\"AlbumId\" = \"Album\".\"AlbumId\"\n",
"JOIN \"Artist\" ON \"Album\".\"ArtistId\" = \"Artist\".\"ArtistId\"\n",
"WHERE \"Artist\".\"Name\" = 'Alanis Morissette'\n"
]
}
],
"source": [
"query = full_chain.invoke(\n",
" {\"question\": \"What are all the genres of Alanis Morisette songs\"}\n",
")\n",
"print(query)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "3fb715cf-69d1-46a6-a1a7-9715ee550a0c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"[('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',)]\""
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.run(query)"
]
},
{
"cell_type": "markdown",
"id": "bb3d12b0-81a6-4250-8bc4-d58fe762c4cc",
"metadata": {},
"source": [
"We might rephrase our question slightly to remove redundancy in the answer"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "010b5c3c-d55b-461a-8de5-8f1a8b2c56ec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SELECT DISTINCT g.Name\n",
"FROM Genre g\n",
"JOIN Track t ON g.GenreId = t.GenreId\n",
"JOIN Album a ON t.AlbumId = a.AlbumId\n",
"JOIN Artist ar ON a.ArtistId = ar.ArtistId\n",
"WHERE ar.Name = 'Alanis Morissette'\n"
]
}
],
"source": [
"query = full_chain.invoke(\n",
" {\"question\": \"What is the set of all unique genres of Alanis Morisette songs\"}\n",
")\n",
"print(query)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "d21c0563-1f55-4577-8222-b0e9802f1c4b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"[('Rock',)]\""
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.run(query)"
]
},
{
"cell_type": "markdown",
"id": "7a717020-84c2-40f3-ba84-6624138d8e0c",
"metadata": {},
"source": [
"We can see the [LangSmith trace](https://smith.langchain.com/public/20b8ef90-1dac-4754-90f0-6bc11203c50a/r) for this run here.\n",
"\n",
"We've seen how to dynamically include a subset of table schemas in a prompt within a chain. Another possible approach to this problem is to let an Agent decide for itself when to look up tables by giving it a Tool to do so. You can see an example of this in the [SQL: Agents](/docs/use_cases/sql/agents) guide."
]
},
{
"cell_type": "markdown",
"id": "cb9e54fd-64ca-4ed5-847c-afc635aae4f5",
"metadata": {},
"source": [
"## High-cardinality columns\n",
"\n",
"In order to filter columns that contain proper nouns such as addresses, song names or artists, we first need to double-check the spelling in order to filter the data correctly. \n",
"\n",
"One naive strategy it to create a vector store with all the distinct proper nouns that exist in the database. We can then query that vector store each user input and inject the most relevant proper nouns into the prompt.\n",
"\n",
"First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "dee1b9e1-36b0-4cc1-ab78-7a872ad87e29",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['AC/DC', 'Accept', 'Aerosmith', 'Alanis Morissette', 'Alice In Chains']"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import ast\n",
"import re\n",
"\n",
"\n",
"def query_as_list(db, query):\n",
" res = db.run(query)\n",
" res = [el for sub in ast.literal_eval(res) for el in sub if el]\n",
" res = [re.sub(r\"\\b\\d+\\b\", \"\", string).strip() for string in res]\n",
" return res\n",
"\n",
"\n",
"proper_nouns = query_as_list(db, \"SELECT Name FROM Artist\")\n",
"proper_nouns += query_as_list(db, \"SELECT Title FROM Album\")\n",
"proper_nouns += query_as_list(db, \"SELECT Name FROM Genre\")\n",
"len(proper_nouns)\n",
"proper_nouns[:5]"
]
},
{
"cell_type": "markdown",
"id": "22efa968-1879-4d7a-858f-7899dfa57454",
"metadata": {},
"source": [
"Now we can embed and store all of our values in a vector database:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ea50abce-545a-4dc3-8795-8d364f7d142a",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.vectorstores import FAISS\n",
"from langchain_openai import OpenAIEmbeddings\n",
"\n",
"vector_db = FAISS.from_texts(proper_nouns, OpenAIEmbeddings())\n",
"retriever = vector_db.as_retriever(search_kwargs={\"k\": 15})"
]
},
{
"cell_type": "markdown",
"id": "a5d1d5c0-0928-40a4-b961-f1afe03cd5d3",
"metadata": {},
"source": [
"And put together a query construction chain that first retrieves values from the database and inserts them into the prompt:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "aea123ae-d809-44a0-be5d-d883c60d6a11",
"metadata": {},
"outputs": [],
"source": [
"from operator import itemgetter\n",
"\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"\n",
"system = \"\"\"You are a SQLite expert. Given an input question, create a syntactically \\\n",
"correct SQLite query to run. Unless otherwise specificed, do not return more than \\\n",
"{top_k} rows.\\n\\nHere is the relevant table info: {table_info}\\n\\nHere is a non-exhaustive \\\n",
"list of possible feature values. If filtering on a feature value make sure to check its spelling \\\n",
"against this list first:\\n\\n{proper_nouns}\"\"\"\n",
"\n",
"prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", \"{input}\")])\n",
"\n",
"query_chain = create_sql_query_chain(llm, db, prompt=prompt)\n",
"retriever_chain = (\n",
" itemgetter(\"question\")\n",
" | retriever\n",
" | (lambda docs: \"\\n\".join(doc.page_content for doc in docs))\n",
")\n",
"chain = RunnablePassthrough.assign(proper_nouns=retriever_chain) | query_chain"
]
},
{
"cell_type": "markdown",
"id": "12b0ed60-2536-4f82-85df-e096a272072a",
"metadata": {},
"source": [
"To try out our chain, let's see what happens when we try filtering on \"elenis moriset\", a mispelling of Alanis Morissette, without and with retrieval:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "fcdd8432-07a4-4609-8214-b1591dd94950",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SELECT DISTINCT Genre.Name\n",
"FROM Genre\n",
"JOIN Track ON Genre.GenreId = Track.GenreId\n",
"JOIN Album ON Track.AlbumId = Album.AlbumId\n",
"JOIN Artist ON Album.ArtistId = Artist.ArtistId\n",
"WHERE Artist.Name = 'Elenis Moriset'\n"
]
},
{
"data": {
"text/plain": [
"''"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Without retrieval\n",
"query = query_chain.invoke(\n",
" {\"question\": \"What are all the genres of elenis moriset songs\", \"proper_nouns\": \"\"}\n",
")\n",
"print(query)\n",
"db.run(query)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e8a3231a-8590-46f5-a954-da06829ee6df",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SELECT DISTINCT Genre.Name\n",
"FROM Genre\n",
"JOIN Track ON Genre.GenreId = Track.GenreId\n",
"JOIN Album ON Track.AlbumId = Album.AlbumId\n",
"JOIN Artist ON Album.ArtistId = Artist.ArtistId\n",
"WHERE Artist.Name = 'Alanis Morissette'\n"
]
},
{
"data": {
"text/plain": [
"\"[('Rock',)]\""
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# With retrieval\n",
"query = chain.invoke({\"question\": \"What are all the genres of elenis moriset songs\"})\n",
"print(query)\n",
"db.run(query)"
]
},
{
"cell_type": "markdown",
"id": "7f99181b-a75c-4ff3-b37b-33f99a506581",
"metadata": {},
"source": [
"We can see that with retrieval we're able to correct the spelling and get back a valid result.\n",
"\n",
"Another possible approach to this problem is to let an Agent decide for itself when to look up proper nouns. You can see an example of this in the [SQL: Agents](/docs/use_cases/sql/agents) guide."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv",
"language": "python",
"name": "poetry-venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}