mirror of
https://github.com/hwchase17/langchain
synced 2024-11-11 19:11:02 +00:00
412fa4e1db
<!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure you're PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> --------- Co-authored-by: Nuno Campos <nuno@boringbits.io> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
933 lines
22 KiB
Plaintext
933 lines
22 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9a9acd2e",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Cookbook\n",
|
|
"\n",
|
|
"In this notebook we'll take a look at a few common types of sequences to create."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "93aa2c87",
|
|
"metadata": {},
|
|
"source": [
|
|
"## PromptTemplate + LLM\n",
|
|
"\n",
|
|
"A PromptTemplate -> LLM is a core chain that is used in most other larger chains/systems."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "466b65b3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.prompts import ChatPromptTemplate\n",
|
|
"from langchain.chat_models import ChatOpenAI"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "3c634ef0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = ChatOpenAI()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "d1850a1f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"prompt = ChatPromptTemplate.from_template(\"tell me a joke about {foo}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "56d0669f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"chain = prompt | model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "e3d0a6cd",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"AIMessage(content='Why don\\'t bears use cell phones? \\n\\nBecause they always get terrible \"grizzly\" reception!', additional_kwargs={}, example=False)"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"chain.invoke({\"foo\": \"bears\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7eb9ef50",
|
|
"metadata": {},
|
|
"source": [
|
|
"Often times we want to attach kwargs to the model that's passed in. Here's a few examples of that:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0b1d8f88",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Attaching Stop Sequences"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "562a06bf",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"chain = prompt | model.bind(stop=[\"\\n\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "43f5d04c",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"AIMessage(content=\"Why don't bears use cell phones?\", additional_kwargs={}, example=False)"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"chain.invoke({\"foo\": \"bears\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f3eaf88a",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Attaching Function Call information"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "f94b71b2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"functions = [\n",
|
|
" {\n",
|
|
" \"name\": \"joke\",\n",
|
|
" \"description\": \"A joke\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"setup\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The setup for the joke\"\n",
|
|
" },\n",
|
|
" \"punchline\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The punchline for the joke\"\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"setup\", \"punchline\"]\n",
|
|
" }\n",
|
|
" }\n",
|
|
" ]\n",
|
|
"chain = prompt | model.bind(function_call= {\"name\": \"joke\"}, functions= functions)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "decf7710",
|
|
"metadata": {
|
|
"scrolled": false
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"AIMessage(content='', additional_kwargs={'function_call': {'name': 'joke', 'arguments': '{\\n \"setup\": \"Why don\\'t bears wear shoes?\",\\n \"punchline\": \"Because they have bear feet!\"\\n}'}}, example=False)"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"chain.invoke({\"foo\": \"bears\"}, config={})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9098c5ed",
|
|
"metadata": {},
|
|
"source": [
|
|
"## PromptTemplate + LLM + OutputParser\n",
|
|
"\n",
|
|
"We can also add in an output parser to easily trasform the raw LLM/ChatModel output into a more workable format"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "f799664d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.schema.output_parser import StrOutputParser"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "cc194c78",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"chain = prompt | model | StrOutputParser()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "77acf448",
|
|
"metadata": {},
|
|
"source": [
|
|
"Notice that this now returns a string - a much more workable format for downstream tasks"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "e3d69a18",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"\"Why don't bears wear shoes?\\n\\nBecause they have bear feet!\""
|
|
]
|
|
},
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"chain.invoke({\"foo\": \"bears\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c01864e5",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Functions Output Parser\n",
|
|
"\n",
|
|
"When you specify the function to return, you may just want to parse that directly"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "ad0dd88e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser\n",
|
|
"chain = (\n",
|
|
" prompt \n",
|
|
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
|
|
" | JsonOutputFunctionsParser()\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "1e7aa8eb",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'setup': \"Why don't bears wear shoes?\",\n",
|
|
" 'punchline': 'Because they have bear feet!'}"
|
|
]
|
|
},
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"chain.invoke({\"foo\": \"bears\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "d4aa1a01",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.output_parsers.openai_functions import JsonKeyOutputFunctionsParser\n",
|
|
"chain = (\n",
|
|
" prompt \n",
|
|
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
|
|
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "8b6df9ba",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"\"Why don't bears like fast food?\""
|
|
]
|
|
},
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"chain.invoke({\"foo\": \"bears\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2ed58136",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Passthroughs and itemgetter\n",
|
|
"\n",
|
|
"Often times when constructing a chain you may want to pass along original input variables to future steps in the chain. How exactly you do this depends on what exactly the input is:\n",
|
|
"\n",
|
|
"- If the original input was a string, then you likely just want to pass along the string. This can be done with `RunnablePassthrough`. For an example of this, see `LLMChain + Retriever`\n",
|
|
"- If the original input was a dictionary, then you likely want to pass along specific keys. This can be done with `itemgetter`. For an example of this see `Multiple LLM Chains`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "5d3d8ffe",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.schema.runnable import RunnablePassthrough\n",
|
|
"from operator import itemgetter"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "91c5ef3d",
|
|
"metadata": {},
|
|
"source": [
|
|
"## LLMChain + Retriever\n",
|
|
"\n",
|
|
"Let's now look at adding in a retrieval step, which adds up to a \"retrieval-augmented generation\" chain"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "33be32af",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.vectorstores import Chroma\n",
|
|
"from langchain.embeddings import OpenAIEmbeddings\n",
|
|
"from langchain.schema.runnable import RunnablePassthrough"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"id": "df3f3fa2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create the retriever\n",
|
|
"vectorstore = Chroma.from_texts([\"harrison worked at kensho\"], embedding=OpenAIEmbeddings())\n",
|
|
"retriever = vectorstore.as_retriever()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "bfc47ec1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"template = \"\"\"Answer the question based only on the following context:\n",
|
|
"{context}\n",
|
|
"\n",
|
|
"Question: {question}\n",
|
|
"\"\"\"\n",
|
|
"prompt = ChatPromptTemplate.from_template(template)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "eae31755",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"chain = (\n",
|
|
" {\"context\": retriever, \"question\": RunnablePassthrough()} \n",
|
|
" | prompt \n",
|
|
" | model \n",
|
|
" | StrOutputParser()\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"id": "f3040b0c",
|
|
"metadata": {
|
|
"scrolled": false
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Number of requested results 4 is greater than number of elements in index 1, updating n_results = 1\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'Harrison worked at Kensho.'"
|
|
]
|
|
},
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"chain.invoke(\"where did harrison work?\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "e1d20c7c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"template = \"\"\"Answer the question based only on the following context:\n",
|
|
"{context}\n",
|
|
"\n",
|
|
"Question: {question}\n",
|
|
"\n",
|
|
"Answer in the following language: {language}\n",
|
|
"\"\"\"\n",
|
|
"prompt = ChatPromptTemplate.from_template(template)\n",
|
|
"\n",
|
|
"chain = {\n",
|
|
" \"context\": itemgetter(\"question\") | retriever, \n",
|
|
" \"question\": itemgetter(\"question\"), \n",
|
|
" \"language\": itemgetter(\"language\")\n",
|
|
"} | prompt | model | StrOutputParser()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"id": "7ee8b2d4",
|
|
"metadata": {
|
|
"scrolled": false
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Number of requested results 4 is greater than number of elements in index 1, updating n_results = 1\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'Harrison ha lavorato a Kensho.'"
|
|
]
|
|
},
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"chain.invoke({\"question\": \"where did harrison work\", \"language\": \"italian\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0f2bf8d3",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Multiple LLM Chains\n",
|
|
"\n",
|
|
"This can also be used to string together multiple LLMChains"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"id": "d65d4e9e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'El país en el que nació la ciudad de Honolulu, Hawái, donde nació Barack Obama, el 44º presidente de los Estados Unidos, es Estados Unidos.'"
|
|
]
|
|
},
|
|
"execution_count": 31,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from operator import itemgetter\n",
|
|
"\n",
|
|
"prompt1 = ChatPromptTemplate.from_template(\"what is the city {person} is from?\")\n",
|
|
"prompt2 = ChatPromptTemplate.from_template(\"what country is the city {city} in? respond in {language}\")\n",
|
|
"\n",
|
|
"chain1 = prompt1 | model | StrOutputParser()\n",
|
|
"\n",
|
|
"chain2 = {\"city\": chain1, \"language\": itemgetter(\"language\")} | prompt2 | model | StrOutputParser()\n",
|
|
"\n",
|
|
"chain2.invoke({\"person\": \"obama\", \"language\": \"spanish\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"id": "878f8176",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.schema.runnable import RunnableMap\n",
|
|
"prompt1 = ChatPromptTemplate.from_template(\"generate a random color\")\n",
|
|
"prompt2 = ChatPromptTemplate.from_template(\"what is a fruit of color: {color}\")\n",
|
|
"prompt3 = ChatPromptTemplate.from_template(\"what is countries flag that has the color: {color}\")\n",
|
|
"prompt4 = ChatPromptTemplate.from_template(\"What is the color of {fruit} and {country}\")\n",
|
|
"chain1 = prompt1 | model | StrOutputParser()\n",
|
|
"chain2 = RunnableMap(steps={\"color\": chain1}) | {\n",
|
|
" \"fruit\": prompt2 | model | StrOutputParser(),\n",
|
|
" \"country\": prompt3 | model | StrOutputParser(),\n",
|
|
"} | prompt4"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 33,
|
|
"id": "d621a870",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"ChatPromptValue(messages=[HumanMessage(content=\"What is the color of A fruit that has a color similar to #7E7DE6 is the Peruvian Apple Cactus (Cereus repandus). It is a tropical fruit with a vibrant purple or violet exterior. and The country's flag that has the color #7E7DE6 is North Macedonia.\", additional_kwargs={}, example=False)])"
|
|
]
|
|
},
|
|
"execution_count": 33,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"chain2.invoke({})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fbc4bf6e",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Arbitrary Functions\n",
|
|
"\n",
|
|
"You can use arbitrary functions in the pipeline\n",
|
|
"\n",
|
|
"Note that all inputs to these functions need to be a SINGLE argument. If you have a function that accepts multiple arguments, you should write a wrapper that accepts a single input and unpacks it into multiple argument."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 35,
|
|
"id": "6bb221b3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.schema.runnable import RunnableLambda\n",
|
|
"\n",
|
|
"def length_function(text):\n",
|
|
" return len(text)\n",
|
|
"\n",
|
|
"def _multiple_length_function(text1, text2):\n",
|
|
" return len(text1) * len(text2)\n",
|
|
"\n",
|
|
"def multiple_length_function(_dict):\n",
|
|
" return _multiple_length_function(_dict[\"text1\"], _dict[\"text2\"])\n",
|
|
"\n",
|
|
"prompt = ChatPromptTemplate.from_template(\"what is {a} + {b}\")\n",
|
|
"\n",
|
|
"chain1 = prompt | model\n",
|
|
"\n",
|
|
"chain = {\n",
|
|
" \"a\": itemgetter(\"foo\") | RunnableLambda(length_function),\n",
|
|
" \"b\": {\"text1\": itemgetter(\"foo\"), \"text2\": itemgetter(\"bar\")} | RunnableLambda(multiple_length_function)\n",
|
|
"} | prompt | model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"id": "5488ec85",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"AIMessage(content='3 + 9 is equal to 12.', additional_kwargs={}, example=False)"
|
|
]
|
|
},
|
|
"execution_count": 36,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"chain.invoke({\"foo\": \"bar\", \"bar\": \"gah\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "506e9636",
|
|
"metadata": {},
|
|
"source": [
|
|
"## SQL Database\n",
|
|
"\n",
|
|
"We can also try to replicate our SQLDatabaseChain using this style."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"id": "7a927516",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n",
|
|
"{schema}\n",
|
|
"\n",
|
|
"Question: {question}\"\"\"\n",
|
|
"prompt = ChatPromptTemplate.from_template(template)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"id": "3f51f386",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.utilities import SQLDatabase"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 41,
|
|
"id": "2ccca6fc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"db = SQLDatabase.from_uri(\"sqlite:///../../../../notebooks/Chinook.db\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"id": "05ba88ee",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_schema(_):\n",
|
|
" return db.get_table_info()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 43,
|
|
"id": "a4eda902",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def run_query(query):\n",
|
|
" return db.run(query)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 47,
|
|
"id": "5046cb17",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"inputs = {\n",
|
|
" \"schema\": RunnableLambda(get_schema),\n",
|
|
" \"question\": itemgetter(\"question\")\n",
|
|
"}\n",
|
|
"sql_response = (\n",
|
|
" RunnableMap(inputs)\n",
|
|
" | prompt\n",
|
|
" | model.bind(stop=[\"\\nSQLResult:\"])\n",
|
|
" | StrOutputParser()\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 48,
|
|
"id": "a5552039",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'SELECT COUNT(*) \\nFROM Employee;'"
|
|
]
|
|
},
|
|
"execution_count": 48,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"sql_response.invoke({\"question\": \"How many employees are there?\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 49,
|
|
"id": "d6fee130",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n",
|
|
"{schema}\n",
|
|
"\n",
|
|
"Question: {question}\n",
|
|
"SQL Query: {query}\n",
|
|
"SQL Response: {response}\"\"\"\n",
|
|
"prompt_response = ChatPromptTemplate.from_template(template)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 52,
|
|
"id": "923aa634",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"full_chain = (\n",
|
|
" RunnableMap({\n",
|
|
" \"question\": itemgetter(\"question\"),\n",
|
|
" \"query\": sql_response,\n",
|
|
" }) \n",
|
|
" | {\n",
|
|
" \"schema\": RunnableLambda(get_schema),\n",
|
|
" \"question\": itemgetter(\"question\"),\n",
|
|
" \"query\": itemgetter(\"query\"),\n",
|
|
" \"response\": lambda x: db.run(x[\"query\"]) \n",
|
|
" } \n",
|
|
" | prompt_response \n",
|
|
" | model\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 53,
|
|
"id": "e94963d8",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"AIMessage(content='There are 8 employees.', additional_kwargs={}, example=False)"
|
|
]
|
|
},
|
|
"execution_count": 53,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"full_chain.invoke({\"question\": \"How many employees are there?\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f09fd305",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Code Writing"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 57,
|
|
"id": "bd7c259a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.utilities import PythonREPL\n",
|
|
"from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 58,
|
|
"id": "73795d2d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"template = \"\"\"Write some python code to solve the user's problem. \n",
|
|
"\n",
|
|
"Return only python code in Markdown format, eg:\n",
|
|
"\n",
|
|
"```python\n",
|
|
"....\n",
|
|
"```\"\"\"\n",
|
|
"prompt = ChatPromptTemplate(messages=[\n",
|
|
" SystemMessagePromptTemplate.from_template(template),\n",
|
|
" HumanMessagePromptTemplate.from_template(\"{input}\")\n",
|
|
"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 64,
|
|
"id": "42859e8a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def _sanitize_output(text: str):\n",
|
|
" _, after = text.split(\"```python\")\n",
|
|
" return after.split(\"```\")[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 67,
|
|
"id": "5ded1a86",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"chain = prompt | model | StrOutputParser() | _sanitize_output | PythonREPL().run"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 68,
|
|
"id": "208c2b75",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Python REPL can execute arbitrary code. Use with caution.\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'4\\n'"
|
|
]
|
|
},
|
|
"execution_count": 68,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"chain.invoke({\"input\": \"whats 2 plus 2\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9be88499",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.1"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|