You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/cookbook/LLaMA2_sql_chat.ipynb

377 lines
252 KiB
Plaintext

{
"cells": [
{
"attachments": {
"02bd3e70-6db4-40e2-9f6b-c83127eb2339.png": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABngAAALTCAYAAAA1lQyVAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkJAQIICAlNCbIFIDSAmhBZBeBBshCRBKjIGgYkcXFVy7iIANXRVR7IDYETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMuTP33gGAfpwnkeSimgDkiQukcaGBzNEpqUzSU0AEdEAFVkCLx8+XsGNiIgEsA+3fy7vrAJG3VxzlWv/s/69FSyDM5wOAxECcLsjn50G8HwC8mi+RFgBAlPMWkwskcgwr0JHCACFeIMeZSlwtx+lKvFthkxDHgbgVADUqjyfNBEDjEuSZhfxMqKHRC7GzWCASA0BnQuyXlzdRAHEaxLbQRgKxXJ+V/oNO5t800wc1ebzMQayci6KoBYnyJbm8qf9nOv53ycuVDfiwhpWaJQ2Lk88Z5u1mzsQIOaZC3CNOj4qGWBviDyKBwh5ilJIlC0tU2qNG/HwOzBnQg9hZwAuKgNgI4hBxblSkik/PEIVwIYYrBJ0iKuAmQKwP8QJhfnC8ymaDdGKcyhfakCHlsFX8WZ5U4Vfu674sJ5Gt0n+dJeSq9DGNoqyEZIgpEFsWipKiINaA2Ck/Jz5CZTOyKIsTNWAjlcXJ47eEOE4oDg1U6mOFGdKQOJV9aV7+wHyxDVkibpQK7y3ISghT5gdr5fMU8cO5YJeEYnbigI4wf3TkwFwEwqBg5dyxZ0JxYrxK54OkIDBOORanSHJjVPa4uTA3VM6bQ+yWXxivGosnFcAFqdTHMyQFMQnKOPGibF54jDIefCmIBBwQBJhABms6mAiygai9p7EH3il7QgAPSEEmEAJHFTMwIlnRI4bXeFAE/oRICPIHxwUqeoWgEPJfB1nl1RFkKHoLFSNywBOI80AEyIX3MsUo8aC3JPAYMqJ/eOfByofx5sIq7//3/AD7nWFDJlLFyAY8MukDlsRgYhAxjBhCtMMNcT/cB4+E1wBYXXAW7jUwj+/2hCeEDsJDwjVCJ+HWBFGx9KcoR4FOqB+iykX6j7nAraGmOx6I+0J1qIzr4YbAEXeDfti4P/TsDlmOKm55Vpg/af9tBj88DZUd2ZmMkoeQA8i2P4/UsNdwH1SR5/rH/ChjTR/MN2ew52f/nB+yL4BtxM+W2AJsH3YGO4Gdww5jjYCJHcOasDbsiBwPrq7HitU14C1OEU8O1BH9w9/Ak5VnMt+5zrnb+Yuyr0A4Rf6OBpyJkqlSUWZWAZMNvwhCJlfMdxrGdHF2cQVA/n1Rvr7exCq+G4he23du7h8A+B7r7+8/9J0LPwbAHk+4/Q9+52xZ8NOhDsDZg3yZtFDJ4fILAb4l6HCnGQATYAFs4XxcgAfwAQEgGISDaJAAUsB4GH0WXOdSMBlMB3NACSgDS8EqUAnWg01gG9gJ9oJGcBicAKfBBXAJXAN34OrpAi9AL3gHPiMIQkJoCAMxQEwRK8QBcUFYiB8SjEQicUgKkoZkImJEhkxH5iJlyHKkEtmI1CJ7kIPICeQc0oHcQh4g3chr5BOKoVRUBzVGrdHhKAtloxFoAjoOzUQnoUXoPHQxWoHWoDvQBvQEegG9hnaiL9A+DGDqmB5mhjliLIyDRWOpWAYmxWZipVg5VoPVY83wOV/BOrEe7CNOxBk4E3eEKzgMT8T5+CR8Jr4Ir8S34Q14K34Ff4D34t8INIIRwYHgTeASRhMyCZMJJYRywhbCAcIpuJe6CO+IRKIe0YboCfdiCjGbOI24iLiWuIt4nNhBfETsI5FIBiQHki8pmsQjFZBKSGtIO0jHSJdJXaQPaupqpmouaiFqqWpitWK1crXtakfVLqs9VftM1iRbkb3J0WQBeSp5CXkzuZl8kdxF/kzRothQfCkJlGzKHEoFpZ5yinKX8kZdXd1c3Us9Vl2kPlu9Qn23+ln1B+ofqdpUeyqHOpYqoy6mbqUep96ivqHRaNa0AFoqrYC2mFZLO0m7T/ugwdBw0uBqCDRmaVRpNGhc1nhJJ9Ot6Gz6eHoRvZy+j36R3qNJ1rTW5GjyNGdqVmke1Lyh2afF0BqhFa2Vp7VIa7vWOa1n2iRta+1gbYH2PO1N2ie1HzEwhgWDw+Az5jI2M04xunSIOjY6XJ1snTKdnTrtOr262rpuukm6U3SrdI/oduphetZ6XL1cvSV6e/Wu630aYjyEPUQ4ZOGQ+iGXh7zXH6ofoC/UL9XfpX9N/5MB0yDYIMdgmUGjwT1D3NDeMNZwsuE6w1OGPUN1hvoM5Q8tHbp36G0j1MjeKM5omtEmozajPmMT41BjifEa45PGPSZ6JgEm2SYrTY6adJsyTP1MRaYrTY+ZPmfqMtnMXGYFs5XZa2ZkFmYmM9to1m722dzGPNG82HyX+T0LigXLIsNipUWLRa+lqeUoy+mWdZa3rchWLKssq9VWZ6zeW9tYJ1vPt260fmajb8O1KbKps7lrS7P1t51kW2N71Y5ox7LLsVtrd8ketXe3z7Kvsr/ogDp4OIgc1jp0DCMM8xomHlYz7IYj1ZHtWOhY5/jASc8p0qnYqdHp5XDL4anDlw0/M/ybs7tzrvNm5zsjtEeEjyge0TzitYu9C9+lyuWqK801xHWWa5PrKzcHN6HbOreb7gz3Ue7z3Vvcv3p4ekg96j26PS090zyrPW+wdFgxrEWss14Er0CvWV6HvT56e3gXeO/1/svH0SfHZ7vPs5E2I4UjN4985Gvuy/Pd6Nvpx/RL89vg1+lv5s/zr/F/GGARIAjYEvCUbcfOZu9gvwx0DpQGHgh8z/HmzOAcD8KCQoNKg9qDtYMTgyuD74eYh2SG1IX0hrqHTgs9HkYIiwhbFnaDa8zlc2u5veGe4TPCWyOoEfERlREPI+0jpZHNo9BR4aNWjLobZRUljmqMBtHc6BXR92JsYibFHIolxsbEVsU+iRsRNz3uTDwjfkL89vh3CYEJSxLuJNomyhJbkuhJY5Nqk94nByUvT+4cPXz0jNEXUgxTRClNqaTUpNQtqX1jgsesGtM11n1sydjr42zGTRl3brzh+NzxRybQJ/Am7EsjpCWnbU/7wovm1fD60rnp1em9fA5/Nf+FIECwUtAt9BUuFz7N8M1YnvEs0zdzRWZ3ln9WeVaPiCOqFL3KDsten/0+Jzpna05/bnLurjy1vLS8g2JtcY64daLJxCkTOyQOkhJJ5yTvSasm9UojpFvykfxx+U0FOvBHvk1mK/tF9qDQr7Cq8MPkpMn7pmhNEU9pm2o/deHUp0UhRb9Nw6fxp7VMN5s+Z/qDGewZG2ciM9NntsyymDVvVtfs0Nnb5lDm5Mz5vdi5eHnx27nJc5vnGc+bPe/RL6G/1JVolEhLbsz3mb9+Ab5AtKB9oevCNQu/lQpKz5c5l5WXfVnEX3T+1xG/VvzavzhjcfsSjyXrlhKXipdeX+a/bNtyreVFyx+tGLWiYSVzZenKt6smrDpX7la+fjVltWx1Z0VkRdMayzVL13ypzKq8VhVYtavaqHph9fu1grWX1wWsq19vvL5s/acNog03N4ZubKixrinfRNxUuOnJ5qTNZ35j/Va7xXBL2ZavW8VbO7fFbWut9ayt3W60fUkdWier694xdselnUE7m+od6zfu0ttVthvslu1+vidtz/W9EXtb9rH21e+32l99gHGgtAFpmNrQ25jV2NmU0tRxMPxgS7NP84FDToe2HjY7XHVE98iSo5Sj8472Hys61ndccrznROaJRy0TWu6cHH3yamtsa/upiFNnT4ecPnmGfebYWd+zh895nzt4nnW+8YLHhYY297YDv7v/fqDdo73houfFpktel5o7RnYcvex/+cSVoCunr3KvXrgWda3jeuL1mzfG3ui8Kbj57FburVe3C29/vjP7LuFu6T3Ne+X3je7X/GH3x65Oj84jD4IetD2Mf3jnEf/Ri8f5j790zXtCe1L+1PRp7TOXZ4e7Q7ovPR/zvOuF5MXnnpI/tf6sfmn7cv9fAX+19Y7u7XolfdX/etEbgzdb37q9bemL6bv/Lu/d5/elHww+bPvI+njmU/Knp58nfyF9qfhq97X5
}
},
"cell_type": "markdown",
"id": "fc935871-7640-41c6-b798-58514d860fe0",
"metadata": {},
"source": [
"## LLaMA2 chat with SQL\n",
"\n",
"LangChain makes it easy to build diverse chains from building blocks.\n",
"\n",
"![cake.png](attachment:02bd3e70-6db4-40e2-9f6b-c83127eb2339.png)\n",
"\n",
"This cookbook shows how to combine three ideas:\n",
"\n",
"* Chat\n",
"* SQL\n",
"* An OSS model (LLaMA2)\n",
"\n",
"It demonstrates chat over a SQLite DB containing 2023-24 NBA rosters."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81adcf8b-395a-4f02-8749-ac976942b446",
"metadata": {},
"outputs": [],
"source": [
"! pip install langchain replicate"
]
},
{
"cell_type": "markdown",
"id": "8e13ed66-300b-4a23-b8ac-44df68ee4733",
"metadata": {},
"source": [
"## LLM\n",
"\n",
"Use Replicate API for `llama-2-13b-chat`.\n",
"\n",
"Note there is also support for running LLaMA2 locally (see [here](https://python.langchain.com/docs/guides/local_llms)). "
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "416ecce7-8aec-4145-b3f1-587a9b8a4fe9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Init param `input` is deprecated, please use `model_kwargs` instead.\n"
]
}
],
"source": [
"from getpass import getpass\n",
"from langchain.llms import Replicate\n",
"# REPLICATE_API_TOKEN = getpass()\n",
"# os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN\n",
"\n",
"# Replicate API\n",
"llama2_13b_chat = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n",
"\n",
"# Set the system_prompt so that LLaMA will generate only the SQL statement, instead of being wordy and adding something like\n",
"# \"Sure! Here's the SQL query for the given input question: \" before the SQL query; otherwise custom parsing will be needed.\n",
"llm = Replicate(\n",
" model=llama2_13b_chat,\n",
" input={\"temperature\": 0.01, \n",
" \"max_length\": 500, \n",
" \"top_p\": 1}\n",
")"
]
},
{
"cell_type": "markdown",
"id": "80222165-f353-4e35-a123-5f70fd70c6c8",
"metadata": {},
"source": [
"## DB\n",
"\n",
"Connect to a SQLite DB."
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "025bdd82-3bb1-4948-bc7c-c3ccd94fd05c",
"metadata": {},
"outputs": [],
"source": [
"from langchain.utilities import SQLDatabase\n",
"db = SQLDatabase.from_uri(\"sqlite:///nba_roster.db\", sample_rows_in_table_info= 0)\n",
"\n",
"def get_schema(_):\n",
" return db.get_table_info()\n",
"\n",
"def run_query(query):\n",
" return db.run(query)"
]
},
{
"cell_type": "markdown",
"id": "654b3577-baa2-4e12-a393-f40e5db49ac7",
"metadata": {},
"source": [
"## Query a SQL DB \n",
"\n",
"Follow the runnables workflow [here](https://python.langchain.com/docs/expression_language/cookbook/sql_db)."
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "5a4933ea-d9c0-4b0a-8177-ba4490c6532b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" SELECT * FROM nba_roster WHERE NAME = 'Klay Thompson';\""
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Prompt\n",
"from langchain.prompts import ChatPromptTemplate\n",
"template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query:\"\"\"\n",
"prompt = ChatPromptTemplate.from_messages([\n",
" (\"system\", \"Given an input question, convert it to a SQL query. No pre-amble.\"),\n",
" (\"human\", template)\n",
"])\n",
"\n",
"# Chain to query\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.schema.output_parser import StrOutputParser\n",
"from langchain.schema.runnable import RunnablePassthrough\n",
"\n",
"sql_response = (\n",
" RunnablePassthrough.assign(schema=get_schema)\n",
" | prompt\n",
" | llm.bind(stop=[\"\\nSQLResult:\"])\n",
" | StrOutputParser()\n",
" )\n",
"\n",
"sql_response.invoke({\"question\": \"What team is Klay Thompson on?\"})"
]
},
{
"cell_type": "markdown",
"id": "a0e9e2c8-9b88-4853-ac86-001bc6cc6695",
"metadata": {},
"source": [
"The [LangSmith trace](https://smith.langchain.com/public/afa56a06-b4e2-469a-a60f-c1746e75e42b/r) gives us visibility into the chain! "
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "2a2825e3-c1b6-4f7d-b9c9-d9835de323bb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" Sure! Here's the natural language response based on the given SQL query and response:\\n\\nThere are 30 unique teams in the NBA roster.\""
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Chain to answer\n",
"template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query: {query}\n",
"SQL Response: {response}\"\"\"\n",
"prompt_response = ChatPromptTemplate.from_messages([\n",
" (\"system\", \"Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.\"),\n",
" (\"human\", template)\n",
"])\n",
"\n",
"full_chain = (\n",
" RunnablePassthrough.assign(query=sql_response) \n",
" | RunnablePassthrough.assign(\n",
" schema=get_schema,\n",
" response=lambda x: db.run(x[\"query\"]),\n",
" )\n",
" | prompt_response \n",
" | llm\n",
")\n",
"\n",
"full_chain.invoke({\"question\": \"How many unique teams are there?\"})"
]
},
{
"cell_type": "markdown",
"id": "ec17b3ee-6618-4681-b6df-089bbb5ffcd7",
"metadata": {},
"source": [
"Again, the [LangSmith trace](https://smith.langchain.com/public/10420721-746a-4806-8ecf-d6dc6399d739/r) gives us visibility into the chain."
]
},
{
"cell_type": "markdown",
"id": "1e85381b-1edc-4bb3-a7bd-2ab23f81e54d",
"metadata": {},
"source": [
"## Chat with a SQL DB \n",
"\n",
"Next, we can add memory."
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "1985aa1c-eb8f-4fb1-a54f-c8aa10744687",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"SELECT Team \\nFROM nba_roster \\nWHERE NAME = 'Klay Thompson'\""
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Prompt\n",
"from langchain.prompts import ChatPromptTemplate\n",
"template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query:\"\"\"\n",
"prompt = ChatPromptTemplate.from_messages([\n",
" (\"system\", \"Given an input question, convert it to a SQL query. No pre-amble.\"),\n",
" MessagesPlaceholder(variable_name=\"history\"),\n",
" (\"human\", template)\n",
"])\n",
"\n",
"memory = ConversationBufferMemory(return_messages=True)\n",
"\n",
"# Chain to query with memory \n",
"from langchain.memory import ConversationBufferMemory\n",
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"from langchain.schema.runnable import RunnableLambda, GetLocalVar, PutLocalVar\n",
"\n",
"sql_chain = (\n",
" RunnablePassthrough.assign(\n",
" schema=get_schema,\n",
" history=RunnableLambda(lambda x: memory.load_memory_variables(x)[\"history\"])\n",
" )| prompt\n",
" | model.bind(stop=[\"\\nSQLResult:\"])\n",
" | StrOutputParser()\n",
")\n",
"\n",
"def save(input_output):\n",
" output = {\"output\": input_output.pop(\"output\")}\n",
" memory.save_context(input_output, output)\n",
" return output['output']\n",
" \n",
"sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save\n",
"sql_response_memory.invoke({\"question\": \"What team is Klay Thompson on?\"})"
]
},
{
"cell_type": "code",
"execution_count": 75,
"id": "0b45818a-1498-441d-b82d-23c29428c2bb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"SELECT SALARY \\nFROM nba_roster \\nWHERE NAME = 'Klay Thompson'\""
]
},
"execution_count": 75,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sql_response_memory.invoke({\"question\": \"What is his salary?\"})"
]
},
{
"cell_type": "code",
"execution_count": 76,
"id": "800a7a3b-f411-478b-af51-2310cd6e0425",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" Sure thing! Here's the natural language response based on the given SQL query and response:\\n\\nKlay Thompson plays for the Golden State Warriors.\""
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Chain to answer\n",
"template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query: {query}\n",
"SQL Response: {response}\"\"\"\n",
"prompt_response = ChatPromptTemplate.from_messages([\n",
" (\"system\", \"Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.\"),\n",
" (\"human\", template)\n",
"])\n",
"\n",
"full_chain = (\n",
" RunnablePassthrough.assign(query=sql_response_memory) \n",
" | RunnablePassthrough.assign(\n",
" schema=get_schema,\n",
" response=lambda x: db.run(x[\"query\"]),\n",
" )\n",
" | prompt_response \n",
" | llm\n",
")\n",
"\n",
"full_chain.invoke({\"question\": \"What team is Klay Thompson on?\"})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}