diff --git a/cookbook/LLaMA2_sql_chat.ipynb b/cookbook/LLaMA2_sql_chat.ipynb new file mode 100644 index 0000000000..a75f498684 --- /dev/null +++ b/cookbook/LLaMA2_sql_chat.ipynb @@ -0,0 +1,363 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fc935871-7640-41c6-b798-58514d860fe0", + "metadata": {}, + "source": [ + "## LLaMA2 chat with SQL\n", + " \n", + "This Cookbook shows how to use LangChain's `SQLDatabaseChain` with LLaMA2 to chat about structured data stored in a SQL DB. \n", + "\n", + "* As the 2023-24 NBA season is around the corner, we use the NBA roster info saved in a SQLite DB to show you how to ask Llama2 questions about your favorite teams or players. \n", + "\n", + "* Because the SQLDatabaseChain API implementation is still in the langchain_experimental package, you'll see more issues that come with using the cutting edge experimental features, and how we succeed resolving some of the issues but fail on some others." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81adcf8b-395a-4f02-8749-ac976942b446", + "metadata": {}, + "outputs": [], + "source": [ + "! pip install langchain replicate langchain_experimental" + ] + }, + { + "cell_type": "markdown", + "id": "8e13ed66-300b-4a23-b8ac-44df68ee4733", + "metadata": {}, + "source": [ + "## LLM\n", + "\n", + "Use Replicate API for llama-2-13b-chat." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "416ecce7-8aec-4145-b3f1-587a9b8a4fe9", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Init param `input` is deprecated, please use `model_kwargs` instead.\n" + ] + } + ], + "source": [ + "from getpass import getpass\n", + "from langchain.llms import Replicate\n", + "# REPLICATE_API_TOKEN = getpass()\n", + "# os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN\n", + "\n", + "# Replicate API\n", + "llama2_13b_chat = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n", + "\n", + "# Set the system_prompt so that LLaMA will generate only the SQL statement, instead of being wordy and adding something like\n", + "# \"Sure! Here's the SQL query for the given input question: \" before the SQL query; otherwise custom parsing will be needed.\n", + "llm = Replicate(\n", + " model=llama2_13b_chat,\n", + " input={\"temperature\": 0.01, \n", + " \"max_length\": 500, \n", + " \"top_p\": 1}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "80222165-f353-4e35-a123-5f70fd70c6c8", + "metadata": {}, + "source": [ + "## DB\n", + "\n", + "Connect to a SQL DB, which in this case is in this same directory." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "025bdd82-3bb1-4948-bc7c-c3ccd94fd05c", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.utilities import SQLDatabase\n", + "db = SQLDatabase.from_uri(\"sqlite:///nba_roster.db\", sample_rows_in_table_info= 0)\n", + "\n", + "def get_schema(_):\n", + " return db.get_table_info()\n", + "\n", + "def run_query(query):\n", + " return db.run(query)" + ] + }, + { + "cell_type": "markdown", + "id": "654b3577-baa2-4e12-a393-f40e5db49ac7", + "metadata": {}, + "source": [ + "## Query a SQL DB \n", + "\n", + "Follow the workflow [here](https://python.langchain.com/docs/expression_language/cookbook/sql_db)." + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "id": "5a4933ea-d9c0-4b0a-8177-ba4490c6532b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\" SELECT * FROM nba_roster WHERE NAME = 'Klay Thompson';\"" + ] + }, + "execution_count": 72, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Prompt\n", + "from langchain.prompts import ChatPromptTemplate\n", + "template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n", + "{schema}\n", + "\n", + "Question: {question}\n", + "SQL Query:\"\"\"\n", + "prompt = ChatPromptTemplate.from_messages([\n", + " (\"system\", \"Given an input question, convert it to a SQL query. No pre-amble.\"),\n", + " (\"human\", template)\n", + "])\n", + "\n", + "# Chain to query\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.schema.output_parser import StrOutputParser\n", + "from langchain.schema.runnable import RunnablePassthrough\n", + "\n", + "sql_response = (\n", + " RunnablePassthrough.assign(schema=get_schema)\n", + " | prompt\n", + " | llm.bind(stop=[\"\\nSQLResult:\"])\n", + " | StrOutputParser()\n", + " )\n", + "\n", + "sql_response.invoke({\"question\": \"What team is Klay Thompson on?\"})" + ] + }, + { + "cell_type": "markdown", + "id": "a0e9e2c8-9b88-4853-ac86-001bc6cc6695", + "metadata": {}, + "source": [ + "The [LangSmith trace](https://smith.langchain.com/public/afa56a06-b4e2-469a-a60f-c1746e75e42b/r) gives us visibility into the chain! " + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "id": "2a2825e3-c1b6-4f7d-b9c9-d9835de323bb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\" Sure! Here's the natural language response based on the given SQL query and response:\\n\\nThere are 30 unique teams in the NBA roster.\"" + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Chain to answer\n", + "template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n", + "{schema}\n", + "\n", + "Question: {question}\n", + "SQL Query: {query}\n", + "SQL Response: {response}\"\"\"\n", + "prompt_response = ChatPromptTemplate.from_messages([\n", + " (\"system\", \"Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.\"),\n", + " (\"human\", template)\n", + "])\n", + "\n", + "full_chain = (\n", + " RunnablePassthrough.assign(query=sql_response) \n", + " | RunnablePassthrough.assign(\n", + " schema=get_schema,\n", + " response=lambda x: db.run(x[\"query\"]),\n", + " )\n", + " | prompt_response \n", + " | llm\n", + ")\n", + "\n", + "full_chain.invoke({\"question\": \"How many unique teams are there?\"})" + ] + }, + { + "cell_type": "markdown", + "id": "ec17b3ee-6618-4681-b6df-089bbb5ffcd7", + "metadata": {}, + "source": [ + "Again, the [LangSmith trace](https://smith.langchain.com/public/10420721-746a-4806-8ecf-d6dc6399d739/r) gives us visibility into the chain! " + ] + }, + { + "cell_type": "markdown", + "id": "1e85381b-1edc-4bb3-a7bd-2ab23f81e54d", + "metadata": {}, + "source": [ + "## Chat with a SQL DB \n", + "\n", + "Add memory!" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "1985aa1c-eb8f-4fb1-a54f-c8aa10744687", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"SELECT Team \\nFROM nba_roster \\nWHERE NAME = 'Klay Thompson'\"" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Prompt\n", + "from langchain.prompts import ChatPromptTemplate\n", + "template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n", + "{schema}\n", + "\n", + "Question: {question}\n", + "SQL Query:\"\"\"\n", + "prompt = ChatPromptTemplate.from_messages([\n", + " (\"system\", \"Given an input question, convert it to a SQL query. No pre-amble.\"),\n", + " MessagesPlaceholder(variable_name=\"history\"),\n", + " (\"human\", template)\n", + "])\n", + "\n", + "memory = ConversationBufferMemory(return_messages=True)\n", + "\n", + "# Chain to query with memory \n", + "from langchain.memory import ConversationBufferMemory\n", + "from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n", + "from langchain.schema.runnable import RunnableLambda, GetLocalVar, PutLocalVar\n", + "\n", + "sql_chain = (\n", + " RunnablePassthrough.assign(\n", + " schema=get_schema,\n", + " history=RunnableLambda(lambda x: memory.load_memory_variables(x)[\"history\"])\n", + " )| prompt\n", + " | model.bind(stop=[\"\\nSQLResult:\"])\n", + " | StrOutputParser()\n", + ")\n", + "\n", + "def save(input_output):\n", + " output = {\"output\": input_output.pop(\"output\")}\n", + " memory.save_context(input_output, output)\n", + " return output['output']\n", + " \n", + "sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save\n", + "sql_response_memory.invoke({\"question\": \"What team is Klay Thompson on?\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "0b45818a-1498-441d-b82d-23c29428c2bb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"SELECT SALARY \\nFROM nba_roster \\nWHERE NAME = 'Klay Thompson'\"" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sql_response_memory.invoke({\"question\": \"What is his salary?\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "800a7a3b-f411-478b-af51-2310cd6e0425", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\" Sure thing! Here's the natural language response based on the given SQL query and response:\\n\\nKlay Thompson plays for the Golden State Warriors.\"" + ] + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Chain to answer\n", + "template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n", + "{schema}\n", + "\n", + "Question: {question}\n", + "SQL Query: {query}\n", + "SQL Response: {response}\"\"\"\n", + "prompt_response = ChatPromptTemplate.from_messages([\n", + " (\"system\", \"Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.\"),\n", + " (\"human\", template)\n", + "])\n", + "\n", + "full_chain = (\n", + " RunnablePassthrough.assign(query=sql_response_memory) \n", + " | RunnablePassthrough.assign(\n", + " schema=get_schema,\n", + " response=lambda x: db.run(x[\"query\"]),\n", + " )\n", + " | prompt_response \n", + " | llm\n", + ")\n", + "\n", + "full_chain.invoke({\"question\": \"What team is Klay Thompson on?\"})" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}