mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
402 lines
12 KiB
Plaintext
402 lines
12 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "fc935871-7640-41c6-b798-58514d860fe0",
|
|
"metadata": {},
|
|
"source": [
|
|
"## LLaMA2 chat with SQL\n",
|
|
"\n",
|
|
"Open source, local LLMs are great to consider for any application that demands data privacy.\n",
|
|
"\n",
|
|
"SQL is one good example. \n",
|
|
"\n",
|
|
"This cookbook shows how to perform text-to-SQL using various local versions of LLaMA2 run locally.\n",
|
|
"\n",
|
|
"## Packages"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "81adcf8b-395a-4f02-8749-ac976942b446",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"! pip install langchain replicate"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8e13ed66-300b-4a23-b8ac-44df68ee4733",
|
|
"metadata": {},
|
|
"source": [
|
|
"## LLM\n",
|
|
"\n",
|
|
"There are a few ways to access LLaMA2.\n",
|
|
"\n",
|
|
"To run locally, we use Ollama.ai. \n",
|
|
"\n",
|
|
"See [here](https://python.langchain.com/docs/integrations/chat/ollama) for details on installation and setup.\n",
|
|
"\n",
|
|
"Also, see [here](https://python.langchain.com/docs/guides/local_llms) for our full guide on local LLMs.\n",
|
|
" \n",
|
|
"To use an external API, which is not private, we can use Replicate."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "6a75a5c6-34ee-4ab9-a664-d9b432d812ee",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Init param `input` is deprecated, please use `model_kwargs` instead.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Local \n",
|
|
"from langchain.chat_models import ChatOllama\n",
|
|
"llama2_chat = ChatOllama(model=\"llama2:13b-chat\")\n",
|
|
"llama2_code = ChatOllama(model=\"codellama:7b-instruct\")\n",
|
|
"\n",
|
|
"# API\n",
|
|
"from getpass import getpass\n",
|
|
"from langchain.llms import Replicate\n",
|
|
"# REPLICATE_API_TOKEN = getpass()\n",
|
|
"# os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN\n",
|
|
"replicate_id = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n",
|
|
"llama2_chat_replicate = Replicate(\n",
|
|
" model=replicate_id,\n",
|
|
" input={\"temperature\": 0.01, \n",
|
|
" \"max_length\": 500, \n",
|
|
" \"top_p\": 1}\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "ce96f7ea-b3d5-44e1-9fa5-a79e04a9e1fb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Simply set the LLM we want to use\n",
|
|
"llm = llama2_chat"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "80222165-f353-4e35-a123-5f70fd70c6c8",
|
|
"metadata": {},
|
|
"source": [
|
|
"## DB\n",
|
|
"\n",
|
|
"Connect to a SQLite DB.\n",
|
|
"\n",
|
|
"To create this particular DB, you can use the code and follow the steps shown [here](https://github.com/facebookresearch/llama-recipes/blob/main/demo_apps/StructuredLlama.ipynb)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "025bdd82-3bb1-4948-bc7c-c3ccd94fd05c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.utilities import SQLDatabase\n",
|
|
"db = SQLDatabase.from_uri(\"sqlite:///nba_roster.db\", sample_rows_in_table_info= 0)\n",
|
|
"\n",
|
|
"def get_schema(_):\n",
|
|
" return db.get_table_info()\n",
|
|
"\n",
|
|
"def run_query(query):\n",
|
|
" return db.run(query)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "654b3577-baa2-4e12-a393-f40e5db49ac7",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Query a SQL DB \n",
|
|
"\n",
|
|
"Follow the runnables workflow [here](https://python.langchain.com/docs/expression_language/cookbook/sql_db)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "5a4933ea-d9c0-4b0a-8177-ba4490c6532b",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"' SELECT \"Team\" FROM nba_roster WHERE \"NAME\" = \\'Klay Thompson\\';'"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Prompt\n",
|
|
"from langchain.prompts import ChatPromptTemplate\n",
|
|
"template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n",
|
|
"{schema}\n",
|
|
"\n",
|
|
"Question: {question}\n",
|
|
"SQL Query:\"\"\"\n",
|
|
"prompt = ChatPromptTemplate.from_messages([\n",
|
|
" (\"system\", \"Given an input question, convert it to a SQL query. No pre-amble.\"),\n",
|
|
" (\"human\", template)\n",
|
|
"])\n",
|
|
"\n",
|
|
"# Chain to query\n",
|
|
"from langchain.chat_models import ChatOpenAI\n",
|
|
"from langchain.schema.output_parser import StrOutputParser\n",
|
|
"from langchain.schema.runnable import RunnablePassthrough\n",
|
|
"\n",
|
|
"sql_response = (\n",
|
|
" RunnablePassthrough.assign(schema=get_schema)\n",
|
|
" | prompt\n",
|
|
" | llm.bind(stop=[\"\\nSQLResult:\"])\n",
|
|
" | StrOutputParser()\n",
|
|
" )\n",
|
|
"\n",
|
|
"sql_response.invoke({\"question\": \"What team is Klay Thompson on?\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a0e9e2c8-9b88-4853-ac86-001bc6cc6695",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can review the results:\n",
|
|
"\n",
|
|
"* [LangSmith trace](https://smith.langchain.com/public/afa56a06-b4e2-469a-a60f-c1746e75e42b/r) LLaMA2-13 Replicate API\n",
|
|
"* [LangSmith trace](https://smith.langchain.com/public/2d4ecc72-6b8f-4523-8f0b-ea95c6b54a1d/r) LLaMA2-13 local \n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "2a2825e3-c1b6-4f7d-b9c9-d9835de323bb",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"AIMessage(content=' Based on the table schema and SQL query, there are 30 unique teams in the NBA.')"
|
|
]
|
|
},
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Chain to answer\n",
|
|
"template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n",
|
|
"{schema}\n",
|
|
"\n",
|
|
"Question: {question}\n",
|
|
"SQL Query: {query}\n",
|
|
"SQL Response: {response}\"\"\"\n",
|
|
"prompt_response = ChatPromptTemplate.from_messages([\n",
|
|
" (\"system\", \"Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.\"),\n",
|
|
" (\"human\", template)\n",
|
|
"])\n",
|
|
"\n",
|
|
"full_chain = (\n",
|
|
" RunnablePassthrough.assign(query=sql_response) \n",
|
|
" | RunnablePassthrough.assign(\n",
|
|
" schema=get_schema,\n",
|
|
" response=lambda x: db.run(x[\"query\"]),\n",
|
|
" )\n",
|
|
" | prompt_response \n",
|
|
" | llm\n",
|
|
")\n",
|
|
"\n",
|
|
"full_chain.invoke({\"question\": \"How many unique teams are there?\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ec17b3ee-6618-4681-b6df-089bbb5ffcd7",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can review the results:\n",
|
|
"\n",
|
|
"* [LangSmith trace](https://smith.langchain.com/public/10420721-746a-4806-8ecf-d6dc6399d739/r) LLaMA2-13 Replicate API\n",
|
|
"* [LangSmith trace](https://smith.langchain.com/public/5265ebab-0a22-4f37-936b-3300f2dfa1c1/r) LLaMA2-13 local "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1e85381b-1edc-4bb3-a7bd-2ab23f81e54d",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Chat with a SQL DB \n",
|
|
"\n",
|
|
"Next, we can add memory."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "1985aa1c-eb8f-4fb1-a54f-c8aa10744687",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"' SELECT \"Team\" FROM nba_roster WHERE \"NAME\" = \\'Klay Thompson\\';'"
|
|
]
|
|
},
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Prompt\n",
|
|
"from langchain.memory import ConversationBufferMemory\n",
|
|
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
|
|
"template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n",
|
|
"{schema}\n",
|
|
"\n",
|
|
"Question: {question}\n",
|
|
"SQL Query:\"\"\"\n",
|
|
"prompt = ChatPromptTemplate.from_messages([\n",
|
|
" (\"system\", \"Given an input question, convert it to a SQL query. No pre-amble.\"),\n",
|
|
" MessagesPlaceholder(variable_name=\"history\"),\n",
|
|
" (\"human\", template)\n",
|
|
"])\n",
|
|
"\n",
|
|
"memory = ConversationBufferMemory(return_messages=True)\n",
|
|
"\n",
|
|
"# Chain to query with memory \n",
|
|
"from langchain.schema.runnable import RunnableLambda\n",
|
|
"\n",
|
|
"sql_chain = (\n",
|
|
" RunnablePassthrough.assign(\n",
|
|
" schema=get_schema,\n",
|
|
" history=RunnableLambda(lambda x: memory.load_memory_variables(x)[\"history\"])\n",
|
|
" )| prompt\n",
|
|
" | llm.bind(stop=[\"\\nSQLResult:\"])\n",
|
|
" | StrOutputParser()\n",
|
|
")\n",
|
|
"\n",
|
|
"def save(input_output):\n",
|
|
" output = {\"output\": input_output.pop(\"output\")}\n",
|
|
" memory.save_context(input_output, output)\n",
|
|
" return output['output']\n",
|
|
" \n",
|
|
"sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save\n",
|
|
"sql_response_memory.invoke({\"question\": \"What team is Klay Thompson on?\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "0b45818a-1498-441d-b82d-23c29428c2bb",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"' SELECT \"SALARY\" FROM nba_roster WHERE \"NAME\" = \\'Klay Thompson\\';'"
|
|
]
|
|
},
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"sql_response_memory.invoke({\"question\": \"What is his salary?\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"id": "800a7a3b-f411-478b-af51-2310cd6e0425",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"AIMessage(content=' Sure! Here\\'s the natural language response based on the given input:\\n\\n\"Klay Thompson\\'s salary is $43,219,440.\"')"
|
|
]
|
|
},
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Chain to answer\n",
|
|
"template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n",
|
|
"{schema}\n",
|
|
"\n",
|
|
"Question: {question}\n",
|
|
"SQL Query: {query}\n",
|
|
"SQL Response: {response}\"\"\"\n",
|
|
"prompt_response = ChatPromptTemplate.from_messages([\n",
|
|
" (\"system\", \"Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.\"),\n",
|
|
" (\"human\", template)\n",
|
|
"])\n",
|
|
"\n",
|
|
"full_chain = (\n",
|
|
" RunnablePassthrough.assign(query=sql_response_memory) \n",
|
|
" | RunnablePassthrough.assign(\n",
|
|
" schema=get_schema,\n",
|
|
" response=lambda x: db.run(x[\"query\"]),\n",
|
|
" )\n",
|
|
" | prompt_response \n",
|
|
" | llm\n",
|
|
")\n",
|
|
"\n",
|
|
"full_chain.invoke({\"question\": \"What is his salary?\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b77fee61-f4da-4bb1-8285-14101e505518",
|
|
"metadata": {},
|
|
"source": [
|
|
"Here is the [trace](https://smith.langchain.com/public/54794d18-2337-4ce2-8b9f-3d8a2df89e51/r)."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|