2023-10-15 15:54:09 +00:00
{
"cells": [
{
2023-10-18 22:01:37 +00:00
"attachments": {},
2023-10-15 15:54:09 +00:00
"cell_type": "markdown",
"id": "fc935871-7640-41c6-b798-58514d860fe0",
"metadata": {},
"source": [
"## LLaMA2 chat with SQL\n",
"\n",
2023-10-18 22:01:37 +00:00
"Open source, local LLMs are great to consider for any application that demands data privacy.\n",
2023-10-16 16:12:03 +00:00
"\n",
2023-10-18 22:01:37 +00:00
"SQL is one good example. \n",
2023-10-15 15:54:09 +00:00
"\n",
2023-10-18 22:01:37 +00:00
"This cookbook shows how to perform text-to-SQL using various local versions of LLaMA2 run locally.\n",
2023-10-16 16:12:03 +00:00
"\n",
2023-10-16 20:37:51 +00:00
"## Packages"
2023-10-15 15:54:09 +00:00
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81adcf8b-395a-4f02-8749-ac976942b446",
"metadata": {},
"outputs": [],
"source": [
2023-10-16 16:12:03 +00:00
"! pip install langchain replicate"
2023-10-15 15:54:09 +00:00
]
},
{
"cell_type": "markdown",
"id": "8e13ed66-300b-4a23-b8ac-44df68ee4733",
"metadata": {},
"source": [
"## LLM\n",
"\n",
2023-10-18 22:01:37 +00:00
"There are a few ways to access LLaMA2.\n",
2023-10-16 16:12:03 +00:00
"\n",
2023-10-18 22:01:37 +00:00
"To run locally, we use Ollama.ai. \n",
"\n",
"See [here](https://python.langchain.com/docs/integrations/chat/ollama) for details on installation and setup.\n",
"\n",
"Also, see [here](https://python.langchain.com/docs/guides/local_llms) for our full guide on local LLMs.\n",
" \n",
"To use an external API, which is not private, we can use Replicate."
2023-10-15 15:54:09 +00:00
]
},
{
"cell_type": "code",
2023-11-01 23:37:03 +00:00
"execution_count": 1,
2023-10-18 22:01:37 +00:00
"id": "6a75a5c6-34ee-4ab9-a664-d9b432d812ee",
2023-10-15 15:54:09 +00:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Init param `input` is deprecated, please use `model_kwargs` instead.\n"
]
}
],
"source": [
2023-10-29 22:50:09 +00:00
"# Local\n",
2023-10-18 22:01:37 +00:00
"from langchain.chat_models import ChatOllama\n",
2023-10-29 22:50:09 +00:00
"\n",
2023-10-18 22:01:37 +00:00
"llama2_chat = ChatOllama(model=\"llama2:13b-chat\")\n",
"llama2_code = ChatOllama(model=\"codellama:7b-instruct\")\n",
"\n",
"# API\n",
2023-10-15 15:54:09 +00:00
"from getpass import getpass\n",
"from langchain.llms import Replicate\n",
2023-10-29 22:50:09 +00:00
"\n",
2023-10-15 15:54:09 +00:00
"# REPLICATE_API_TOKEN = getpass()\n",
"# os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN\n",
2023-10-18 22:01:37 +00:00
"replicate_id = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n",
"llama2_chat_replicate = Replicate(\n",
2023-10-29 22:50:09 +00:00
" model=replicate_id, input={\"temperature\": 0.01, \"max_length\": 500, \"top_p\": 1}\n",
2023-10-15 15:54:09 +00:00
")"
]
},
2023-10-18 22:01:37 +00:00
{
"cell_type": "code",
2023-11-01 23:37:03 +00:00
"execution_count": 2,
2023-10-18 22:01:37 +00:00
"id": "ce96f7ea-b3d5-44e1-9fa5-a79e04a9e1fb",
"metadata": {},
"outputs": [],
"source": [
"# Simply set the LLM we want to use\n",
"llm = llama2_chat"
]
},
2023-10-15 15:54:09 +00:00
{
"cell_type": "markdown",
"id": "80222165-f353-4e35-a123-5f70fd70c6c8",
"metadata": {},
"source": [
"## DB\n",
"\n",
2023-10-16 20:37:51 +00:00
"Connect to a SQLite DB.\n",
"\n",
"To create this particular DB, you can use the code and follow the steps shown [here](https://github.com/facebookresearch/llama-recipes/blob/main/demo_apps/StructuredLlama.ipynb)."
2023-10-15 15:54:09 +00:00
]
},
{
"cell_type": "code",
2023-11-01 23:37:03 +00:00
"execution_count": 3,
2023-10-15 15:54:09 +00:00
"id": "025bdd82-3bb1-4948-bc7c-c3ccd94fd05c",
"metadata": {},
"outputs": [],
"source": [
"from langchain.utilities import SQLDatabase\n",
2023-10-29 22:50:09 +00:00
"\n",
"db = SQLDatabase.from_uri(\"sqlite:///nba_roster.db\", sample_rows_in_table_info=0)\n",
"\n",
2023-10-15 15:54:09 +00:00
"\n",
"def get_schema(_):\n",
" return db.get_table_info()\n",
"\n",
2023-10-29 22:50:09 +00:00
"\n",
2023-10-15 15:54:09 +00:00
"def run_query(query):\n",
" return db.run(query)"
]
},
{
"cell_type": "markdown",
"id": "654b3577-baa2-4e12-a393-f40e5db49ac7",
"metadata": {},
"source": [
"## Query a SQL DB \n",
"\n",
2023-10-16 16:12:03 +00:00
"Follow the runnables workflow [here](https://python.langchain.com/docs/expression_language/cookbook/sql_db)."
2023-10-15 15:54:09 +00:00
]
},
{
"cell_type": "code",
2023-11-01 23:37:03 +00:00
"execution_count": 4,
2023-10-15 15:54:09 +00:00
"id": "5a4933ea-d9c0-4b0a-8177-ba4490c6532b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2023-10-18 22:01:37 +00:00
"' SELECT \"Team\" FROM nba_roster WHERE \"NAME\" = \\'Klay Thompson\\';'"
2023-10-15 15:54:09 +00:00
]
},
2023-11-01 23:37:03 +00:00
"execution_count": 4,
2023-10-15 15:54:09 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Prompt\n",
"from langchain.prompts import ChatPromptTemplate\n",
2023-10-29 22:50:09 +00:00
"\n",
2023-10-15 15:54:09 +00:00
"template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query:\"\"\"\n",
2023-10-29 22:50:09 +00:00
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", \"Given an input question, convert it to a SQL query. No pre-amble.\"),\n",
" (\"human\", template),\n",
" ]\n",
")\n",
2023-10-15 15:54:09 +00:00
"\n",
"# Chain to query\n",
"from langchain.schema.output_parser import StrOutputParser\n",
"from langchain.schema.runnable import RunnablePassthrough\n",
"\n",
"sql_response = (\n",
2023-10-29 22:50:09 +00:00
" RunnablePassthrough.assign(schema=get_schema)\n",
" | prompt\n",
" | llm.bind(stop=[\"\\nSQLResult:\"])\n",
" | StrOutputParser()\n",
")\n",
2023-10-15 15:54:09 +00:00
"\n",
"sql_response.invoke({\"question\": \"What team is Klay Thompson on?\"})"
]
},
{
"cell_type": "markdown",
"id": "a0e9e2c8-9b88-4853-ac86-001bc6cc6695",
"metadata": {},
"source": [
2023-10-18 22:01:37 +00:00
"We can review the results:\n",
"\n",
"* [LangSmith trace](https://smith.langchain.com/public/afa56a06-b4e2-469a-a60f-c1746e75e42b/r) LLaMA2-13 Replicate API\n",
"* [LangSmith trace](https://smith.langchain.com/public/2d4ecc72-6b8f-4523-8f0b-ea95c6b54a1d/r) LLaMA2-13 local \n"
2023-10-15 15:54:09 +00:00
]
},
{
"cell_type": "code",
2023-10-18 22:01:37 +00:00
"execution_count": 15,
2023-10-15 15:54:09 +00:00
"id": "2a2825e3-c1b6-4f7d-b9c9-d9835de323bb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2023-10-18 22:01:37 +00:00
"AIMessage(content=' Based on the table schema and SQL query, there are 30 unique teams in the NBA.')"
2023-10-15 15:54:09 +00:00
]
},
2023-10-18 22:01:37 +00:00
"execution_count": 15,
2023-10-15 15:54:09 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Chain to answer\n",
"template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query: {query}\n",
"SQL Response: {response}\"\"\"\n",
2023-10-29 22:50:09 +00:00
"prompt_response = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.\",\n",
" ),\n",
" (\"human\", template),\n",
" ]\n",
")\n",
2023-10-15 15:54:09 +00:00
"\n",
"full_chain = (\n",
2023-10-29 22:50:09 +00:00
" RunnablePassthrough.assign(query=sql_response)\n",
2023-10-15 15:54:09 +00:00
" | RunnablePassthrough.assign(\n",
" schema=get_schema,\n",
" response=lambda x: db.run(x[\"query\"]),\n",
" )\n",
2023-10-29 22:50:09 +00:00
" | prompt_response\n",
2023-10-15 15:54:09 +00:00
" | llm\n",
")\n",
"\n",
"full_chain.invoke({\"question\": \"How many unique teams are there?\"})"
]
},
{
"cell_type": "markdown",
"id": "ec17b3ee-6618-4681-b6df-089bbb5ffcd7",
"metadata": {},
"source": [
2023-10-18 22:01:37 +00:00
"We can review the results:\n",
"\n",
"* [LangSmith trace](https://smith.langchain.com/public/10420721-746a-4806-8ecf-d6dc6399d739/r) LLaMA2-13 Replicate API\n",
"* [LangSmith trace](https://smith.langchain.com/public/5265ebab-0a22-4f37-936b-3300f2dfa1c1/r) LLaMA2-13 local "
2023-10-15 15:54:09 +00:00
]
},
{
"cell_type": "markdown",
"id": "1e85381b-1edc-4bb3-a7bd-2ab23f81e54d",
"metadata": {},
"source": [
"## Chat with a SQL DB \n",
"\n",
2023-10-16 16:12:03 +00:00
"Next, we can add memory."
2023-10-15 15:54:09 +00:00
]
},
{
"cell_type": "code",
2023-11-01 23:37:03 +00:00
"execution_count": 7,
"id": "022868f2-128e-42f5-8d90-d3bb2f11d994",
2023-10-15 15:54:09 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2023-10-18 22:01:37 +00:00
"' SELECT \"Team\" FROM nba_roster WHERE \"NAME\" = \\'Klay Thompson\\';'"
2023-10-15 15:54:09 +00:00
]
},
2023-11-01 23:37:03 +00:00
"execution_count": 7,
2023-10-15 15:54:09 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Prompt\n",
2023-10-18 22:01:37 +00:00
"from langchain.memory import ConversationBufferMemory\n",
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
2023-10-29 22:50:09 +00:00
"\n",
2023-11-01 23:37:03 +00:00
"template = \"\"\"Given an input question, convert it to a SQL query. No pre-amble. Based on the table schema below, write a SQL query that would answer the user's question:\n",
2023-10-15 15:54:09 +00:00
"{schema}\n",
2023-11-01 23:37:03 +00:00
"\"\"\"\n",
2023-10-29 22:50:09 +00:00
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
2023-11-01 23:37:03 +00:00
" (\"system\", template),\n",
2023-10-29 22:50:09 +00:00
" MessagesPlaceholder(variable_name=\"history\"),\n",
2023-11-01 23:37:03 +00:00
" (\"human\", \"{question}\"),\n",
2023-10-29 22:50:09 +00:00
" ]\n",
")\n",
2023-10-15 15:54:09 +00:00
"\n",
"memory = ConversationBufferMemory(return_messages=True)\n",
"\n",
2023-10-29 22:50:09 +00:00
"# Chain to query with memory\n",
2023-10-18 22:01:37 +00:00
"from langchain.schema.runnable import RunnableLambda\n",
2023-10-15 15:54:09 +00:00
"\n",
"sql_chain = (\n",
" RunnablePassthrough.assign(\n",
2023-10-29 22:50:09 +00:00
" schema=get_schema,\n",
" history=RunnableLambda(lambda x: memory.load_memory_variables(x)[\"history\"]),\n",
" )\n",
" | prompt\n",
2023-10-18 22:01:37 +00:00
" | llm.bind(stop=[\"\\nSQLResult:\"])\n",
2023-10-15 15:54:09 +00:00
" | StrOutputParser()\n",
")\n",
"\n",
2023-10-29 22:50:09 +00:00
"\n",
2023-10-15 15:54:09 +00:00
"def save(input_output):\n",
" output = {\"output\": input_output.pop(\"output\")}\n",
" memory.save_context(input_output, output)\n",
2023-10-29 22:50:09 +00:00
" return output[\"output\"]\n",
"\n",
"\n",
2023-10-15 15:54:09 +00:00
"sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save\n",
"sql_response_memory.invoke({\"question\": \"What team is Klay Thompson on?\"})"
]
},
{
"cell_type": "code",
2023-10-18 22:01:37 +00:00
"execution_count": 21,
2023-10-15 15:54:09 +00:00
"id": "800a7a3b-f411-478b-af51-2310cd6e0425",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2023-10-18 22:01:37 +00:00
"AIMessage(content=' Sure! Here\\'s the natural language response based on the given input:\\n\\n\"Klay Thompson\\'s salary is $43,219,440.\"')"
2023-10-15 15:54:09 +00:00
]
},
2023-10-18 22:01:37 +00:00
"execution_count": 21,
2023-10-15 15:54:09 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Chain to answer\n",
"template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query: {query}\n",
"SQL Response: {response}\"\"\"\n",
2023-10-29 22:50:09 +00:00
"prompt_response = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.\",\n",
" ),\n",
" (\"human\", template),\n",
" ]\n",
")\n",
2023-10-15 15:54:09 +00:00
"\n",
"full_chain = (\n",
2023-10-29 22:50:09 +00:00
" RunnablePassthrough.assign(query=sql_response_memory)\n",
2023-10-15 15:54:09 +00:00
" | RunnablePassthrough.assign(\n",
" schema=get_schema,\n",
" response=lambda x: db.run(x[\"query\"]),\n",
" )\n",
2023-10-29 22:50:09 +00:00
" | prompt_response\n",
2023-10-15 15:54:09 +00:00
" | llm\n",
")\n",
"\n",
2023-10-18 22:01:37 +00:00
"full_chain.invoke({\"question\": \"What is his salary?\"})"
]
},
{
"cell_type": "markdown",
"id": "b77fee61-f4da-4bb1-8285-14101e505518",
"metadata": {},
"source": [
"Here is the [trace](https://smith.langchain.com/public/54794d18-2337-4ce2-8b9f-3d8a2df89e51/r)."
2023-10-15 15:54:09 +00:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}