{ "cells": [ { "cell_type": "raw", "id": "c14da114-1a4a-487d-9cff-e0e8c30ba366", "metadata": {}, "source": [ "---\n", "sidebar_position: 3\n", "title: Querying a SQL DB\n", "---" ] }, { "cell_type": "markdown", "id": "506e9636", "metadata": {}, "source": [ "We can replicate our SQLDatabaseChain with Runnables." ] }, { "cell_type": "code", "execution_count": 1, "id": "7a927516", "metadata": {}, "outputs": [], "source": [ "from langchain.prompts import ChatPromptTemplate\n", "\n", "template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n", "{schema}\n", "\n", "Question: {question}\n", "SQL Query:\"\"\"\n", "prompt = ChatPromptTemplate.from_template(template)" ] }, { "cell_type": "code", "execution_count": 2, "id": "3f51f386", "metadata": {}, "outputs": [], "source": [ "from langchain.utilities import SQLDatabase" ] }, { "cell_type": "markdown", "id": "7c3449d6-684b-416e-ba16-90a035835a88", "metadata": {}, "source": [ "We'll need the Chinook sample DB for this example. There's many places to download it from, e.g. https://database.guide/2-sample-databases-sqlite/" ] }, { "cell_type": "code", "execution_count": 20, "id": "2ccca6fc", "metadata": {}, "outputs": [], "source": [ "db = SQLDatabase.from_uri(\"sqlite:///./Chinook.db\")" ] }, { "cell_type": "code", "execution_count": 21, "id": "05ba88ee", "metadata": {}, "outputs": [], "source": [ "def get_schema(_):\n", " return db.get_table_info()" ] }, { "cell_type": "code", "execution_count": 22, "id": "a4eda902", "metadata": {}, "outputs": [], "source": [ "def run_query(query):\n", " return db.run(query)" ] }, { "cell_type": "code", "execution_count": 23, "id": "5046cb17", "metadata": {}, "outputs": [], "source": [ "from operator import itemgetter\n", "\n", "from langchain.chat_models import ChatOpenAI\n", "from langchain.schema.output_parser import StrOutputParser\n", "from langchain.schema.runnable import RunnableLambda, RunnableMap\n", "\n", "model = ChatOpenAI()\n", "\n", "inputs = {\n", " \"schema\": RunnableLambda(get_schema),\n", " \"question\": itemgetter(\"question\")\n", "}\n", "sql_response = (\n", " RunnableMap(inputs)\n", " | prompt\n", " | model.bind(stop=[\"\\nSQLResult:\"])\n", " | StrOutputParser()\n", " )" ] }, { "cell_type": "code", "execution_count": 24, "id": "a5552039", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'SELECT COUNT(*) FROM Employee'" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sql_response.invoke({\"question\": \"How many employees are there?\"})" ] }, { "cell_type": "code", "execution_count": 25, "id": "d6fee130", "metadata": {}, "outputs": [], "source": [ "template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n", "{schema}\n", "\n", "Question: {question}\n", "SQL Query: {query}\n", "SQL Response: {response}\"\"\"\n", "prompt_response = ChatPromptTemplate.from_template(template)" ] }, { "cell_type": "code", "execution_count": 26, "id": "923aa634", "metadata": {}, "outputs": [], "source": [ "full_chain = (\n", " RunnableMap({\n", " \"question\": itemgetter(\"question\"),\n", " \"query\": sql_response,\n", " }) \n", " | {\n", " \"schema\": RunnableLambda(get_schema),\n", " \"question\": itemgetter(\"question\"),\n", " \"query\": itemgetter(\"query\"),\n", " \"response\": lambda x: db.run(x[\"query\"]) \n", " } \n", " | prompt_response \n", " | model\n", ")" ] }, { "cell_type": "code", "execution_count": 27, "id": "e94963d8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AIMessage(content='There are 8 employees.', additional_kwargs={}, example=False)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "full_chain.invoke({\"question\": \"How many employees are there?\"})" ] }, { "cell_type": "code", "execution_count": null, "id": "4f358d7b-a721-4db3-9f92-f06913428afc", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.1" } }, "nbformat": 4, "nbformat_minor": 5 }