LLaMA2 SQL Chat cookbook (#11685)

pull/11832/head
Lance Martin 12 months ago committed by GitHub
parent a506302772
commit 8bf16d5275
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,363 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "fc935871-7640-41c6-b798-58514d860fe0",
"metadata": {},
"source": [
"## LLaMA2 chat with SQL\n",
" \n",
"This Cookbook shows how to use LangChain's `SQLDatabaseChain` with LLaMA2 to chat about structured data stored in a SQL DB. \n",
"\n",
"* As the 2023-24 NBA season is around the corner, we use the NBA roster info saved in a SQLite DB to show you how to ask Llama2 questions about your favorite teams or players. \n",
"\n",
"* Because the SQLDatabaseChain API implementation is still in the langchain_experimental package, you'll see more issues that come with using the cutting edge experimental features, and how we succeed resolving some of the issues but fail on some others."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81adcf8b-395a-4f02-8749-ac976942b446",
"metadata": {},
"outputs": [],
"source": [
"! pip install langchain replicate langchain_experimental"
]
},
{
"cell_type": "markdown",
"id": "8e13ed66-300b-4a23-b8ac-44df68ee4733",
"metadata": {},
"source": [
"## LLM\n",
"\n",
"Use Replicate API for llama-2-13b-chat."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "416ecce7-8aec-4145-b3f1-587a9b8a4fe9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Init param `input` is deprecated, please use `model_kwargs` instead.\n"
]
}
],
"source": [
"from getpass import getpass\n",
"from langchain.llms import Replicate\n",
"# REPLICATE_API_TOKEN = getpass()\n",
"# os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN\n",
"\n",
"# Replicate API\n",
"llama2_13b_chat = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n",
"\n",
"# Set the system_prompt so that LLaMA will generate only the SQL statement, instead of being wordy and adding something like\n",
"# \"Sure! Here's the SQL query for the given input question: \" before the SQL query; otherwise custom parsing will be needed.\n",
"llm = Replicate(\n",
" model=llama2_13b_chat,\n",
" input={\"temperature\": 0.01, \n",
" \"max_length\": 500, \n",
" \"top_p\": 1}\n",
")"
]
},
{
"cell_type": "markdown",
"id": "80222165-f353-4e35-a123-5f70fd70c6c8",
"metadata": {},
"source": [
"## DB\n",
"\n",
"Connect to a SQL DB, which in this case is in this same directory."
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "025bdd82-3bb1-4948-bc7c-c3ccd94fd05c",
"metadata": {},
"outputs": [],
"source": [
"from langchain.utilities import SQLDatabase\n",
"db = SQLDatabase.from_uri(\"sqlite:///nba_roster.db\", sample_rows_in_table_info= 0)\n",
"\n",
"def get_schema(_):\n",
" return db.get_table_info()\n",
"\n",
"def run_query(query):\n",
" return db.run(query)"
]
},
{
"cell_type": "markdown",
"id": "654b3577-baa2-4e12-a393-f40e5db49ac7",
"metadata": {},
"source": [
"## Query a SQL DB \n",
"\n",
"Follow the workflow [here](https://python.langchain.com/docs/expression_language/cookbook/sql_db)."
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "5a4933ea-d9c0-4b0a-8177-ba4490c6532b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" SELECT * FROM nba_roster WHERE NAME = 'Klay Thompson';\""
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Prompt\n",
"from langchain.prompts import ChatPromptTemplate\n",
"template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query:\"\"\"\n",
"prompt = ChatPromptTemplate.from_messages([\n",
" (\"system\", \"Given an input question, convert it to a SQL query. No pre-amble.\"),\n",
" (\"human\", template)\n",
"])\n",
"\n",
"# Chain to query\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.schema.output_parser import StrOutputParser\n",
"from langchain.schema.runnable import RunnablePassthrough\n",
"\n",
"sql_response = (\n",
" RunnablePassthrough.assign(schema=get_schema)\n",
" | prompt\n",
" | llm.bind(stop=[\"\\nSQLResult:\"])\n",
" | StrOutputParser()\n",
" )\n",
"\n",
"sql_response.invoke({\"question\": \"What team is Klay Thompson on?\"})"
]
},
{
"cell_type": "markdown",
"id": "a0e9e2c8-9b88-4853-ac86-001bc6cc6695",
"metadata": {},
"source": [
"The [LangSmith trace](https://smith.langchain.com/public/afa56a06-b4e2-469a-a60f-c1746e75e42b/r) gives us visibility into the chain! "
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "2a2825e3-c1b6-4f7d-b9c9-d9835de323bb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" Sure! Here's the natural language response based on the given SQL query and response:\\n\\nThere are 30 unique teams in the NBA roster.\""
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Chain to answer\n",
"template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query: {query}\n",
"SQL Response: {response}\"\"\"\n",
"prompt_response = ChatPromptTemplate.from_messages([\n",
" (\"system\", \"Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.\"),\n",
" (\"human\", template)\n",
"])\n",
"\n",
"full_chain = (\n",
" RunnablePassthrough.assign(query=sql_response) \n",
" | RunnablePassthrough.assign(\n",
" schema=get_schema,\n",
" response=lambda x: db.run(x[\"query\"]),\n",
" )\n",
" | prompt_response \n",
" | llm\n",
")\n",
"\n",
"full_chain.invoke({\"question\": \"How many unique teams are there?\"})"
]
},
{
"cell_type": "markdown",
"id": "ec17b3ee-6618-4681-b6df-089bbb5ffcd7",
"metadata": {},
"source": [
"Again, the [LangSmith trace](https://smith.langchain.com/public/10420721-746a-4806-8ecf-d6dc6399d739/r) gives us visibility into the chain! "
]
},
{
"cell_type": "markdown",
"id": "1e85381b-1edc-4bb3-a7bd-2ab23f81e54d",
"metadata": {},
"source": [
"## Chat with a SQL DB \n",
"\n",
"Add memory!"
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "1985aa1c-eb8f-4fb1-a54f-c8aa10744687",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"SELECT Team \\nFROM nba_roster \\nWHERE NAME = 'Klay Thompson'\""
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Prompt\n",
"from langchain.prompts import ChatPromptTemplate\n",
"template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query:\"\"\"\n",
"prompt = ChatPromptTemplate.from_messages([\n",
" (\"system\", \"Given an input question, convert it to a SQL query. No pre-amble.\"),\n",
" MessagesPlaceholder(variable_name=\"history\"),\n",
" (\"human\", template)\n",
"])\n",
"\n",
"memory = ConversationBufferMemory(return_messages=True)\n",
"\n",
"# Chain to query with memory \n",
"from langchain.memory import ConversationBufferMemory\n",
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"from langchain.schema.runnable import RunnableLambda, GetLocalVar, PutLocalVar\n",
"\n",
"sql_chain = (\n",
" RunnablePassthrough.assign(\n",
" schema=get_schema,\n",
" history=RunnableLambda(lambda x: memory.load_memory_variables(x)[\"history\"])\n",
" )| prompt\n",
" | model.bind(stop=[\"\\nSQLResult:\"])\n",
" | StrOutputParser()\n",
")\n",
"\n",
"def save(input_output):\n",
" output = {\"output\": input_output.pop(\"output\")}\n",
" memory.save_context(input_output, output)\n",
" return output['output']\n",
" \n",
"sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save\n",
"sql_response_memory.invoke({\"question\": \"What team is Klay Thompson on?\"})"
]
},
{
"cell_type": "code",
"execution_count": 75,
"id": "0b45818a-1498-441d-b82d-23c29428c2bb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"SELECT SALARY \\nFROM nba_roster \\nWHERE NAME = 'Klay Thompson'\""
]
},
"execution_count": 75,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sql_response_memory.invoke({\"question\": \"What is his salary?\"})"
]
},
{
"cell_type": "code",
"execution_count": 76,
"id": "800a7a3b-f411-478b-af51-2310cd6e0425",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" Sure thing! Here's the natural language response based on the given SQL query and response:\\n\\nKlay Thompson plays for the Golden State Warriors.\""
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Chain to answer\n",
"template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query: {query}\n",
"SQL Response: {response}\"\"\"\n",
"prompt_response = ChatPromptTemplate.from_messages([\n",
" (\"system\", \"Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.\"),\n",
" (\"human\", template)\n",
"])\n",
"\n",
"full_chain = (\n",
" RunnablePassthrough.assign(query=sql_response_memory) \n",
" | RunnablePassthrough.assign(\n",
" schema=get_schema,\n",
" response=lambda x: db.run(x[\"query\"]),\n",
" )\n",
" | prompt_response \n",
" | llm\n",
")\n",
"\n",
"full_chain.invoke({\"question\": \"What team is Klay Thompson on?\"})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading…
Cancel
Save