{ "cells": [ { "cell_type": "markdown", "id": "9a9acd2e", "metadata": {}, "source": [ "# Cookbook\n", "\n", "In this notebook we'll take a look at a few common types of sequences to create." ] }, { "cell_type": "markdown", "id": "93aa2c87", "metadata": {}, "source": [ "## PromptTemplate + LLM\n", "\n", "A PromptTemplate -> LLM is a core chain that is used in most other larger chains/systems." ] }, { "cell_type": "code", "execution_count": 19, "id": "466b65b3", "metadata": {}, "outputs": [], "source": [ "from langchain.prompts import ChatPromptTemplate\n", "from langchain.chat_models import ChatOpenAI" ] }, { "cell_type": "code", "execution_count": 2, "id": "3c634ef0", "metadata": {}, "outputs": [], "source": [ "model = ChatOpenAI()" ] }, { "cell_type": "code", "execution_count": 3, "id": "d1850a1f", "metadata": {}, "outputs": [], "source": [ "prompt = ChatPromptTemplate.from_template(\"tell me a joke about {foo}\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "56d0669f", "metadata": {}, "outputs": [], "source": [ "chain = prompt | model" ] }, { "cell_type": "code", "execution_count": 5, "id": "e3d0a6cd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AIMessage(content='Why don\\'t bears use cell phones? \\n\\nBecause they always get terrible \"grizzly\" reception!', additional_kwargs={}, example=False)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.invoke({\"foo\": \"bears\"})" ] }, { "cell_type": "markdown", "id": "7eb9ef50", "metadata": {}, "source": [ "Often times we want to attach kwargs to the model that's passed in. Here's a few examples of that:" ] }, { "cell_type": "markdown", "id": "0b1d8f88", "metadata": {}, "source": [ "### Attaching Stop Sequences" ] }, { "cell_type": "code", "execution_count": 6, "id": "562a06bf", "metadata": {}, "outputs": [], "source": [ "chain = prompt | model.bind(stop=[\"\\n\"])" ] }, { "cell_type": "code", "execution_count": 7, "id": "43f5d04c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AIMessage(content=\"Why don't bears use cell phones?\", additional_kwargs={}, example=False)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.invoke({\"foo\": \"bears\"})" ] }, { "cell_type": "markdown", "id": "f3eaf88a", "metadata": {}, "source": [ "### Attaching Function Call information" ] }, { "cell_type": "code", "execution_count": 8, "id": "f94b71b2", "metadata": {}, "outputs": [], "source": [ "functions = [\n", " {\n", " \"name\": \"joke\",\n", " \"description\": \"A joke\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"setup\": {\n", " \"type\": \"string\",\n", " \"description\": \"The setup for the joke\"\n", " },\n", " \"punchline\": {\n", " \"type\": \"string\",\n", " \"description\": \"The punchline for the joke\"\n", " }\n", " },\n", " \"required\": [\"setup\", \"punchline\"]\n", " }\n", " }\n", " ]\n", "chain = prompt | model.bind(function_call= {\"name\": \"joke\"}, functions= functions)" ] }, { "cell_type": "code", "execution_count": 9, "id": "decf7710", "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "AIMessage(content='', additional_kwargs={'function_call': {'name': 'joke', 'arguments': '{\\n \"setup\": \"Why don\\'t bears wear shoes?\",\\n \"punchline\": \"Because they have bear feet!\"\\n}'}}, example=False)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.invoke({\"foo\": \"bears\"}, config={})" ] }, { "cell_type": "markdown", "id": "9098c5ed", "metadata": {}, "source": [ "## PromptTemplate + LLM + OutputParser\n", "\n", "We can also add in an output parser to easily trasform the raw LLM/ChatModel output into a more workable format" ] }, { "cell_type": "code", "execution_count": 11, "id": "f799664d", "metadata": {}, "outputs": [], "source": [ "from langchain.schema.output_parser import StrOutputParser" ] }, { "cell_type": "code", "execution_count": 12, "id": "cc194c78", "metadata": {}, "outputs": [], "source": [ "chain = prompt | model | StrOutputParser()" ] }, { "cell_type": "markdown", "id": "77acf448", "metadata": {}, "source": [ "Notice that this now returns a string - a much more workable format for downstream tasks" ] }, { "cell_type": "code", "execution_count": 13, "id": "e3d69a18", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\"Why don't bears wear shoes?\\n\\nBecause they have bear feet!\"" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.invoke({\"foo\": \"bears\"})" ] }, { "cell_type": "markdown", "id": "c01864e5", "metadata": {}, "source": [ "### Functions Output Parser\n", "\n", "When you specify the function to return, you may just want to parse that directly" ] }, { "cell_type": "code", "execution_count": 14, "id": "ad0dd88e", "metadata": {}, "outputs": [], "source": [ "from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser\n", "chain = (\n", " prompt \n", " | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n", " | JsonOutputFunctionsParser()\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "id": "1e7aa8eb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'setup': \"Why don't bears wear shoes?\",\n", " 'punchline': 'Because they have bear feet!'}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.invoke({\"foo\": \"bears\"})" ] }, { "cell_type": "code", "execution_count": 17, "id": "d4aa1a01", "metadata": {}, "outputs": [], "source": [ "from langchain.output_parsers.openai_functions import JsonKeyOutputFunctionsParser\n", "chain = (\n", " prompt \n", " | model.bind(function_call= {\"name\": \"joke\"}, functions= functions) \n", " | JsonKeyOutputFunctionsParser(key_name=\"setup\")\n", ")" ] }, { "cell_type": "code", "execution_count": 18, "id": "8b6df9ba", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\"Why don't bears like fast food?\"" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.invoke({\"foo\": \"bears\"})" ] }, { "cell_type": "markdown", "id": "2ed58136", "metadata": {}, "source": [ "## Passthroughs and itemgetter\n", "\n", "Often times when constructing a chain you may want to pass along original input variables to future steps in the chain. How exactly you do this depends on what exactly the input is:\n", "\n", "- If the original input was a string, then you likely just want to pass along the string. This can be done with `RunnablePassthrough`. For an example of this, see `LLMChain + Retriever`\n", "- If the original input was a dictionary, then you likely want to pass along specific keys. This can be done with `itemgetter`. For an example of this see `Multiple LLM Chains`" ] }, { "cell_type": "code", "execution_count": 19, "id": "5d3d8ffe", "metadata": {}, "outputs": [], "source": [ "from langchain.schema.runnable import RunnablePassthrough\n", "from operator import itemgetter" ] }, { "cell_type": "markdown", "id": "91c5ef3d", "metadata": {}, "source": [ "## LLMChain + Retriever\n", "\n", "Let's now look at adding in a retrieval step, which adds up to a \"retrieval-augmented generation\" chain" ] }, { "cell_type": "code", "execution_count": 20, "id": "33be32af", "metadata": {}, "outputs": [], "source": [ "from langchain.vectorstores import Chroma\n", "from langchain.embeddings import OpenAIEmbeddings\n", "from langchain.schema.runnable import RunnablePassthrough" ] }, { "cell_type": "code", "execution_count": 21, "id": "df3f3fa2", "metadata": {}, "outputs": [], "source": [ "# Create the retriever\n", "vectorstore = Chroma.from_texts([\"harrison worked at kensho\"], embedding=OpenAIEmbeddings())\n", "retriever = vectorstore.as_retriever()" ] }, { "cell_type": "code", "execution_count": 22, "id": "bfc47ec1", "metadata": {}, "outputs": [], "source": [ "template = \"\"\"Answer the question based only on the following context:\n", "{context}\n", "\n", "Question: {question}\n", "\"\"\"\n", "prompt = ChatPromptTemplate.from_template(template)" ] }, { "cell_type": "code", "execution_count": 24, "id": "eae31755", "metadata": {}, "outputs": [], "source": [ "chain = (\n", " {\"context\": retriever, \"question\": RunnablePassthrough()} \n", " | prompt \n", " | model \n", " | StrOutputParser()\n", ")" ] }, { "cell_type": "code", "execution_count": 25, "id": "f3040b0c", "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Number of requested results 4 is greater than number of elements in index 1, updating n_results = 1\n" ] }, { "data": { "text/plain": [ "'Harrison worked at Kensho.'" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.invoke(\"where did harrison work?\")" ] }, { "cell_type": "code", "execution_count": 27, "id": "e1d20c7c", "metadata": {}, "outputs": [], "source": [ "template = \"\"\"Answer the question based only on the following context:\n", "{context}\n", "\n", "Question: {question}\n", "\n", "Answer in the following language: {language}\n", "\"\"\"\n", "prompt = ChatPromptTemplate.from_template(template)\n", "\n", "chain = {\n", " \"context\": itemgetter(\"question\") | retriever, \n", " \"question\": itemgetter(\"question\"), \n", " \"language\": itemgetter(\"language\")\n", "} | prompt | model | StrOutputParser()" ] }, { "cell_type": "code", "execution_count": 28, "id": "7ee8b2d4", "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Number of requested results 4 is greater than number of elements in index 1, updating n_results = 1\n" ] }, { "data": { "text/plain": [ "'Harrison ha lavorato a Kensho.'" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.invoke({\"question\": \"where did harrison work\", \"language\": \"italian\"})" ] }, { "cell_type": "markdown", "id": "0f2bf8d3", "metadata": {}, "source": [ "## Multiple LLM Chains\n", "\n", "This can also be used to string together multiple LLMChains" ] }, { "cell_type": "code", "execution_count": 31, "id": "d65d4e9e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'El país en el que nació la ciudad de Honolulu, Hawái, donde nació Barack Obama, el 44º presidente de los Estados Unidos, es Estados Unidos.'" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from operator import itemgetter\n", "\n", "prompt1 = ChatPromptTemplate.from_template(\"what is the city {person} is from?\")\n", "prompt2 = ChatPromptTemplate.from_template(\"what country is the city {city} in? respond in {language}\")\n", "\n", "chain1 = prompt1 | model | StrOutputParser()\n", "\n", "chain2 = {\"city\": chain1, \"language\": itemgetter(\"language\")} | prompt2 | model | StrOutputParser()\n", "\n", "chain2.invoke({\"person\": \"obama\", \"language\": \"spanish\"})" ] }, { "cell_type": "code", "execution_count": 32, "id": "878f8176", "metadata": {}, "outputs": [], "source": [ "from langchain.schema.runnable import RunnableMap\n", "prompt1 = ChatPromptTemplate.from_template(\"generate a random color\")\n", "prompt2 = ChatPromptTemplate.from_template(\"what is a fruit of color: {color}\")\n", "prompt3 = ChatPromptTemplate.from_template(\"what is countries flag that has the color: {color}\")\n", "prompt4 = ChatPromptTemplate.from_template(\"What is the color of {fruit} and {country}\")\n", "chain1 = prompt1 | model | StrOutputParser()\n", "chain2 = RunnableMap(steps={\"color\": chain1}) | {\n", " \"fruit\": prompt2 | model | StrOutputParser(),\n", " \"country\": prompt3 | model | StrOutputParser(),\n", "} | prompt4" ] }, { "cell_type": "code", "execution_count": 33, "id": "d621a870", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ChatPromptValue(messages=[HumanMessage(content=\"What is the color of A fruit that has a color similar to #7E7DE6 is the Peruvian Apple Cactus (Cereus repandus). It is a tropical fruit with a vibrant purple or violet exterior. and The country's flag that has the color #7E7DE6 is North Macedonia.\", additional_kwargs={}, example=False)])" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain2.invoke({})" ] }, { "cell_type": "markdown", "id": "fbc4bf6e", "metadata": {}, "source": [ "## Arbitrary Functions\n", "\n", "You can use arbitrary functions in the pipeline\n", "\n", "Note that all inputs to these functions need to be a SINGLE argument. If you have a function that accepts multiple arguments, you should write a wrapper that accepts a single input and unpacks it into multiple argument." ] }, { "cell_type": "code", "execution_count": 35, "id": "6bb221b3", "metadata": {}, "outputs": [], "source": [ "from langchain.schema.runnable import RunnableLambda\n", "\n", "def length_function(text):\n", " return len(text)\n", "\n", "def _multiple_length_function(text1, text2):\n", " return len(text1) * len(text2)\n", "\n", "def multiple_length_function(_dict):\n", " return _multiple_length_function(_dict[\"text1\"], _dict[\"text2\"])\n", "\n", "prompt = ChatPromptTemplate.from_template(\"what is {a} + {b}\")\n", "\n", "chain1 = prompt | model\n", "\n", "chain = {\n", " \"a\": itemgetter(\"foo\") | RunnableLambda(length_function),\n", " \"b\": {\"text1\": itemgetter(\"foo\"), \"text2\": itemgetter(\"bar\")} | RunnableLambda(multiple_length_function)\n", "} | prompt | model" ] }, { "cell_type": "code", "execution_count": 36, "id": "5488ec85", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AIMessage(content='3 + 9 is equal to 12.', additional_kwargs={}, example=False)" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.invoke({\"foo\": \"bar\", \"bar\": \"gah\"})" ] }, { "cell_type": "markdown", "id": "506e9636", "metadata": {}, "source": [ "## SQL Database\n", "\n", "We can also try to replicate our SQLDatabaseChain using this style." ] }, { "cell_type": "code", "execution_count": 37, "id": "7a927516", "metadata": {}, "outputs": [], "source": [ "template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n", "{schema}\n", "\n", "Question: {question}\"\"\"\n", "prompt = ChatPromptTemplate.from_template(template)" ] }, { "cell_type": "code", "execution_count": 38, "id": "3f51f386", "metadata": {}, "outputs": [], "source": [ "from langchain.utilities import SQLDatabase" ] }, { "cell_type": "code", "execution_count": 41, "id": "2ccca6fc", "metadata": {}, "outputs": [], "source": [ "db = SQLDatabase.from_uri(\"sqlite:///../../../../notebooks/Chinook.db\")" ] }, { "cell_type": "code", "execution_count": 42, "id": "05ba88ee", "metadata": {}, "outputs": [], "source": [ "def get_schema(_):\n", " return db.get_table_info()" ] }, { "cell_type": "code", "execution_count": 43, "id": "a4eda902", "metadata": {}, "outputs": [], "source": [ "def run_query(query):\n", " return db.run(query)" ] }, { "cell_type": "code", "execution_count": 47, "id": "5046cb17", "metadata": {}, "outputs": [], "source": [ "inputs = {\n", " \"schema\": RunnableLambda(get_schema),\n", " \"question\": itemgetter(\"question\")\n", "}\n", "sql_response = (\n", " RunnableMap(inputs)\n", " | prompt\n", " | model.bind(stop=[\"\\nSQLResult:\"])\n", " | StrOutputParser()\n", " )" ] }, { "cell_type": "code", "execution_count": 48, "id": "a5552039", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'SELECT COUNT(*) \\nFROM Employee;'" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sql_response.invoke({\"question\": \"How many employees are there?\"})" ] }, { "cell_type": "code", "execution_count": 49, "id": "d6fee130", "metadata": {}, "outputs": [], "source": [ "template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n", "{schema}\n", "\n", "Question: {question}\n", "SQL Query: {query}\n", "SQL Response: {response}\"\"\"\n", "prompt_response = ChatPromptTemplate.from_template(template)" ] }, { "cell_type": "code", "execution_count": 52, "id": "923aa634", "metadata": {}, "outputs": [], "source": [ "full_chain = (\n", " RunnableMap({\n", " \"question\": itemgetter(\"question\"),\n", " \"query\": sql_response,\n", " }) \n", " | {\n", " \"schema\": RunnableLambda(get_schema),\n", " \"question\": itemgetter(\"question\"),\n", " \"query\": itemgetter(\"query\"),\n", " \"response\": lambda x: db.run(x[\"query\"]) \n", " } \n", " | prompt_response \n", " | model\n", ")" ] }, { "cell_type": "code", "execution_count": 53, "id": "e94963d8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AIMessage(content='There are 8 employees.', additional_kwargs={}, example=False)" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "full_chain.invoke({\"question\": \"How many employees are there?\"})" ] }, { "cell_type": "markdown", "id": "f09fd305", "metadata": {}, "source": [ "## Code Writing" ] }, { "cell_type": "code", "execution_count": 57, "id": "bd7c259a", "metadata": {}, "outputs": [], "source": [ "from langchain.utilities import PythonREPL\n", "from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate" ] }, { "cell_type": "code", "execution_count": 58, "id": "73795d2d", "metadata": {}, "outputs": [], "source": [ "template = \"\"\"Write some python code to solve the user's problem. \n", "\n", "Return only python code in Markdown format, eg:\n", "\n", "```python\n", "....\n", "```\"\"\"\n", "prompt = ChatPromptTemplate(messages=[\n", " SystemMessagePromptTemplate.from_template(template),\n", " HumanMessagePromptTemplate.from_template(\"{input}\")\n", "])" ] }, { "cell_type": "code", "execution_count": 64, "id": "42859e8a", "metadata": {}, "outputs": [], "source": [ "def _sanitize_output(text: str):\n", " _, after = text.split(\"```python\")\n", " return after.split(\"```\")[0]" ] }, { "cell_type": "code", "execution_count": 67, "id": "5ded1a86", "metadata": {}, "outputs": [], "source": [ "chain = prompt | model | StrOutputParser() | _sanitize_output | PythonREPL().run" ] }, { "cell_type": "code", "execution_count": 68, "id": "208c2b75", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Python REPL can execute arbitrary code. Use with caution.\n" ] }, { "data": { "text/plain": [ "'4\\n'" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.invoke({\"input\": \"whats 2 plus 2\"})" ] }, { "cell_type": "code", "execution_count": null, "id": "9be88499", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.1" } }, "nbformat": 4, "nbformat_minor": 5 }