langchain/cookbook/LLaMA2_sql_chat.ipynb
2023-10-18 15:01:37 -07:00

402 lines
12 KiB
Plaintext

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "fc935871-7640-41c6-b798-58514d860fe0",
"metadata": {},
"source": [
"## LLaMA2 chat with SQL\n",
"\n",
"Open source, local LLMs are great to consider for any application that demands data privacy.\n",
"\n",
"SQL is one good example. \n",
"\n",
"This cookbook shows how to perform text-to-SQL using various local versions of LLaMA2 run locally.\n",
"\n",
"## Packages"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81adcf8b-395a-4f02-8749-ac976942b446",
"metadata": {},
"outputs": [],
"source": [
"! pip install langchain replicate"
]
},
{
"cell_type": "markdown",
"id": "8e13ed66-300b-4a23-b8ac-44df68ee4733",
"metadata": {},
"source": [
"## LLM\n",
"\n",
"There are a few ways to access LLaMA2.\n",
"\n",
"To run locally, we use Ollama.ai. \n",
"\n",
"See [here](https://python.langchain.com/docs/integrations/chat/ollama) for details on installation and setup.\n",
"\n",
"Also, see [here](https://python.langchain.com/docs/guides/local_llms) for our full guide on local LLMs.\n",
" \n",
"To use an external API, which is not private, we can use Replicate."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6a75a5c6-34ee-4ab9-a664-d9b432d812ee",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Init param `input` is deprecated, please use `model_kwargs` instead.\n"
]
}
],
"source": [
"# Local \n",
"from langchain.chat_models import ChatOllama\n",
"llama2_chat = ChatOllama(model=\"llama2:13b-chat\")\n",
"llama2_code = ChatOllama(model=\"codellama:7b-instruct\")\n",
"\n",
"# API\n",
"from getpass import getpass\n",
"from langchain.llms import Replicate\n",
"# REPLICATE_API_TOKEN = getpass()\n",
"# os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN\n",
"replicate_id = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n",
"llama2_chat_replicate = Replicate(\n",
" model=replicate_id,\n",
" input={\"temperature\": 0.01, \n",
" \"max_length\": 500, \n",
" \"top_p\": 1}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "ce96f7ea-b3d5-44e1-9fa5-a79e04a9e1fb",
"metadata": {},
"outputs": [],
"source": [
"# Simply set the LLM we want to use\n",
"llm = llama2_chat"
]
},
{
"cell_type": "markdown",
"id": "80222165-f353-4e35-a123-5f70fd70c6c8",
"metadata": {},
"source": [
"## DB\n",
"\n",
"Connect to a SQLite DB.\n",
"\n",
"To create this particular DB, you can use the code and follow the steps shown [here](https://github.com/facebookresearch/llama-recipes/blob/main/demo_apps/StructuredLlama.ipynb)."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "025bdd82-3bb1-4948-bc7c-c3ccd94fd05c",
"metadata": {},
"outputs": [],
"source": [
"from langchain.utilities import SQLDatabase\n",
"db = SQLDatabase.from_uri(\"sqlite:///nba_roster.db\", sample_rows_in_table_info= 0)\n",
"\n",
"def get_schema(_):\n",
" return db.get_table_info()\n",
"\n",
"def run_query(query):\n",
" return db.run(query)"
]
},
{
"cell_type": "markdown",
"id": "654b3577-baa2-4e12-a393-f40e5db49ac7",
"metadata": {},
"source": [
"## Query a SQL DB \n",
"\n",
"Follow the runnables workflow [here](https://python.langchain.com/docs/expression_language/cookbook/sql_db)."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "5a4933ea-d9c0-4b0a-8177-ba4490c6532b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' SELECT \"Team\" FROM nba_roster WHERE \"NAME\" = \\'Klay Thompson\\';'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Prompt\n",
"from langchain.prompts import ChatPromptTemplate\n",
"template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query:\"\"\"\n",
"prompt = ChatPromptTemplate.from_messages([\n",
" (\"system\", \"Given an input question, convert it to a SQL query. No pre-amble.\"),\n",
" (\"human\", template)\n",
"])\n",
"\n",
"# Chain to query\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.schema.output_parser import StrOutputParser\n",
"from langchain.schema.runnable import RunnablePassthrough\n",
"\n",
"sql_response = (\n",
" RunnablePassthrough.assign(schema=get_schema)\n",
" | prompt\n",
" | llm.bind(stop=[\"\\nSQLResult:\"])\n",
" | StrOutputParser()\n",
" )\n",
"\n",
"sql_response.invoke({\"question\": \"What team is Klay Thompson on?\"})"
]
},
{
"cell_type": "markdown",
"id": "a0e9e2c8-9b88-4853-ac86-001bc6cc6695",
"metadata": {},
"source": [
"We can review the results:\n",
"\n",
"* [LangSmith trace](https://smith.langchain.com/public/afa56a06-b4e2-469a-a60f-c1746e75e42b/r) LLaMA2-13 Replicate API\n",
"* [LangSmith trace](https://smith.langchain.com/public/2d4ecc72-6b8f-4523-8f0b-ea95c6b54a1d/r) LLaMA2-13 local \n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "2a2825e3-c1b6-4f7d-b9c9-d9835de323bb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=' Based on the table schema and SQL query, there are 30 unique teams in the NBA.')"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Chain to answer\n",
"template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query: {query}\n",
"SQL Response: {response}\"\"\"\n",
"prompt_response = ChatPromptTemplate.from_messages([\n",
" (\"system\", \"Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.\"),\n",
" (\"human\", template)\n",
"])\n",
"\n",
"full_chain = (\n",
" RunnablePassthrough.assign(query=sql_response) \n",
" | RunnablePassthrough.assign(\n",
" schema=get_schema,\n",
" response=lambda x: db.run(x[\"query\"]),\n",
" )\n",
" | prompt_response \n",
" | llm\n",
")\n",
"\n",
"full_chain.invoke({\"question\": \"How many unique teams are there?\"})"
]
},
{
"cell_type": "markdown",
"id": "ec17b3ee-6618-4681-b6df-089bbb5ffcd7",
"metadata": {},
"source": [
"We can review the results:\n",
"\n",
"* [LangSmith trace](https://smith.langchain.com/public/10420721-746a-4806-8ecf-d6dc6399d739/r) LLaMA2-13 Replicate API\n",
"* [LangSmith trace](https://smith.langchain.com/public/5265ebab-0a22-4f37-936b-3300f2dfa1c1/r) LLaMA2-13 local "
]
},
{
"cell_type": "markdown",
"id": "1e85381b-1edc-4bb3-a7bd-2ab23f81e54d",
"metadata": {},
"source": [
"## Chat with a SQL DB \n",
"\n",
"Next, we can add memory."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "1985aa1c-eb8f-4fb1-a54f-c8aa10744687",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' SELECT \"Team\" FROM nba_roster WHERE \"NAME\" = \\'Klay Thompson\\';'"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Prompt\n",
"from langchain.memory import ConversationBufferMemory\n",
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query:\"\"\"\n",
"prompt = ChatPromptTemplate.from_messages([\n",
" (\"system\", \"Given an input question, convert it to a SQL query. No pre-amble.\"),\n",
" MessagesPlaceholder(variable_name=\"history\"),\n",
" (\"human\", template)\n",
"])\n",
"\n",
"memory = ConversationBufferMemory(return_messages=True)\n",
"\n",
"# Chain to query with memory \n",
"from langchain.schema.runnable import RunnableLambda\n",
"\n",
"sql_chain = (\n",
" RunnablePassthrough.assign(\n",
" schema=get_schema,\n",
" history=RunnableLambda(lambda x: memory.load_memory_variables(x)[\"history\"])\n",
" )| prompt\n",
" | llm.bind(stop=[\"\\nSQLResult:\"])\n",
" | StrOutputParser()\n",
")\n",
"\n",
"def save(input_output):\n",
" output = {\"output\": input_output.pop(\"output\")}\n",
" memory.save_context(input_output, output)\n",
" return output['output']\n",
" \n",
"sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save\n",
"sql_response_memory.invoke({\"question\": \"What team is Klay Thompson on?\"})"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "0b45818a-1498-441d-b82d-23c29428c2bb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' SELECT \"SALARY\" FROM nba_roster WHERE \"NAME\" = \\'Klay Thompson\\';'"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sql_response_memory.invoke({\"question\": \"What is his salary?\"})"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "800a7a3b-f411-478b-af51-2310cd6e0425",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=' Sure! Here\\'s the natural language response based on the given input:\\n\\n\"Klay Thompson\\'s salary is $43,219,440.\"')"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Chain to answer\n",
"template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query: {query}\n",
"SQL Response: {response}\"\"\"\n",
"prompt_response = ChatPromptTemplate.from_messages([\n",
" (\"system\", \"Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.\"),\n",
" (\"human\", template)\n",
"])\n",
"\n",
"full_chain = (\n",
" RunnablePassthrough.assign(query=sql_response_memory) \n",
" | RunnablePassthrough.assign(\n",
" schema=get_schema,\n",
" response=lambda x: db.run(x[\"query\"]),\n",
" )\n",
" | prompt_response \n",
" | llm\n",
")\n",
"\n",
"full_chain.invoke({\"question\": \"What is his salary?\"})"
]
},
{
"cell_type": "markdown",
"id": "b77fee61-f4da-4bb1-8285-14101e505518",
"metadata": {},
"source": [
"Here is the [trace](https://smith.langchain.com/public/54794d18-2337-4ce2-8b9f-3d8a2df89e51/r)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}