You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/cookbook/langgraph_self_rag.ipynb

666 lines
202 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "a384cc48-0425-4e8f-aafc-cfb8e56025c9",
"metadata": {},
"outputs": [],
"source": [
"! pip install langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph"
]
},
{
"attachments": {
"ea6a57d2-f2ec-4061-840a-98deb3207248.png": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA20AAAFMCAYAAABYnVRwAAAMP2lDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBCCSAgJfQmCEgJICWEFkB6EWyEJEAoMQaCiB1dVHDtYgEbuiqi2AGxI3YWwd4XRRSUdbFgV96kgK77yvfO9829//3nzH/OnDu3DADqp7hicQ6qAUCuKF8SGxLAGJucwiB1AwTggAYIgMDl5YlZ0dERANrg+e/27ib0hnbNQab1z/7/app8QR4PACQa4jR+Hi8X4kMA4JU8sSQfAKKMN5+aL5Zh2IC2BCYI8UIZzlDgShlOU+B9cp/4WDbEzQCoqHG5kgwAaG2QZxTwMqAGrQ9iJxFfKAJAnQGxb27uZD7EqRDbQB8xxDJ9ZtoPOhl/00wb0uRyM4awYi5yUwkU5olzuNP+z3L8b8vNkQ7GsIJNLVMSGiubM6zb7ezJ4TKsBnGvKC0yCmItiD8I+XJ/iFFKpjQ0QeGPGvLy2LBmQBdiJz43MBxiQ4iDRTmREUo+LV0YzIEYrhC0UJjPiYdYD+KFgrygOKXPZsnkWGUstC5dwmYp+QtciTyuLNZDaXYCS6n/OlPAUepjtKLM+CSIKRBbFAgTIyGmQeyYlx0XrvQZXZTJjhz0kUhjZflbQBwrEIUEKPSxgnRJcKzSvzQ3b3C+2OZMISdSiQ/kZ8aHKuqDNfO48vzhXLA2gYiVMKgjyBsbMTgXviAwSDF3rFsgSohT6nwQ5wfEKsbiFHFOtNIfNxPkhMh4M4hd8wrilGPxxHy4IBX6eLo4PzpekSdelMUNi1bkgy8DEYANAgEDSGFLA5NBFhC29tb3witFTzDgAgnIAALgoGQGRyTJe0TwGAeKwJ8QCUDe0LgAea8AFED+6xCrODqAdHlvgXxENngKcS4IBznwWiofJRqKlgieQEb4j+hc2Hgw3xzYZP3/nh9kvzMsyEQoGelgRIb6oCcxiBhIDCUGE21xA9wX98Yj4NEfNheciXsOzuO7P+EpoZ3wmHCD0EG4M0lYLPkpyzGgA+oHK2uR9mMtcCuo6YYH4D5QHSrjurgBcMBdYRwW7gcju0GWrcxbVhXGT9p/m8EPd0PpR3Yio+RhZH+yzc8jaXY0tyEVWa1/rI8i17SherOHen6Oz/6h+nx4Dv/ZE1uIHcTOY6exi9gxrB4wsJNYA9aCHZfhodX1RL66BqPFyvPJhjrCf8QbvLOySuY51Tj1OH1R9OULCmXvaMCeLJ4mEWZk5jNY8IsgYHBEPMcRDBcnF1cAZN8XxevrTYz8u4Hotnzn5v0BgM/JgYGBo9+5sJMA7PeAj/+R75wNE346VAG4cIQnlRQoOFx2IMC3hDp80vSBMTAHNnA+LsAdeAN/EATCQBSIB8lgIsw+E65zCZgKZoC5oASUgWVgNVgPNoGtYCfYAw6AenAMnAbnwGXQBm6Ae3D1dIEXoA+8A58RBCEhVISO6CMmiCVij7ggTMQXCUIikFgkGUlFMhARIkVmIPOQMmQFsh7ZglQj+5EjyGnkItKO3EEeIT3Ia+QTiqFqqDZqhFqhI1EmykLD0Xh0ApqBTkGL0PnoEnQtWoXuRuvQ0+hl9Abagb5A+zGAqWK6mCnmgDExNhaFpWDpmASbhZVi5VgVVos1wvt8DevAerGPOBGn4wzcAa7gUDwB5+FT8Fn4Ynw9vhOvw5vxa/gjvA//RqASDAn2BC8ChzCWkEGYSighlBO2Ew4TzsJnqYvwjkgk6hKtiR7wWUwmZhGnExcTNxD3Ek8R24mdxH4SiaRPsif5kKJIXFI+qYS0jrSbdJJ0ldRF+qCiqmKi4qISrJKiIlIpVilX2aVyQuWqyjOVz2QNsiXZixxF5pOnkZeSt5EbyVfIXeTPFE2KNcWHEk/JosylrKXUUs5S7lPeqKqqmql6qsaoClXnqK5V3ad6QfWR6kc1LTU7NbbaeDWp2hK1HWqn1O6ovaFSqVZUf2oKNZ+6hFpNPUN9SP1Ao9McaRwanzabVkGro12lvVQnq1uqs9Qnqhepl6sfVL+i3qtB1rDSYGtwNWZpVGgc0bil0a9J13TWjNLM1VysuUvzoma3FknLSitIi681X2ur1hmtTjpGN6ez6Tz6PPo2+ll6lzZR21qbo52lXaa9R7tVu09HS8dVJ1GnUKdC57hOhy6ma6XL0c3RXap7QPem7qdhRsNYwwTDFg2rHXZ12Hu94Xr+egK9Ur29ejf0Pukz9IP0s/WX69frPzDADewMYgymGmw0OGvQO1x7uPdw3vDS4QeG3zVEDe0MYw2nG241bDHsNzI2CjESG60zOmPUa6xr7G+cZbzK+IRxjwndxNdEaLLK5KTJc4YOg8XIYaxlNDP6TA1NQ02lpltMW00/m1mbJZgVm+01e2BOMWeap5uvMm8y77MwsRhjMcOixuKuJdmSaZlpucbyvOV7K2urJKsFVvVW3dZ61hzrIusa6/s2VBs/myk2VTbXbYm2TNts2w22bXaonZtdpl2F3RV71N7dXmi/wb59BGGE5wjRiKoRtxzUHFgOBQ41Do8cdR0jHIsd6x1fjrQYmTJy+cjzI785uTnlOG1zuues5RzmXOzc6Pzaxc6F51Lhcn0UdVTwqNmjGka9crV3FbhudL3tRncb47bArcntq7uHu8S91r3Hw8Ij1aPS4xZTmxnNXMy84EnwDPCc7XnM86OXu1e+1wGvv7wdvLO9d3l3j7YeLRi9bXSnj5kP12eLT4cvwzfVd7Nvh5+pH9evyu+xv7k/33+7/zOWLSuLtZv1MsApQBJwOOA924s9k30qEAsMCSwNbA3SCkoIWh/0MNgsOCO4JrgvxC1kesipUEJoeOjy0FscIw6PU83pC/MImxnWHK4WHhe+PvxxhF2EJKJxDDombMzKMfcjLSNFkfVRIIoTtTLqQbR19JToozHEmOiYipinsc6xM2LPx9HjJsXtinsXHxC/NP5egk2CNKEpUT1xfGJ14vukwKQVSR1jR46dOfZyskGyMLkhhZSSmLI9pX9c0LjV47rGu40vGX9zgvWEwgkXJxpMzJl4fJL6JO6kg6mE1KTUXalfuFHcKm5/GietMq2Px+at4b3g+/NX8XsEPoIVgmfpPukr0rszfDJWZvRk+mWWZ/YK2cL1wldZoVmbst5nR2XvyB7IScrZm6uSm5p7RKQlyhY1TzaeXDi5XWwvLhF3TPGasnpKnyRcsj0PyZuQ15CvDX/kW6Q20l+kjwp8CyoKPkxNnHqwULNQVNgyzW7aomnPioKLfpuOT+dNb5phOmPujEczWTO3zEJmpc1qmm0+e/7srjkhc3bOpczNnvt7sVPxiuK385LmNc43mj9nfucvIb/UlNBKJCW3Fngv2LQQXyhc2Lpo1KJ1i76V8ksvlTmVlZd9WcxbfOlX51/X/jqwJH1J61L3pRuXEZeJlt1c7rd85wrNFUUrOleOWVm3irGqdNXb1ZNWXyx3Ld+0hrJGuqZjbcTahnUW65at+7I+c/2NioCKvZWGlYsq32/gb7i60X9j7SajTWWbPm0Wbr69JWRLXZVVVflW4taCrU+3JW47/xvzt+rtBtvLtn/dIdrRsTN2Z3O1R3X1LsNdS2vQGmlNz+7xu9v2BO5pqHWo3bJXd2/ZPrBPuu/5/tT9Nw+EH2g6yDxYe8jyUOVh+uHSOqRuWl1ffWZ9R0NyQ/uRsCNNjd6Nh486Ht1xzPRYxXGd40tPUE7MPzFwsuhk/ynxqd7TGac7myY13Tsz9sz15pjm1rPhZy+cCz535jzr/MkLPheOXfS6eOQS81L9ZffLdS1uLYd/d/v9cKt7a90VjysNbZ5tje2j209c9bt6+lrgtXPXOdcv34i80X4z4ebtW+Nvddzm3+6+k3Pn1d2Cu5/vzblPuF/6QONB+UPDh1V/2P6xt8O94/ijwEctj+Me3+vkdb54kvfkS9f8p9Sn5c9MnlV3u3Qf6wnuaXs+7nnXC/GLz70lf2r+WfnS5uWhv/z/aukb29f1SvJq4PXiN/pvdrx1fdvUH93/8F3uu8/vSz/of9j5kfnx/KekT88+T/1C+rL2q+3Xxm/h
}
},
"cell_type": "markdown",
"id": "919fe33c-0149-4f7d-b200-544a18986c9a",
"metadata": {},
"source": [
"# Self-RAG\n",
"\n",
"Self-RAG is a recent paper that introduces an interesting approach for active RAG. \n",
"\n",
"The framework trains a single arbitrary LM (LLaMA2-7b, 13b) to generate tokens that govern the RAG process:\n",
"\n",
"1. Should I retrieve from retriever, `R` -\n",
"\n",
"* Token: `Retrieve`\n",
"* Input: `x (question)` OR `x (question)`, `y (generation)`\n",
"* Decides when to retrieve `D` chunks with `R`\n",
"* Output: `yes, no, continue`\n",
"\n",
"2. Are the retrieved passages `D` relevant to the question `x` -\n",
"\n",
"* Token: `ISREL`\n",
"* * Input: (`x (question)`, `d (chunk)`) for `d` in `D`\n",
"* `d` provides useful information to solve `x`\n",
"* Output: `relevant, irrelevant`\n",
"\n",
"\n",
"3. Are the LLM generation from each chunk in `D` is relevant to the chunk (hallucinations, etc) -\n",
"\n",
"* Token: `ISSUP`\n",
"* Input: `x (question)`, `d (chunk)`, `y (generation)` for `d` in `D`\n",
"* All of the verification-worthy statements in `y (generation)` are supported by `d`\n",
"* Output: `{fully supported, partially supported, no support`\n",
"\n",
"4. The LLM generation from each chunk in `D` is a useful response to `x (question)` -\n",
"\n",
"* Token: `ISUSE`\n",
"* Input: `x (question)`, `y (generation)` for `d` in `D`\n",
"* `y (generation)` is a useful response to `x (question)`.\n",
"* Output: `{5, 4, 3, 2, 1}`\n",
"\n",
"We can represent this as a graph:\n",
"\n",
"![Screenshot 2024-02-02 at 1.36.44 PM.png](attachment:ea6a57d2-f2ec-4061-840a-98deb3207248.png)\n",
"\n",
"Paper -\n",
"\n",
"https://arxiv.org/abs/2310.11511\n",
"\n",
"---\n",
"\n",
"Let's implement this from scratch using [LangGraph](https://python.langchain.com/docs/langgraph)."
]
},
{
"cell_type": "markdown",
"id": "c27bebdc-be71-4130-ab9d-42f09f87658b",
"metadata": {},
"source": [
"## Retriever\n",
" \n",
"Let's index 3 blog posts."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "565a6d44-2c9f-4fff-b1ec-eea05df9350d",
"metadata": {},
"outputs": [],
"source": [
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from langchain_community.document_loaders import WebBaseLoader\n",
"from langchain_community.vectorstores import Chroma\n",
"from langchain_openai import OpenAIEmbeddings\n",
"\n",
"urls = [\n",
" \"https://lilianweng.github.io/posts/2023-06-23-agent/\",\n",
" \"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/\",\n",
" \"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/\",\n",
"]\n",
"\n",
"docs = [WebBaseLoader(url).load() for url in urls]\n",
"docs_list = [item for sublist in docs for item in sublist]\n",
"\n",
"text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
" chunk_size=250, chunk_overlap=0\n",
")\n",
"doc_splits = text_splitter.split_documents(docs_list)\n",
"\n",
"# Add to vectorDB\n",
"vectorstore = Chroma.from_documents(\n",
" documents=doc_splits,\n",
" collection_name=\"rag-chroma\",\n",
" embedding=OpenAIEmbeddings(),\n",
")\n",
"retriever = vectorstore.as_retriever()"
]
},
{
"cell_type": "markdown",
"id": "276001c5-c079-4e5b-9f42-81a06704d200",
"metadata": {},
"source": [
"## State\n",
" \n",
"We will define a graph.\n",
"\n",
"Our state will be a `dict`.\n",
"\n",
"We can access this from any graph node as `state['keys']`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1617e9e-66a8-4c1a-a1fe-cc936284c085",
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict, TypedDict\n",
"\n",
"from langchain_core.messages import BaseMessage\n",
"\n",
"\n",
"class GraphState(TypedDict):\n",
" \"\"\"\n",
" Represents the state of an agent in the conversation.\n",
"\n",
" Attributes:\n",
" keys: A dictionary where each key is a string and the value is expected to be a list or another structure\n",
" that supports addition with `operator.add`. This could be used, for instance, to accumulate messages\n",
" or other pieces of data throughout the graph.\n",
" \"\"\"\n",
"\n",
" keys: Dict[str, any]"
]
},
{
"attachments": {
"e61fbd0c-e667-4160-a96c-82f95a560b44.png": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABGwAAAG8CAYAAACL9jPUAAAMP2lDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBCCSAgJfQmCEgJICWEFkB6EWyEJEAoMQaCiB1dVHDtYgEbuiqi2AGxI3YWwd4XRRSUdbFgV96kgK77yvfO9829//3nzH/OnDu3DADqp7hicQ6qAUCuKF8SGxLAGJucwiB1AwTggAYIgMDl5YlZ0dERANrg+e/27ib0hnbNQab1z/7/app8QR4PACQa4jR+Hi8X4kMA4JU8sSQfAKKMN5+aL5Zh2IC2BCYI8UIZzlDgShlOU+B9cp/4WDbEzQCoqHG5kgwAaG2QZxTwMqAGrQ9iJxFfKAJAnQGxb27uZD7EqRDbQB8xxDJ9ZtoPOhl/00wb0uRyM4awYi5yUwkU5olzuNP+z3L8b8vNkQ7GsIJNLVMSGiubM6zb7ezJ4TKsBnGvKC0yCmItiD8I+XJ/iFFKpjQ0QeGPGvLy2LBmQBdiJz43MBxiQ4iDRTmREUo+LV0YzIEYrhC0UJjPiYdYD+KFgrygOKXPZsnkWGUstC5dwmYp+QtciTyuLNZDaXYCS6n/OlPAUepjtKLM+CSIKRBbFAgTIyGmQeyYlx0XrvQZXZTJjhz0kUhjZflbQBwrEIUEKPSxgnRJcKzSvzQ3b3C+2OZMISdSiQ/kZ8aHKuqDNfO48vzhXLA2gYiVMKgjyBsbMTgXviAwSDF3rFsgSohT6nwQ5wfEKsbiFHFOtNIfNxPkhMh4M4hd8wrilGPxxHy4IBX6eLo4PzpekSdelMUNi1bkgy8DEYANAgEDSGFLA5NBFhC29tb3witFTzDgAgnIAALgoGQGRyTJe0TwGAeKwJ8QCUDe0LgAea8AFED+6xCrODqAdHlvgXxENngKcS4IBznwWiofJRqKlgieQEb4j+hc2Hgw3xzYZP3/nh9kvzMsyEQoGelgRIb6oCcxiBhIDCUGE21xA9wX98Yj4NEfNheciXsOzuO7P+EpoZ3wmHCD0EG4M0lYLPkpyzGgA+oHK2uR9mMtcCuo6YYH4D5QHSrjurgBcMBdYRwW7gcju0GWrcxbVhXGT9p/m8EPd0PpR3Yio+RhZH+yzc8jaXY0tyEVWa1/rI8i17SherOHen6Oz/6h+nx4Dv/ZE1uIHcTOY6exi9gxrB4wsJNYA9aCHZfhodX1RL66BqPFyvPJhjrCf8QbvLOySuY51Tj1OH1R9OULCmXvaMCeLJ4mEWZk5jNY8IsgYHBEPMcRDBcnF1cAZN8XxevrTYz8u4Hotnzn5v0BgM/JgYGBo9+5sJMA7PeAj/+R75wNE346VAG4cIQnlRQoOFx2IMC3hDp80vSBMTAHNnA+LsAdeAN/EATCQBSIB8lgIsw+E65zCZgKZoC5oASUgWVgNVgPNoGtYCfYAw6AenAMnAbnwGXQBm6Ae3D1dIEXoA+8A58RBCEhVISO6CMmiCVij7ggTMQXCUIikFgkGUlFMhARIkVmIPOQMmQFsh7ZglQj+5EjyGnkItKO3EEeIT3Ia+QTiqFqqDZqhFqhI1EmykLD0Xh0ApqBTkGL0PnoEnQtWoXuRuvQ0+hl9Abagb5A+zGAqWK6mCnmgDExNhaFpWDpmASbhZVi5VgVVos1wvt8DevAerGPOBGn4wzcAa7gUDwB5+FT8Fn4Ynw9vhOvw5vxa/gjvA//RqASDAn2BC8ChzCWkEGYSighlBO2Ew4TzsJnqYvwjkgk6hKtiR7wWUwmZhGnExcTNxD3Ek8R24mdxH4SiaRPsif5kKJIXFI+qYS0jrSbdJJ0ldRF+qCiqmKi4qISrJKiIlIpVilX2aVyQuWqyjOVz2QNsiXZixxF5pOnkZeSt5EbyVfIXeTPFE2KNcWHEk/JosylrKXUUs5S7lPeqKqqmql6qsaoClXnqK5V3ad6QfWR6kc1LTU7NbbaeDWp2hK1HWqn1O6ovaFSqVZUf2oKNZ+6hFpNPUN9SP1Ao9McaRwanzabVkGro12lvVQnq1uqs9Qnqhepl6sfVL+i3qtB1rDSYGtwNWZpVGgc0bil0a9J13TWjNLM1VysuUvzoma3FknLSitIi681X2ur1hmtTjpGN6ez6Tz6PPo2+ll6lzZR21qbo52lXaa9R7tVu09HS8dVJ1GnUKdC57hOhy6ma6XL0c3RXap7QPem7qdhRsNYwwTDFg2rHXZ12Hu94Xr+egK9Ur29ejf0Pukz9IP0s/WX69frPzDADewMYgymGmw0OGvQO1x7uPdw3vDS4QeG3zVEDe0MYw2nG241bDHsNzI2CjESG60zOmPUa6xr7G+cZbzK+IRxjwndxNdEaLLK5KTJc4YOg8XIYaxlNDP6TA1NQ02lpltMW00/m1mbJZgVm+01e2BOMWeap5uvMm8y77MwsRhjMcOixuKuJdmSaZlpucbyvOV7K2urJKsFVvVW3dZ61hzrIusa6/s2VBs/myk2VTbXbYm2TNts2w22bXaonZtdpl2F3RV71N7dXmi/wb59BGGE5wjRiKoRtxzUHFgOBQ41Do8cdR0jHIsd6x1fjrQYmTJy+cjzI785uTnlOG1zuues5RzmXOzc6Pzaxc6F51Lhcn0UdVTwqNmjGka9crV3FbhudL3tRncb47bArcntq7uHu8S91r3Hw8Ij1aPS4xZTmxnNXMy84EnwDPCc7XnM86OXu1e+1wGvv7wdvLO9d3l3j7YeLRi9bXSnj5kP12eLT4cvwzfVd7Nvh5+pH9evyu+xv7k/33+7/zOWLSuLtZv1MsApQBJwOOA924s9k30qEAsMCSwNbA3SCkoIWh/0MNgsOCO4JrgvxC1kesipUEJoeOjy0FscIw6PU83pC/MImxnWHK4WHhe+PvxxhF2EJKJxDDombMzKMfcjLSNFkfVRIIoTtTLqQbR19JToozHEmOiYipinsc6xM2LPx9HjJsXtinsXHxC/NP5egk2CNKEpUT1xfGJ14vukwKQVSR1jR46dOfZyskGyMLkhhZSSmLI9pX9c0LjV47rGu40vGX9zgvWEwgkXJxpMzJl4fJL6JO6kg6mE1KTUXalfuFHcKm5/GietMq2Px+at4b3g+/NX8XsEPoIVgmfpPukr0rszfDJWZvRk+mWWZ/YK2cL1wldZoVmbst5nR2XvyB7IScrZm6uSm5p7RKQlyhY1TzaeXDi5XWwvLhF3TPGasnpKnyRcsj0PyZuQ15CvDX/kW6Q20l+kjwp8CyoKPkxNnHqwULNQVNgyzW7aomnPioKLfpuOT+dNb5phOmPujEczWTO3zEJmpc1qmm0+e/7srjkhc3bOpczNnvt7sVPxiuK385LmNc43mj9nfucvIb/UlNBKJCW3Fngv2LQQXyhc2Lpo1KJ1i76V8ksvlTmVlZd9WcxbfOlX51/X/jqwJH1J61L3pRuXEZeJlt1c7rd85wrNFUUrOleOWVm3irGqdNXb1ZNWXyx3Ld+0hrJGuqZjbcTahnUW65at+7I+c/2NioCKvZWGlYsq32/gb7i60X9j7SajTWWbPm0Wbr69JWRLXZVVVflW4taCrU+3JW47/xvzt+rtBtvLtn/dIdrRsTN2Z3O1R3X1LsNdS2vQGmlNz+7xu9v2BO5pqHWo3bJXd2/ZPrBPuu/5/tT9Nw+EH2g6yDxYe8jyUOVh+uHSOqRuWl1ffWZ9R0NyQ/uRsCNNjd6Nh486Ht1xzPRYxXGd40tPUE7MPzFwsuhk/ynxqd7TGac7myY13Tsz9sz15pjm1rPhZy+cCz535jzr/MkLPheOXfS6eOQS81L9ZffLdS1uLYd/d/v9cKt7a90VjysNbZ5tje2j209c9bt6+lrgtXPXOdcv34i80X4z4ebtW+Nvddzm3+6+k3Pn1d2Cu5/vzblPuF/6QONB+UPDh1V/2P6xt8O94/ijwEctj+Me3+vkdb54kvfkS9f8p9Sn5c9MnlV3u3Qf6wnuaXs+7nnXC/GLz70lf2r+WfnS5uWhv/z/aukb29f1SvJq4PXiN/pvdrx1fdvUH93/8F3uu8/vSz/of9j5kfnx/KekT88+T/1C+rL2q+3Xxm/h
}
},
"cell_type": "markdown",
"id": "251feeea-c9a0-404a-8b55-bef3020bb5e2",
"metadata": {},
"source": [
"## Nodes and Edges\n",
"\n",
"Each `node` will simply modify the `state`.\n",
"\n",
"Each `edge` will choose which `node` to call next.\n",
"\n",
"We can lay out `self-RAG` as a graph:\n",
"\n",
"![Screenshot 2024-02-02 at 9.01.01 PM.png](attachment:e61fbd0c-e667-4160-a96c-82f95a560b44.png)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "add509d8-6682-4127-8d95-13dd37d79702",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import operator\n",
"from typing import Annotated, Sequence, TypedDict\n",
"\n",
"from langchain import hub\n",
"from langchain.output_parsers import PydanticOutputParser\n",
"from langchain.output_parsers.openai_tools import PydanticToolsParser\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain_community.vectorstores import Chroma\n",
"from langchain_core.messages import BaseMessage, FunctionMessage\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"from langchain_core.utils.function_calling import convert_to_openai_tool\n",
"from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n",
"from langgraph.prebuilt import ToolInvocation\n",
"\n",
"### Nodes ###\n",
"\n",
"\n",
"def retrieve(state):\n",
" \"\"\"\n",
" Retrieve documents\n",
"\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
"\n",
" Returns:\n",
" dict: New key added to state, documents, that contains documents.\n",
" \"\"\"\n",
" print(\"---RETRIEVE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = retriever.invoke(question)\n",
" return {\"keys\": {\"documents\": documents, \"question\": question}}\n",
"\n",
"\n",
"def generate(state):\n",
" \"\"\"\n",
" Generate answer\n",
"\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
"\n",
" Returns:\n",
" dict: New key added to state, generation, that contains generation.\n",
" \"\"\"\n",
" print(\"---GENERATE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
"\n",
" # Prompt\n",
" prompt = hub.pull(\"rlm/rag-prompt\")\n",
"\n",
" # LLM\n",
" llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n",
"\n",
" # Post-processing\n",
" def format_docs(docs):\n",
" return \"\\n\\n\".join(doc.page_content for doc in docs)\n",
"\n",
" # Chain\n",
" rag_chain = prompt | llm | StrOutputParser()\n",
"\n",
" # Run\n",
" generation = rag_chain.invoke({\"context\": documents, \"question\": question})\n",
" return {\n",
" \"keys\": {\"documents\": documents, \"question\": question, \"generation\": generation}\n",
" }\n",
"\n",
"\n",
"def grade_documents(state):\n",
" \"\"\"\n",
" Determines whether the retrieved documents are relevant to the question.\n",
"\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
"\n",
" Returns:\n",
" dict: New key added to state, filtered_documents, that contains relevant documents.\n",
" \"\"\"\n",
"\n",
" print(\"---CHECK RELEVANCE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
"\n",
" # Data model\n",
" class grade(BaseModel):\n",
" \"\"\"Binary score for relevance check.\"\"\"\n",
"\n",
" binary_score: str = Field(description=\"Relevance score 'yes' or 'no'\")\n",
"\n",
" # LLM\n",
" model = ChatOpenAI(temperature=0, model=\"gpt-4-0125-preview\", streaming=True)\n",
"\n",
" # Tool\n",
" grade_tool_oai = convert_to_openai_tool(grade)\n",
"\n",
" # LLM with tool and enforce invocation\n",
" llm_with_tool = model.bind(\n",
" tools=[convert_to_openai_tool(grade_tool_oai)],\n",
" tool_choice={\"type\": \"function\", \"function\": {\"name\": \"grade\"}},\n",
" )\n",
"\n",
" # Parser\n",
" parser_tool = PydanticToolsParser(tools=[grade])\n",
"\n",
" # Prompt\n",
" prompt = PromptTemplate(\n",
" template=\"\"\"You are a grader assessing relevance of a retrieved document to a user question. \\n \n",
" Here is the retrieved document: \\n\\n {context} \\n\\n\n",
" Here is the user question: {question} \\n\n",
" If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \\n\n",
" Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.\"\"\",\n",
" input_variables=[\"context\", \"question\"],\n",
" )\n",
"\n",
" # Chain\n",
" chain = prompt | llm_with_tool | parser_tool\n",
"\n",
" # Score\n",
" filtered_docs = []\n",
" for d in documents:\n",
" score = chain.invoke({\"question\": question, \"context\": d.page_content})\n",
" grade = score[0].binary_score\n",
" if grade == \"yes\":\n",
" print(\"---GRADE: DOCUMENT RELEVANT---\")\n",
" filtered_docs.append(d)\n",
" else:\n",
" print(\"---GRADE: DOCUMENT NOT RELEVANT---\")\n",
" continue\n",
"\n",
" return {\"keys\": {\"documents\": filtered_docs, \"question\": question}}\n",
"\n",
"\n",
"def transform_query(state):\n",
" \"\"\"\n",
" Transform the query to produce a better question.\n",
"\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
"\n",
" Returns:\n",
" dict: New value saved to question.\n",
" \"\"\"\n",
"\n",
" print(\"---TRANSFORM QUERY---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
"\n",
" # Create a prompt template with format instructions and the query\n",
" prompt = PromptTemplate(\n",
" template=\"\"\"You are generating questions that is well optimized for retrieval. \\n \n",
" Look at the input and try to reason about the underlying sematic intent / meaning. \\n \n",
" Here is the initial question:\n",
" \\n ------- \\n\n",
" {question} \n",
" \\n ------- \\n\n",
" Formulate an improved question: \"\"\",\n",
" input_variables=[\"question\"],\n",
" )\n",
"\n",
" # Grader\n",
" model = ChatOpenAI(temperature=0, model=\"gpt-4-0125-preview\", streaming=True)\n",
"\n",
" # Prompt\n",
" chain = prompt | model | StrOutputParser()\n",
" better_question = chain.invoke({\"question\": question})\n",
"\n",
" return {\"keys\": {\"documents\": documents, \"question\": better_question}}\n",
"\n",
"\n",
"def prepare_for_final_grade(state):\n",
" \"\"\"\n",
" Stage for final grade, passthrough state.\n",
"\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
"\n",
" Returns:\n",
" state (dict): The current state of the agent, including all keys.\n",
" \"\"\"\n",
"\n",
" print(\"---FINAL GRADE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
" generation = state_dict[\"generation\"]\n",
"\n",
" return {\n",
" \"keys\": {\"documents\": documents, \"question\": question, \"generation\": generation}\n",
" }\n",
"\n",
"\n",
"### Edges ###\n",
"\n",
"\n",
"def decide_to_generate(state):\n",
" \"\"\"\n",
" Determines whether to generate an answer, or re-generate a question.\n",
"\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
"\n",
" Returns:\n",
" dict: New key added to state, filtered_documents, that contains relevant documents.\n",
" \"\"\"\n",
"\n",
" print(\"---DECIDE TO GENERATE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" filtered_documents = state_dict[\"documents\"]\n",
"\n",
" if not filtered_documents:\n",
" # All documents have been filtered check_relevance\n",
" # We will re-generate a new query\n",
" print(\"---DECISION: TRANSFORM QUERY---\")\n",
" return \"transform_query\"\n",
" else:\n",
" # We have relevant documents, so generate answer\n",
" print(\"---DECISION: GENERATE---\")\n",
" return \"generate\"\n",
"\n",
"\n",
"def grade_generation_v_documents(state):\n",
" \"\"\"\n",
" Determines whether the generation is grounded in the document.\n",
"\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
"\n",
" Returns:\n",
" str: Binary decision score.\n",
" \"\"\"\n",
"\n",
" print(\"---GRADE GENERATION vs DOCUMENTS---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
" generation = state_dict[\"generation\"]\n",
"\n",
" # Data model\n",
" class grade(BaseModel):\n",
" \"\"\"Binary score for relevance check.\"\"\"\n",
"\n",
" binary_score: str = Field(description=\"Supported score 'yes' or 'no'\")\n",
"\n",
" # LLM\n",
" model = ChatOpenAI(temperature=0, model=\"gpt-4-0125-preview\", streaming=True)\n",
"\n",
" # Tool\n",
" grade_tool_oai = convert_to_openai_tool(grade)\n",
"\n",
" # LLM with tool and enforce invocation\n",
" llm_with_tool = model.bind(\n",
" tools=[convert_to_openai_tool(grade_tool_oai)],\n",
" tool_choice={\"type\": \"function\", \"function\": {\"name\": \"grade\"}},\n",
" )\n",
"\n",
" # Parser\n",
" parser_tool = PydanticToolsParser(tools=[grade])\n",
"\n",
" # Prompt\n",
" prompt = PromptTemplate(\n",
" template=\"\"\"You are a grader assessing whether an answer is grounded in / supported by a set of facts. \\n \n",
" Here are the facts:\n",
" \\n ------- \\n\n",
" {documents} \n",
" \\n ------- \\n\n",
" Here is the answer: {generation}\n",
" Give a binary score 'yes' or 'no' to indicate whether the answer is grounded in / supported by a set of facts.\"\"\",\n",
" input_variables=[\"generation\", \"documents\"],\n",
" )\n",
"\n",
" # Chain\n",
" chain = prompt | llm_with_tool | parser_tool\n",
"\n",
" score = chain.invoke({\"generation\": generation, \"documents\": documents})\n",
" grade = score[0].binary_score\n",
"\n",
" if grade == \"yes\":\n",
" print(\"---DECISION: SUPPORTED, MOVE TO FINAL GRADE---\")\n",
" return \"supported\"\n",
" else:\n",
" print(\"---DECISION: NOT SUPPORTED, GENERATE AGAIN---\")\n",
" return \"not supported\"\n",
"\n",
"\n",
"def grade_generation_v_question(state):\n",
" \"\"\"\n",
" Determines whether the generation addresses the question.\n",
"\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
"\n",
" Returns:\n",
" str: Binary decision score.\n",
" \"\"\"\n",
"\n",
" print(\"---GRADE GENERATION vs QUESTION---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
" generation = state_dict[\"generation\"]\n",
"\n",
" # Data model\n",
" class grade(BaseModel):\n",
" \"\"\"Binary score for relevance check.\"\"\"\n",
"\n",
" binary_score: str = Field(description=\"Useful score 'yes' or 'no'\")\n",
"\n",
" # LLM\n",
" model = ChatOpenAI(temperature=0, model=\"gpt-4-0125-preview\", streaming=True)\n",
"\n",
" # Tool\n",
" grade_tool_oai = convert_to_openai_tool(grade)\n",
"\n",
" # LLM with tool and enforce invocation\n",
" llm_with_tool = model.bind(\n",
" tools=[convert_to_openai_tool(grade_tool_oai)],\n",
" tool_choice={\"type\": \"function\", \"function\": {\"name\": \"grade\"}},\n",
" )\n",
"\n",
" # Parser\n",
" parser_tool = PydanticToolsParser(tools=[grade])\n",
"\n",
" # Prompt\n",
" prompt = PromptTemplate(\n",
" template=\"\"\"You are a grader assessing whether an answer is useful to resolve a question. \\n \n",
" Here is the answer:\n",
" \\n ------- \\n\n",
" {generation} \n",
" \\n ------- \\n\n",
" Here is the question: {question}\n",
" Give a binary score 'yes' or 'no' to indicate whether the answer is useful to resolve a question.\"\"\",\n",
" input_variables=[\"generation\", \"question\"],\n",
" )\n",
"\n",
" # Prompt\n",
" chain = prompt | llm_with_tool | parser_tool\n",
"\n",
" score = chain.invoke({\"generation\": generation, \"question\": question})\n",
" grade = score[0].binary_score\n",
"\n",
" if grade == \"yes\":\n",
" print(\"---DECISION: USEFUL---\")\n",
" return \"useful\"\n",
" else:\n",
" print(\"---DECISION: NOT USEFUL---\")\n",
" return \"not useful\""
]
},
{
"cell_type": "markdown",
"id": "61cd5797-1782-4d78-a277-8196d13f3e1b",
"metadata": {},
"source": [
"## Graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e09ca9f-e36d-4ef4-a0d5-79fdbada9fe0",
"metadata": {},
"outputs": [],
"source": [
"import pprint\n",
"\n",
"from langgraph.graph import END, StateGraph\n",
"\n",
"workflow = StateGraph(GraphState)\n",
"\n",
"# Define the nodes\n",
"workflow.add_node(\"retrieve\", retrieve) # retrieve\n",
"workflow.add_node(\"grade_documents\", grade_documents) # grade documents\n",
"workflow.add_node(\"generate\", generate) # generatae\n",
"workflow.add_node(\"transform_query\", transform_query) # transform_query\n",
"workflow.add_node(\"prepare_for_final_grade\", prepare_for_final_grade) # passthrough\n",
"\n",
"# Build graph\n",
"workflow.set_entry_point(\"retrieve\")\n",
"workflow.add_edge(\"retrieve\", \"grade_documents\")\n",
"workflow.add_conditional_edges(\n",
" \"grade_documents\",\n",
" decide_to_generate,\n",
" {\n",
" \"transform_query\": \"transform_query\",\n",
" \"generate\": \"generate\",\n",
" },\n",
")\n",
"workflow.add_edge(\"transform_query\", \"retrieve\")\n",
"workflow.add_conditional_edges(\n",
" \"generate\",\n",
" grade_generation_v_documents,\n",
" {\n",
" \"supported\": \"prepare_for_final_grade\",\n",
" \"not supported\": \"generate\",\n",
" },\n",
")\n",
"workflow.add_conditional_edges(\n",
" \"prepare_for_final_grade\",\n",
" grade_generation_v_question,\n",
" {\n",
" \"useful\": END,\n",
" \"not useful\": \"transform_query\",\n",
" },\n",
")\n",
"\n",
"# Compile\n",
"app = workflow.compile()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fb69dbb9-91ee-4868-8c3c-93af3cd885be",
"metadata": {},
"outputs": [],
"source": [
"# Run\n",
"inputs = {\"keys\": {\"question\": \"Explain how the different types of agent memory work?\"}}\n",
"for output in app.stream(inputs):\n",
" for key, value in output.items():\n",
" pprint.pprint(f\"Output from node '{key}':\")\n",
" pprint.pprint(\"---\")\n",
" pprint.pprint(value[\"keys\"], indent=2, width=80, depth=None)\n",
" pprint.pprint(\"\\n---\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4138bc51-8c84-4b8a-8d24-f7f470721f6f",
"metadata": {},
"outputs": [],
"source": [
"inputs = {\"keys\": {\"question\": \"Explain how chain of thought prompting works?\"}}\n",
"for output in app.stream(inputs):\n",
" for key, value in output.items():\n",
" pprint.pprint(f\"Output from node '{key}':\")\n",
" pprint.pprint(\"---\")\n",
" pprint.pprint(value[\"keys\"], indent=2, width=80, depth=None)\n",
" pprint.pprint(\"\\n---\\n\")"
]
},
{
"cell_type": "markdown",
"id": "548f1c5b-4108-4aae-8abb-ec171b511b92",
"metadata": {},
"source": [
"Trace - \n",
" \n",
"* https://smith.langchain.com/public/55d6180f-aab8-42bc-8799-dadce6247d9b/r\n",
"* https://smith.langchain.com/public/f85ebc95-81d9-47fc-91c6-b54e5b78f359/r"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}