mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
228 lines
5.1 KiB
Plaintext
228 lines
5.1 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "raw",
|
|
"id": "c14da114-1a4a-487d-9cff-e0e8c30ba366",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"sidebar_position: 3\n",
|
|
"title: Querying a SQL DB\n",
|
|
"---"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "506e9636",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can replicate our SQLDatabaseChain with Runnables."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "7a927516",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.prompts import ChatPromptTemplate\n",
|
|
"\n",
|
|
"template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n",
|
|
"{schema}\n",
|
|
"\n",
|
|
"Question: {question}\n",
|
|
"SQL Query:\"\"\"\n",
|
|
"prompt = ChatPromptTemplate.from_template(template)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "3f51f386",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.utilities import SQLDatabase"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7c3449d6-684b-416e-ba16-90a035835a88",
|
|
"metadata": {},
|
|
"source": [
|
|
"We'll need the Chinook sample DB for this example. There's many places to download it from, e.g. https://database.guide/2-sample-databases-sqlite/"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "2ccca6fc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"db = SQLDatabase.from_uri(\"sqlite:///./Chinook.db\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"id": "05ba88ee",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_schema(_):\n",
|
|
" return db.get_table_info()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "a4eda902",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def run_query(query):\n",
|
|
" return db.run(query)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "5046cb17",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from operator import itemgetter\n",
|
|
"\n",
|
|
"from langchain.chat_models import ChatOpenAI\n",
|
|
"from langchain.schema.output_parser import StrOutputParser\n",
|
|
"from langchain.schema.runnable import RunnableLambda, RunnableMap\n",
|
|
"\n",
|
|
"model = ChatOpenAI()\n",
|
|
"\n",
|
|
"inputs = {\n",
|
|
" \"schema\": RunnableLambda(get_schema),\n",
|
|
" \"question\": itemgetter(\"question\")\n",
|
|
"}\n",
|
|
"sql_response = (\n",
|
|
" RunnableMap(inputs)\n",
|
|
" | prompt\n",
|
|
" | model.bind(stop=[\"\\nSQLResult:\"])\n",
|
|
" | StrOutputParser()\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "a5552039",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'SELECT COUNT(*) FROM Employee'"
|
|
]
|
|
},
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"sql_response.invoke({\"question\": \"How many employees are there?\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"id": "d6fee130",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n",
|
|
"{schema}\n",
|
|
"\n",
|
|
"Question: {question}\n",
|
|
"SQL Query: {query}\n",
|
|
"SQL Response: {response}\"\"\"\n",
|
|
"prompt_response = ChatPromptTemplate.from_template(template)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"id": "923aa634",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"full_chain = (\n",
|
|
" RunnableMap({\n",
|
|
" \"question\": itemgetter(\"question\"),\n",
|
|
" \"query\": sql_response,\n",
|
|
" }) \n",
|
|
" | {\n",
|
|
" \"schema\": RunnableLambda(get_schema),\n",
|
|
" \"question\": itemgetter(\"question\"),\n",
|
|
" \"query\": itemgetter(\"query\"),\n",
|
|
" \"response\": lambda x: db.run(x[\"query\"]) \n",
|
|
" } \n",
|
|
" | prompt_response \n",
|
|
" | model\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "e94963d8",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"AIMessage(content='There are 8 employees.', additional_kwargs={}, example=False)"
|
|
]
|
|
},
|
|
"execution_count": 27,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"full_chain.invoke({\"question\": \"How many employees are there?\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4f358d7b-a721-4db3-9f92-f06913428afc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.1"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|