You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/docs/extras/guides/expression_language/cookbook.ipynb

933 lines
22 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "9a9acd2e",
"metadata": {},
"source": [
"# Cookbook\n",
"\n",
"In this notebook we'll take a look at a few common types of sequences to create."
]
},
{
"cell_type": "markdown",
"id": "93aa2c87",
"metadata": {},
"source": [
"## PromptTemplate + LLM\n",
"\n",
"A PromptTemplate -> LLM is a core chain that is used in most other larger chains/systems."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "466b65b3",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import ChatPromptTemplate\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3c634ef0",
"metadata": {},
"outputs": [],
"source": [
"model = ChatOpenAI()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d1850a1f",
"metadata": {},
"outputs": [],
"source": [
"prompt = ChatPromptTemplate.from_template(\"tell me a joke about {foo}\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "56d0669f",
"metadata": {},
"outputs": [],
"source": [
"chain = prompt | model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e3d0a6cd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Why don\\'t bears use cell phones? \\n\\nBecause they always get terrible \"grizzly\" reception!', additional_kwargs={}, example=False)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"foo\": \"bears\"})"
]
},
{
"cell_type": "markdown",
"id": "7eb9ef50",
"metadata": {},
"source": [
"Often times we want to attach kwargs to the model that's passed in. Here's a few examples of that:"
]
},
{
"cell_type": "markdown",
"id": "0b1d8f88",
"metadata": {},
"source": [
"### Attaching Stop Sequences"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "562a06bf",
"metadata": {},
"outputs": [],
"source": [
"chain = prompt | model.bind(stop=[\"\\n\"])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "43f5d04c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"Why don't bears use cell phones?\", additional_kwargs={}, example=False)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"foo\": \"bears\"})"
]
},
{
"cell_type": "markdown",
"id": "f3eaf88a",
"metadata": {},
"source": [
"### Attaching Function Call information"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f94b71b2",
"metadata": {},
"outputs": [],
"source": [
"functions = [\n",
" {\n",
" \"name\": \"joke\",\n",
" \"description\": \"A joke\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"setup\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The setup for the joke\"\n",
" },\n",
" \"punchline\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The punchline for the joke\"\n",
" }\n",
" },\n",
" \"required\": [\"setup\", \"punchline\"]\n",
" }\n",
" }\n",
" ]\n",
"chain = prompt | model.bind(function_call= {\"name\": \"joke\"}, functions= functions)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "decf7710",
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='', additional_kwargs={'function_call': {'name': 'joke', 'arguments': '{\\n \"setup\": \"Why don\\'t bears wear shoes?\",\\n \"punchline\": \"Because they have bear feet!\"\\n}'}}, example=False)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"foo\": \"bears\"}, config={})"
]
},
{
"cell_type": "markdown",
"id": "9098c5ed",
"metadata": {},
"source": [
"## PromptTemplate + LLM + OutputParser\n",
"\n",
"We can also add in an output parser to easily trasform the raw LLM/ChatModel output into a more workable format"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f799664d",
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema.output_parser import StrOutputParser"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cc194c78",
"metadata": {},
"outputs": [],
"source": [
"chain = prompt | model | StrOutputParser()"
]
},
{
"cell_type": "markdown",
"id": "77acf448",
"metadata": {},
"source": [
"Notice that this now returns a string - a much more workable format for downstream tasks"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e3d69a18",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"Why don't bears wear shoes?\\n\\nBecause they have bear feet!\""
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"foo\": \"bears\"})"
]
},
{
"cell_type": "markdown",
"id": "c01864e5",
"metadata": {},
"source": [
"### Functions Output Parser\n",
"\n",
"When you specify the function to return, you may just want to parse that directly"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "ad0dd88e",
"metadata": {},
"outputs": [],
"source": [
"from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser\n",
"chain = (\n",
" prompt \n",
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
" | JsonOutputFunctionsParser()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1e7aa8eb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'setup': \"Why don't bears wear shoes?\",\n",
" 'punchline': 'Because they have bear feet!'}"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"foo\": \"bears\"})"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "d4aa1a01",
"metadata": {},
"outputs": [],
"source": [
"from langchain.output_parsers.openai_functions import JsonKeyOutputFunctionsParser\n",
"chain = (\n",
" prompt \n",
" | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n",
" | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "8b6df9ba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"Why don't bears like fast food?\""
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"foo\": \"bears\"})"
]
},
{
"cell_type": "markdown",
"id": "2ed58136",
"metadata": {},
"source": [
"## Passthroughs and itemgetter\n",
"\n",
"Often times when constructing a chain you may want to pass along original input variables to future steps in the chain. How exactly you do this depends on what exactly the input is:\n",
"\n",
"- If the original input was a string, then you likely just want to pass along the string. This can be done with `RunnablePassthrough`. For an example of this, see `LLMChain + Retriever`\n",
"- If the original input was a dictionary, then you likely want to pass along specific keys. This can be done with `itemgetter`. For an example of this see `Multiple LLM Chains`"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "5d3d8ffe",
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema.runnable import RunnablePassthrough\n",
"from operator import itemgetter"
]
},
{
"cell_type": "markdown",
"id": "91c5ef3d",
"metadata": {},
"source": [
"## LLMChain + Retriever\n",
"\n",
"Let's now look at adding in a retrieval step, which adds up to a \"retrieval-augmented generation\" chain"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "33be32af",
"metadata": {},
"outputs": [],
"source": [
"from langchain.vectorstores import Chroma\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.schema.runnable import RunnablePassthrough"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "df3f3fa2",
"metadata": {},
"outputs": [],
"source": [
"# Create the retriever\n",
"vectorstore = Chroma.from_texts([\"harrison worked at kensho\"], embedding=OpenAIEmbeddings())\n",
"retriever = vectorstore.as_retriever()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "bfc47ec1",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Answer the question based only on the following context:\n",
"{context}\n",
"\n",
"Question: {question}\n",
"\"\"\"\n",
"prompt = ChatPromptTemplate.from_template(template)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "eae31755",
"metadata": {},
"outputs": [],
"source": [
"chain = (\n",
" {\"context\": retriever, \"question\": RunnablePassthrough()} \n",
" | prompt \n",
" | model \n",
" | StrOutputParser()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "f3040b0c",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Number of requested results 4 is greater than number of elements in index 1, updating n_results = 1\n"
]
},
{
"data": {
"text/plain": [
"'Harrison worked at Kensho.'"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke(\"where did harrison work?\")"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "e1d20c7c",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Answer the question based only on the following context:\n",
"{context}\n",
"\n",
"Question: {question}\n",
"\n",
"Answer in the following language: {language}\n",
"\"\"\"\n",
"prompt = ChatPromptTemplate.from_template(template)\n",
"\n",
"chain = {\n",
" \"context\": itemgetter(\"question\") | retriever, \n",
" \"question\": itemgetter(\"question\"), \n",
" \"language\": itemgetter(\"language\")\n",
"} | prompt | model | StrOutputParser()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "7ee8b2d4",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Number of requested results 4 is greater than number of elements in index 1, updating n_results = 1\n"
]
},
{
"data": {
"text/plain": [
"'Harrison ha lavorato a Kensho.'"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"question\": \"where did harrison work\", \"language\": \"italian\"})"
]
},
{
"cell_type": "markdown",
"id": "0f2bf8d3",
"metadata": {},
"source": [
"## Multiple LLM Chains\n",
"\n",
"This can also be used to string together multiple LLMChains"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "d65d4e9e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'El país en el que nació la ciudad de Honolulu, Hawái, donde nació Barack Obama, el 44º presidente de los Estados Unidos, es Estados Unidos.'"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from operator import itemgetter\n",
"\n",
"prompt1 = ChatPromptTemplate.from_template(\"what is the city {person} is from?\")\n",
"prompt2 = ChatPromptTemplate.from_template(\"what country is the city {city} in? respond in {language}\")\n",
"\n",
"chain1 = prompt1 | model | StrOutputParser()\n",
"\n",
"chain2 = {\"city\": chain1, \"language\": itemgetter(\"language\")} | prompt2 | model | StrOutputParser()\n",
"\n",
"chain2.invoke({\"person\": \"obama\", \"language\": \"spanish\"})"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "878f8176",
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema.runnable import RunnableMap\n",
"prompt1 = ChatPromptTemplate.from_template(\"generate a random color\")\n",
"prompt2 = ChatPromptTemplate.from_template(\"what is a fruit of color: {color}\")\n",
"prompt3 = ChatPromptTemplate.from_template(\"what is countries flag that has the color: {color}\")\n",
"prompt4 = ChatPromptTemplate.from_template(\"What is the color of {fruit} and {country}\")\n",
"chain1 = prompt1 | model | StrOutputParser()\n",
"chain2 = RunnableMap(steps={\"color\": chain1}) | {\n",
" \"fruit\": prompt2 | model | StrOutputParser(),\n",
" \"country\": prompt3 | model | StrOutputParser(),\n",
"} | prompt4"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "d621a870",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ChatPromptValue(messages=[HumanMessage(content=\"What is the color of A fruit that has a color similar to #7E7DE6 is the Peruvian Apple Cactus (Cereus repandus). It is a tropical fruit with a vibrant purple or violet exterior. and The country's flag that has the color #7E7DE6 is North Macedonia.\", additional_kwargs={}, example=False)])"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain2.invoke({})"
]
},
{
"cell_type": "markdown",
"id": "fbc4bf6e",
"metadata": {},
"source": [
"## Arbitrary Functions\n",
"\n",
"You can use arbitrary functions in the pipeline\n",
"\n",
"Note that all inputs to these functions need to be a SINGLE argument. If you have a function that accepts multiple arguments, you should write a wrapper that accepts a single input and unpacks it into multiple argument."
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "6bb221b3",
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema.runnable import RunnableLambda\n",
"\n",
"def length_function(text):\n",
" return len(text)\n",
"\n",
"def _multiple_length_function(text1, text2):\n",
" return len(text1) * len(text2)\n",
"\n",
"def multiple_length_function(_dict):\n",
" return _multiple_length_function(_dict[\"text1\"], _dict[\"text2\"])\n",
"\n",
"prompt = ChatPromptTemplate.from_template(\"what is {a} + {b}\")\n",
"\n",
"chain1 = prompt | model\n",
"\n",
"chain = {\n",
" \"a\": itemgetter(\"foo\") | RunnableLambda(length_function),\n",
" \"b\": {\"text1\": itemgetter(\"foo\"), \"text2\": itemgetter(\"bar\")} | RunnableLambda(multiple_length_function)\n",
"} | prompt | model"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "5488ec85",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='3 + 9 is equal to 12.', additional_kwargs={}, example=False)"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"foo\": \"bar\", \"bar\": \"gah\"})"
]
},
{
"cell_type": "markdown",
"id": "506e9636",
"metadata": {},
"source": [
"## SQL Database\n",
"\n",
"We can also try to replicate our SQLDatabaseChain using this style."
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "7a927516",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n",
"{schema}\n",
"\n",
"Question: {question}\"\"\"\n",
"prompt = ChatPromptTemplate.from_template(template)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "3f51f386",
"metadata": {},
"outputs": [],
"source": [
"from langchain.utilities import SQLDatabase"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "2ccca6fc",
"metadata": {},
"outputs": [],
"source": [
"db = SQLDatabase.from_uri(\"sqlite:///../../../../notebooks/Chinook.db\")"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "05ba88ee",
"metadata": {},
"outputs": [],
"source": [
"def get_schema(_):\n",
" return db.get_table_info()"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "a4eda902",
"metadata": {},
"outputs": [],
"source": [
"def run_query(query):\n",
" return db.run(query)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "5046cb17",
"metadata": {},
"outputs": [],
"source": [
"inputs = {\n",
" \"schema\": RunnableLambda(get_schema),\n",
" \"question\": itemgetter(\"question\")\n",
"}\n",
"sql_response = (\n",
" RunnableMap(inputs)\n",
" | prompt\n",
" | model.bind(stop=[\"\\nSQLResult:\"])\n",
" | StrOutputParser()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "a5552039",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'SELECT COUNT(*) \\nFROM Employee;'"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sql_response.invoke({\"question\": \"How many employees are there?\"})"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "d6fee130",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query: {query}\n",
"SQL Response: {response}\"\"\"\n",
"prompt_response = ChatPromptTemplate.from_template(template)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "923aa634",
"metadata": {},
"outputs": [],
"source": [
"full_chain = (\n",
" RunnableMap({\n",
" \"question\": itemgetter(\"question\"),\n",
" \"query\": sql_response,\n",
" }) \n",
" | {\n",
" \"schema\": RunnableLambda(get_schema),\n",
" \"question\": itemgetter(\"question\"),\n",
" \"query\": itemgetter(\"query\"),\n",
" \"response\": lambda x: db.run(x[\"query\"]) \n",
" } \n",
" | prompt_response \n",
" | model\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "e94963d8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='There are 8 employees.', additional_kwargs={}, example=False)"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"full_chain.invoke({\"question\": \"How many employees are there?\"})"
]
},
{
"cell_type": "markdown",
"id": "f09fd305",
"metadata": {},
"source": [
"## Code Writing"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "bd7c259a",
"metadata": {},
"outputs": [],
"source": [
"from langchain.utilities import PythonREPL\n",
"from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "73795d2d",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Write some python code to solve the user's problem. \n",
"\n",
"Return only python code in Markdown format, eg:\n",
"\n",
"```python\n",
"....\n",
"```\"\"\"\n",
"prompt = ChatPromptTemplate(messages=[\n",
" SystemMessagePromptTemplate.from_template(template),\n",
" HumanMessagePromptTemplate.from_template(\"{input}\")\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 64,
"id": "42859e8a",
"metadata": {},
"outputs": [],
"source": [
"def _sanitize_output(text: str):\n",
" _, after = text.split(\"```python\")\n",
" return after.split(\"```\")[0]"
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "5ded1a86",
"metadata": {},
"outputs": [],
"source": [
"chain = prompt | model | StrOutputParser() | _sanitize_output | PythonREPL().run"
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "208c2b75",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Python REPL can execute arbitrary code. Use with caution.\n"
]
},
{
"data": {
"text/plain": [
"'4\\n'"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"input\": \"whats 2 plus 2\"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9be88499",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}