You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/cookbook/RAPTOR.ipynb

748 lines
210 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "3058e9ca-07c3-4eef-b98c-bc2f2dbb9cc6",
"metadata": {},
"outputs": [],
"source": [
"pip install -U langchain umap-learn scikit-learn langchain_community tiktoken langchain-openai langchainhub chromadb langchain-anthropic"
]
},
{
"attachments": {
"72039e0c-e8c4-4b17-8780-04ad9fc584f3.png": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA88AAAMDCAYAAACcoScMAAAMP2lDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBCCSAgJfQmCEgJICWEFkB6EWyEJEAoMQaCiB1dVHDtYgEbuiqi2AGxI3YWwd4XRRSUdbFgV96kgK77yvfO9829//3nzH/OnDu3DADqp7hicQ6qAUCuKF8SGxLAGJucwiB1AwTggAYIgMDl5YlZ0dERANrg+e/27ib0hnbNQab1z/7/app8QR4PACQa4jR+Hi8X4kMA4JU8sSQfAKKMN5+aL5Zh2IC2BCYI8UIZzlDgShlOU+B9cp/4WDbEzQCoqHG5kgwAaG2QZxTwMqAGrQ9iJxFfKAJAnQGxb27uZD7EqRDbQB8xxDJ9ZtoPOhl/00wb0uRyM4awYi5yUwkU5olzuNP+z3L8b8vNkQ7GsIJNLVMSGiubM6zb7ezJ4TKsBnGvKC0yCmItiD8I+XJ/iFFKpjQ0QeGPGvLy2LBmQBdiJz43MBxiQ4iDRTmREUo+LV0YzIEYrhC0UJjPiYdYD+KFgrygOKXPZsnkWGUstC5dwmYp+QtciTyuLNZDaXYCS6n/OlPAUepjtKLM+CSIKRBbFAgTIyGmQeyYlx0XrvQZXZTJjhz0kUhjZflbQBwrEIUEKPSxgnRJcKzSvzQ3b3C+2OZMISdSiQ/kZ8aHKuqDNfO48vzhXLA2gYiVMKgjyBsbMTgXviAwSDF3rFsgSohT6nwQ5wfEKsbiFHFOtNIfNxPkhMh4M4hd8wrilGPxxHy4IBX6eLo4PzpekSdelMUNi1bkgy8DEYANAgEDSGFLA5NBFhC29tb3witFTzDgAgnIAALgoGQGRyTJe0TwGAeKwJ8QCUDe0LgAea8AFED+6xCrODqAdHlvgXxENngKcS4IBznwWiofJRqKlgieQEb4j+hc2Hgw3xzYZP3/nh9kvzMsyEQoGelgRIb6oCcxiBhIDCUGE21xA9wX98Yj4NEfNheciXsOzuO7P+EpoZ3wmHCD0EG4M0lYLPkpyzGgA+oHK2uR9mMtcCuo6YYH4D5QHSrjurgBcMBdYRwW7gcju0GWrcxbVhXGT9p/m8EPd0PpR3Yio+RhZH+yzc8jaXY0tyEVWa1/rI8i17SherOHen6Oz/6h+nx4Dv/ZE1uIHcTOY6exi9gxrB4wsJNYA9aCHZfhodX1RL66BqPFyvPJhjrCf8QbvLOySuY51Tj1OH1R9OULCmXvaMCeLJ4mEWZk5jNY8IsgYHBEPMcRDBcnF1cAZN8XxevrTYz8u4Hotnzn5v0BgM/JgYGBo9+5sJMA7PeAj/+R75wNE346VAG4cIQnlRQoOFx2IMC3hDp80vSBMTAHNnA+LsAdeAN/EATCQBSIB8lgIsw+E65zCZgKZoC5oASUgWVgNVgPNoGtYCfYAw6AenAMnAbnwGXQBm6Ae3D1dIEXoA+8A58RBCEhVISO6CMmiCVij7ggTMQXCUIikFgkGUlFMhARIkVmIPOQMmQFsh7ZglQj+5EjyGnkItKO3EEeIT3Ia+QTiqFqqDZqhFqhI1EmykLD0Xh0ApqBTkGL0PnoEnQtWoXuRuvQ0+hl9Abagb5A+zGAqWK6mCnmgDExNhaFpWDpmASbhZVi5VgVVos1wvt8DevAerGPOBGn4wzcAa7gUDwB5+FT8Fn4Ynw9vhOvw5vxa/gjvA//RqASDAn2BC8ChzCWkEGYSighlBO2Ew4TzsJnqYvwjkgk6hKtiR7wWUwmZhGnExcTNxD3Ek8R24mdxH4SiaRPsif5kKJIXFI+qYS0jrSbdJJ0ldRF+qCiqmKi4qISrJKiIlIpVilX2aVyQuWqyjOVz2QNsiXZixxF5pOnkZeSt5EbyVfIXeTPFE2KNcWHEk/JosylrKXUUs5S7lPeqKqqmql6qsaoClXnqK5V3ad6QfWR6kc1LTU7NbbaeDWp2hK1HWqn1O6ovaFSqVZUf2oKNZ+6hFpNPUN9SP1Ao9McaRwanzabVkGro12lvVQnq1uqs9Qnqhepl6sfVL+i3qtB1rDSYGtwNWZpVGgc0bil0a9J13TWjNLM1VysuUvzoma3FknLSitIi681X2ur1hmtTjpGN6ez6Tz6PPo2+ll6lzZR21qbo52lXaa9R7tVu09HS8dVJ1GnUKdC57hOhy6ma6XL0c3RXap7QPem7qdhRsNYwwTDFg2rHXZ12Hu94Xr+egK9Ur29ejf0Pukz9IP0s/WX69frPzDADewMYgymGmw0OGvQO1x7uPdw3vDS4QeG3zVEDe0MYw2nG241bDHsNzI2CjESG60zOmPUa6xr7G+cZbzK+IRxjwndxNdEaLLK5KTJc4YOg8XIYaxlNDP6TA1NQ02lpltMW00/m1mbJZgVm+01e2BOMWeap5uvMm8y77MwsRhjMcOixuKuJdmSaZlpucbyvOV7K2urJKsFVvVW3dZ61hzrIusa6/s2VBs/myk2VTbXbYm2TNts2w22bXaonZtdpl2F3RV71N7dXmi/wb59BGGE5wjRiKoRtxzUHFgOBQ41Do8cdR0jHIsd6x1fjrQYmTJy+cjzI785uTnlOG1zuues5RzmXOzc6Pzaxc6F51Lhcn0UdVTwqNmjGka9crV3FbhudL3tRncb47bArcntq7uHu8S91r3Hw8Ij1aPS4xZTmxnNXMy84EnwDPCc7XnM86OXu1e+1wGvv7wdvLO9d3l3j7YeLRi9bXSnj5kP12eLT4cvwzfVd7Nvh5+pH9evyu+xv7k/33+7/zOWLSuLtZv1MsApQBJwOOA924s9k30qEAsMCSwNbA3SCkoIWh/0MNgsOCO4JrgvxC1kesipUEJoeOjy0FscIw6PU83pC/MImxnWHK4WHhe+PvxxhF2EJKJxDDombMzKMfcjLSNFkfVRIIoTtTLqQbR19JToozHEmOiYipinsc6xM2LPx9HjJsXtinsXHxC/NP5egk2CNKEpUT1xfGJ14vukwKQVSR1jR46dOfZyskGyMLkhhZSSmLI9pX9c0LjV47rGu40vGX9zgvWEwgkXJxpMzJl4fJL6JO6kg6mE1KTUXalfuFHcKm5/GietMq2Px+at4b3g+/NX8XsEPoIVgmfpPukr0rszfDJWZvRk+mWWZ/YK2cL1wldZoVmbst5nR2XvyB7IScrZm6uSm5p7RKQlyhY1TzaeXDi5XWwvLhF3TPGasnpKnyRcsj0PyZuQ15CvDX/kW6Q20l+kjwp8CyoKPkxNnHqwULNQVNgyzW7aomnPioKLfpuOT+dNb5phOmPujEczWTO3zEJmpc1qmm0+e/7srjkhc3bOpczNnvt7sVPxiuK385LmNc43mj9nfucvIb/UlNBKJCW3Fngv2LQQXyhc2Lpo1KJ1i76V8ksvlTmVlZd9WcxbfOlX51/X/jqwJH1J61L3pRuXEZeJlt1c7rd85wrNFUUrOleOWVm3irGqdNXb1ZNWXyx3Ld+0hrJGuqZjbcTahnUW65at+7I+c/2NioCKvZWGlYsq32/gb7i60X9j7SajTWWbPm0Wbr69JWRLXZVVVflW4taCrU+3JW47/xvzt+rtBtvLtn/dIdrRsTN2Z3O1R3X1LsNdS2vQGmlNz+7xu9v2BO5pqHWo3bJXd2/ZPrBPuu/5/tT9Nw+EH2g6yDxYe8jyUOVh+uHSOqRuWl1ffWZ9R0NyQ/uRsCNNjd6Nh486Ht1xzPRYxXGd40tPUE7MPzFwsuhk/ynxqd7TGac7myY13Tsz9sz15pjm1rPhZy+cCz535jzr/MkLPheOXfS6eOQS81L9ZffLdS1uLYd/d/v9cKt7a90VjysNbZ5tje2j209c9bt6+lrgtXPXOdcv34i80X4z4ebtW+Nvddzm3+6+k3Pn1d2Cu5/vzblPuF/6QONB+UPDh1V/2P6xt8O94/ijwEctj+Me3+vkdb54kvfkS9f8p9Sn5c9MnlV3u3Qf6wnuaXs+7nnXC/GLz70lf2r+WfnS5uWhv/z/aukb29f1SvJq4PXiN/pvdrx1fdvUH93/8F3uu8/vSz/of9j5kfnx/KekT88+T/1C+rL2q+3Xxm/h
}
},
"cell_type": "markdown",
"id": "ea54c848-0df6-474e-b266-218a2acf67d3",
"metadata": {},
"source": [
"# RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval\n",
"\n",
"The [RAPTOR](https://arxiv.org/pdf/2401.18059.pdf) paper presents an interesting approaching for indexing and retrieval of documents:\n",
"\n",
"* The `leafs` are a set of starting documents\n",
"* Leafs are embedded and clustered\n",
"* Clusters are then summarized into higher level (more abstract) consolidations of information across similar documents\n",
"\n",
"This process is done recursivly, resulting in a \"tree\" going from raw docs (`leafs`) to more abstract summaries.\n",
" \n",
"We can applying this at varying scales; `leafs` can be:\n",
"\n",
"* Text chunks from a single doc (as shown in the paper)\n",
"* Full docs (as we show below)\n",
"\n",
"With longer context LLMs, it's possible to perform this over full documents. \n",
"\n",
"![Screenshot 2024-03-04 at 12.45.25 PM.png](attachment:72039e0c-e8c4-4b17-8780-04ad9fc584f3.png)"
]
},
{
"cell_type": "markdown",
"id": "083dd961-b401-4fc6-867c-8f8950059b02",
"metadata": {},
"source": [
"### Docs\n",
"\n",
"Let's apply this to LangChain's LCEL documentation.\n",
"\n",
"In this case, each `doc` is a unique web page of the LCEL docs.\n",
"\n",
"The context varies from < 2k tokens on up to > 10k tokens."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b17c1331-373f-491d-8b53-ccf634e68c8e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function matplotlib.pyplot.show(close=None, block=None)>"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0kAAAIjCAYAAADWYVDIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAABGN0lEQVR4nO3deVxUZf//8ffAwLDIoriSqLjnkpZmmZaZFqm5tWqWaLZbaprZcleSmZZlVpbl3Z3LnWXZ3farNNcy01RUXAs1DS23cmEEERnm+v3hg/meEVTEgQF8PR8PHndznetc53PO5T3O23PmwmaMMQIAAAAASJIC/F0AAAAAAJQmhCQAAAAAsCAkAQAAAIAFIQkAAAAALAhJAAAAAGBBSAIAAAAAC0ISAAAAAFgQkgAAAADAgpAEAAAAABaEJAAoAXXq1NGAAQP8XUa5N2HCBNWtW1eBgYFq2bJlsR7rhx9+kM1m02effVasxwEAlDxCEgCco+nTp8tmsyk5ObnA7ddee62aNWt23sf57rvvNHr06PMe50Ixf/58PfHEE2rXrp2mTZuml156KV+fvGBTmJ+y6Pjx43r99dd1xRVXKCoqSiEhIWrYsKEeeeQRbd261d/lSZKWL1+u0aNH68iRI/4uBQBOy+7vAgDgQpCamqqAgHP7d6nvvvtOb7/9NkGpkBYvXqyAgAD95z//UXBwcIF9Lr74Yv33v//1anvqqadUoUIFPfPMMyVRZrH5559/dOONN2rNmjW66aabdOedd6pChQpKTU3V7NmzNXXqVJ04ccLfZWr58uVKSkrSgAEDFB0d7e9yAKBAhCQAKAEOh8PfJZyzzMxMhYeH+7uMQjtw4IBCQ0NPG5AkqVq1arrrrru82saPH6/KlSvnay9rBgwYoHXr1umzzz7TLbfc4rVtzJgxZT4EAkBJ4nE7ACgBp34nKScnR0lJSWrQoIFCQkIUExOj9u3ba8GCBZJOfuB9++23JanAR8AyMzM1YsQIxcXFyeFwqFGjRnr11VdljPE6blZWloYMGaLKlSsrIiJCPXr00F9//SWbzeZ1h2r06NGy2WzasmWL7rzzTlWsWFHt27eXJG3YsEEDBgxQ3bp1FRISourVq+uee+7RwYMHvY6VN8bWrVt11113KSoqSlWqVNGzzz4rY4x2796tnj17KjIyUtWrV9drr71WqGvncrk0ZswY1atXTw6HQ3Xq1NHTTz+t7OxsTx+bzaZp06YpMzPTc62mT59eqPELsmPHDt12222qVKmSwsLCdOWVV+rbb789637Z2dm66aabFBUVpeXLl0uS3G63Jk2apKZNmyokJETVqlXTAw88oMOHD3vtW6dOHd10001atmyZ2rRpo5CQENWtW1czZ84863FXrlypb7/9VoMGDcoXkKSTIf3VV1/1alu8eLGuvvpqhYeHKzo6Wj179tSvv/7q1WfAgAGqU6dOvvHy5trKZrPpkUce0ZdffqlmzZrJ4XCoadOmmjdvntd+I0eOlCTFx8d75uqPP/6QJC1YsEDt27dXdHS0KlSooEaNGunpp58+6/kDgK9xJwkAiig9PV3//PNPvvacnJyz7jt69GiNGzdO9957r9q0aSOn06nk5GStXbtW119/vR544AHt2bNHCxYsyPd4mDFGPXr00JIlSzRo0CC1bNlS33//vUaOHKm//vpLr7/+uqfvgAED9Omnn+ruu+/WlVdeqR9//FHdunU7bV233XabGjRooJdeeskTuBYsWKAdO3Zo4MCBql69ujZv3qypU6dq8+bN+uWXX/J9WL7jjjt08cUXa/z48fr222/14osvqlKlSnrvvfd03XXX6eWXX9asWbP0+OOP6/LLL9c111xzxmt17733asaMGbr11ls1YsQIrVy5UuPGjdOvv/6qL774QpL03//+V1OnTtWqVav0/vvvS5Kuuuqqs85DQfbv36+rrrpKx44d05AhQxQTE6MZM2aoR48e+uyzz9S7d+8C98vKylLPnj2VnJyshQsX6vLLL5ckPfDAA5o+fboGDhyoIUOGaOfOnZo8ebLWrVunn3/+WUFBQZ4xtm/frltvvVWDBg1SYmKiPvjgAw0YMECtWrVS06ZNT1vz119/LUm6++67C3WOCxcuVJcuXVS3bl2NHj1aWVlZeuutt9SuXTutXbu2wGBUGMuWLdPnn3+uhx9+WBEREXrzzTd1yy23aNeuXYqJidHNN9+srVu36uOPP9brr7+uypUrS5KqVKmizZs366abbtIll1yiF154QQ6HQ9u3b9fPP/9cpFoA4LwYAMA5mTZtmpF0xp+mTZt67VO7dm2TmJjoed2iRQvTrVu3Mx5n8ODBpqC36S+//NJIMi+++KJX+6233mpsNpvZvn27McaYNWvWGElm2LBhXv0GDBhgJJnnn3/e0/b8888bSaZv3775jnfs2LF8bR9//LGRZJYuXZpvjPvvv9/T5nK5TM2aNY3NZjPjx4/3tB8+fNiEhoZ6XZOCpKSkGEnm3nvv9Wp//PHHjSSzePFiT1tiYqIJDw8/43gFadq0qenQoYPn9bBhw4wk89NPP3najh49auLj402dOnVMbm6uMcaYJUuWGElmzpw55ujRo6ZDhw6mcuXKZt26dZ79fvrpJyPJzJo1y+uY8+bNy9deu3btfNf0wIEDxuFwmBEjRpzxHHr37m0kmcOHDxfqnFu2bGmqVq1qDh486Glbv369CQgIMP379/e0JSYmmtq1a+fbP2+urSSZ4OBgz5+/vDElmbfeesvTNmHCBCPJ7Ny502v/119/3Ugyf//9d6HOAQCKE4/bAUARvf3221qwYEG+n0suueSs+0ZHR2vz5s3atm3bOR/3u+++U2BgoIYMGeLVPmLECBljNHfuXEnyPOb08MMPe/V79NFHTzv2gw8+mK8tNDTU89/Hjx/XP//8oyuvvFKStHbt2nz97733Xs9/BwYGqnXr1jLGaNCgQZ726OhoNWrUSDt27DhtLdLJc5Wk4cOHe7WPGDFCkgr1CNy5+u6779SmTRvP44aSVKFCBd1///36448/tGXLFq/+6enpuuGGG/Tbb7/phx9+8Fp6fM6cOYqKitL111+vf/75x/PTqlUrVahQQUuWLPEaq0mTJrr66qs9r6tUqVKo6+R0OiVJERERZz2/vXv3KiUlRQMGDFClSpU87Zdccomuv/56zzUvis6dO6tevXpeY0ZGRp61fkmeRRy++uorud3uItcAAL5ASAKAImrTpo06d+6c76dixYpn3feFF17QkSNH1LBhQzVv3lwjR47Uhg0bCnXctLQ0xcbG5vtAfPHFF3u25/1vQECA4uPjvfrVr1//tGOf2leSDh06pKFDh6patWoKDQ1VlSpVPP3S09Pz9a9Vq5bX67ylqPMerbK2n/q9nFPlncOpNVevXl3R0dGec/WltLQ0NWrUKF/7qdc3z7Bhw7R69WotXLgw3yNx27ZtU3p6uqpWraoqVap4/WRkZOjAgQNe/U+9dpJUsWLFs16nyMhISdLRo0cLdX6STnuO//zzjzIzM886TkGKWr908jHNdu3a6d5771W1atXUp08fffrppwQmAH7Bd5IAwA+uueYa/f777/rqq680f/58vf/++3r99df17rvvet2JKWnWu0Z5br/9di1fvlwjR45Uy5YtVaFCBbndbt14440FfoANDAwsVJukfAtNnE5p/r1FPXv21OzZszV+/HjNnDnTa6l3t9utqlWratasWQXuW6VKFa/XRb1OjRs3liRt3LjR607U+Trddc/NzS2w/XzmOTQ0VEuXLtWSJUv07bffat68efrkk0903XXXaf78+acdGwCKA3eSAMBPKlWqpIEDB+rjjz/W7t27dckll3itOHe6D6i1a9fWnj178t01+O233zzb8/7X7XZr586dXv22b99e6BoPHz6sRYsW6cknn1RSUpJ69+6t66+/XnXr1i30GOcj7xxOfSxx//79OnLkiOdcfX3M1NTUfO2nXt88vXr10gcffKCPPvpIgwcP9tpWr149HTx4UO3atSvwrmOLFi18UnP37t0lSR9++OFZ++bVf7pzrFy5smf
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import tiktoken\n",
"from bs4 import BeautifulSoup as Soup\n",
"from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader\n",
"\n",
"\n",
"def num_tokens_from_string(string: str, encoding_name: str) -> int:\n",
" \"\"\"Returns the number of tokens in a text string.\"\"\"\n",
" encoding = tiktoken.get_encoding(encoding_name)\n",
" num_tokens = len(encoding.encode(string))\n",
" return num_tokens\n",
"\n",
"\n",
"# LCEL docs\n",
"url = \"https://python.langchain.com/docs/expression_language/\"\n",
"loader = RecursiveUrlLoader(\n",
" url=url, max_depth=20, extractor=lambda x: Soup(x, \"html.parser\").text\n",
")\n",
"docs = loader.load()\n",
"\n",
"# LCEL w/ PydanticOutputParser (outside the primary LCEL docs)\n",
"url = \"https://python.langchain.com/docs/modules/model_io/output_parsers/quick_start\"\n",
"loader = RecursiveUrlLoader(\n",
" url=url, max_depth=1, extractor=lambda x: Soup(x, \"html.parser\").text\n",
")\n",
"docs_pydantic = loader.load()\n",
"\n",
"# LCEL w/ Self Query (outside the primary LCEL docs)\n",
"url = \"https://python.langchain.com/docs/modules/data_connection/retrievers/self_query/\"\n",
"loader = RecursiveUrlLoader(\n",
" url=url, max_depth=1, extractor=lambda x: Soup(x, \"html.parser\").text\n",
")\n",
"docs_sq = loader.load()\n",
"\n",
"# Doc texts\n",
"docs.extend([*docs_pydantic, *docs_sq])\n",
"docs_texts = [d.page_content for d in docs]\n",
"\n",
"# Calculate the number of tokens for each document\n",
"counts = [num_tokens_from_string(d, \"cl100k_base\") for d in docs_texts]\n",
"\n",
"# Plotting the histogram of token counts\n",
"plt.figure(figsize=(10, 6))\n",
"plt.hist(counts, bins=30, color=\"blue\", edgecolor=\"black\", alpha=0.7)\n",
"plt.title(\"Histogram of Token Counts\")\n",
"plt.xlabel(\"Token Count\")\n",
"plt.ylabel(\"Frequency\")\n",
"plt.grid(axis=\"y\", alpha=0.75)\n",
"\n",
"# Display the histogram\n",
"plt.show"
]
},
{
"cell_type": "code",
"execution_count": 75,
"id": "70750603-ec82-4439-9b32-d22014b5ff2c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Num tokens in all context: 68705\n"
]
}
],
"source": [
"# Doc texts concat\n",
"d_sorted = sorted(docs, key=lambda x: x.metadata[\"source\"])\n",
"d_reversed = list(reversed(d_sorted))\n",
"concatenated_content = \"\\n\\n\\n --- \\n\\n\\n\".join(\n",
" [doc.page_content for doc in d_reversed]\n",
")\n",
"print(\n",
" \"Num tokens in all context: %s\"\n",
" % num_tokens_from_string(concatenated_content, \"cl100k_base\")\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 155,
"id": "25ca3cf2-0f6b-40f9-a2ff-285a8dcb33dc",
"metadata": {},
"outputs": [],
"source": [
"# Doc texts split\n",
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
"\n",
"chunk_size_tok = 2000\n",
"text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
" chunk_size=chunk_size_tok, chunk_overlap=0\n",
")\n",
"texts_split = text_splitter.split_text(concatenated_content)"
]
},
{
"cell_type": "markdown",
"id": "797a5469-0942-45a5-adb6-f12e05d76798",
"metadata": {},
"source": [
"## Models\n",
"\n",
"We can test various models, including the new [Claude3](https://www.anthropic.com/news/claude-3-family) family.\n",
"\n",
"Be sure to set the relevant API keys:\n",
"\n",
"* `ANTHROPIC_API_KEY`\n",
"* `OPENAI_API_KEY`"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "033e71d3-5dc8-42a3-a0b7-4df116048c14",
"metadata": {},
"outputs": [],
"source": [
"from langchain_openai import OpenAIEmbeddings\n",
"\n",
"embd = OpenAIEmbeddings()\n",
"\n",
"# from langchain_openai import ChatOpenAI\n",
"\n",
"# model = ChatOpenAI(temperature=0, model=\"gpt-4-1106-preview\")\n",
"\n",
"from langchain_anthropic import ChatAnthropic\n",
"\n",
"model = ChatAnthropic(temperature=0, model=\"claude-3-opus-20240229\")"
]
},
{
"cell_type": "markdown",
"id": "5c63db01-cf95-4c17-ae5d-8dc7267ad58a",
"metadata": {},
"source": [
"### Tree Constrution\n",
"\n",
"The clustering approach in tree construction includes a few interesting ideas.\n",
"\n",
"**GMM (Gaussian Mixture Model)** \n",
"\n",
"- Model the distribution of data points across different clusters\n",
"- Optimal number of clusters by evaluating the model's Bayesian Information Criterion (BIC)\n",
"\n",
"**UMAP (Uniform Manifold Approximation and Projection)** \n",
"\n",
"- Supports clustering\n",
"- Reduces the dimensionality of high-dimensional data\n",
"- UMAP helps to highlight the natural grouping of data points based on their similarities\n",
"\n",
"**Local and Global Clustering** \n",
"\n",
"- Used to analyze data at different scales\n",
"- Both fine-grained and broader patterns within the data are captured effectively\n",
"\n",
"**Thresholding** \n",
"\n",
"- Apply in the context of GMM to determine cluster membership\n",
"- Based on the probability distribution (assignment of data points to ≥ 1 cluster)\n",
"---\n",
"\n",
"Code for GMM and thresholding is from Sarthi et al, as noted in the below two sources:\n",
" \n",
"* [Origional repo](https://github.com/parthsarthi03/raptor/blob/master/raptor/cluster_tree_builder.py)\n",
"* [Minor tweaks](https://github.com/run-llama/llama_index/blob/main/llama-index-packs/llama-index-packs-raptor/llama_index/packs/raptor/clustering.py)\n",
"\n",
"Full credit to both authors."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a849980c-27d4-48e0-87a0-c2a5143cb8c0",
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict, List, Optional, Tuple\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import umap\n",
"from langchain.prompts import ChatPromptTemplate\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from sklearn.mixture import GaussianMixture\n",
"\n",
"RANDOM_SEED = 224 # Fixed seed for reproducibility\n",
"\n",
"### --- Code from citations referenced above (added comments and docstrings) --- ###\n",
"\n",
"\n",
"def global_cluster_embeddings(\n",
" embeddings: np.ndarray,\n",
" dim: int,\n",
" n_neighbors: Optional[int] = None,\n",
" metric: str = \"cosine\",\n",
") -> np.ndarray:\n",
" \"\"\"\n",
" Perform global dimensionality reduction on the embeddings using UMAP.\n",
"\n",
" Parameters:\n",
" - embeddings: The input embeddings as a numpy array.\n",
" - dim: The target dimensionality for the reduced space.\n",
" - n_neighbors: Optional; the number of neighbors to consider for each point.\n",
" If not provided, it defaults to the square root of the number of embeddings.\n",
" - metric: The distance metric to use for UMAP.\n",
"\n",
" Returns:\n",
" - A numpy array of the embeddings reduced to the specified dimensionality.\n",
" \"\"\"\n",
" if n_neighbors is None:\n",
" n_neighbors = int((len(embeddings) - 1) ** 0.5)\n",
" return umap.UMAP(\n",
" n_neighbors=n_neighbors, n_components=dim, metric=metric\n",
" ).fit_transform(embeddings)\n",
"\n",
"\n",
"def local_cluster_embeddings(\n",
" embeddings: np.ndarray, dim: int, num_neighbors: int = 10, metric: str = \"cosine\"\n",
") -> np.ndarray:\n",
" \"\"\"\n",
" Perform local dimensionality reduction on the embeddings using UMAP, typically after global clustering.\n",
"\n",
" Parameters:\n",
" - embeddings: The input embeddings as a numpy array.\n",
" - dim: The target dimensionality for the reduced space.\n",
" - num_neighbors: The number of neighbors to consider for each point.\n",
" - metric: The distance metric to use for UMAP.\n",
"\n",
" Returns:\n",
" - A numpy array of the embeddings reduced to the specified dimensionality.\n",
" \"\"\"\n",
" return umap.UMAP(\n",
" n_neighbors=num_neighbors, n_components=dim, metric=metric\n",
" ).fit_transform(embeddings)\n",
"\n",
"\n",
"def get_optimal_clusters(\n",
" embeddings: np.ndarray, max_clusters: int = 50, random_state: int = RANDOM_SEED\n",
") -> int:\n",
" \"\"\"\n",
" Determine the optimal number of clusters using the Bayesian Information Criterion (BIC) with a Gaussian Mixture Model.\n",
"\n",
" Parameters:\n",
" - embeddings: The input embeddings as a numpy array.\n",
" - max_clusters: The maximum number of clusters to consider.\n",
" - random_state: Seed for reproducibility.\n",
"\n",
" Returns:\n",
" - An integer representing the optimal number of clusters found.\n",
" \"\"\"\n",
" max_clusters = min(max_clusters, len(embeddings))\n",
" n_clusters = np.arange(1, max_clusters)\n",
" bics = []\n",
" for n in n_clusters:\n",
" gm = GaussianMixture(n_components=n, random_state=random_state)\n",
" gm.fit(embeddings)\n",
" bics.append(gm.bic(embeddings))\n",
" return n_clusters[np.argmin(bics)]\n",
"\n",
"\n",
"def GMM_cluster(embeddings: np.ndarray, threshold: float, random_state: int = 0):\n",
" \"\"\"\n",
" Cluster embeddings using a Gaussian Mixture Model (GMM) based on a probability threshold.\n",
"\n",
" Parameters:\n",
" - embeddings: The input embeddings as a numpy array.\n",
" - threshold: The probability threshold for assigning an embedding to a cluster.\n",
" - random_state: Seed for reproducibility.\n",
"\n",
" Returns:\n",
" - A tuple containing the cluster labels and the number of clusters determined.\n",
" \"\"\"\n",
" n_clusters = get_optimal_clusters(embeddings)\n",
" gm = GaussianMixture(n_components=n_clusters, random_state=random_state)\n",
" gm.fit(embeddings)\n",
" probs = gm.predict_proba(embeddings)\n",
" labels = [np.where(prob > threshold)[0] for prob in probs]\n",
" return labels, n_clusters\n",
"\n",
"\n",
"def perform_clustering(\n",
" embeddings: np.ndarray,\n",
" dim: int,\n",
" threshold: float,\n",
") -> List[np.ndarray]:\n",
" \"\"\"\n",
" Perform clustering on the embeddings by first reducing their dimensionality globally, then clustering\n",
" using a Gaussian Mixture Model, and finally performing local clustering within each global cluster.\n",
"\n",
" Parameters:\n",
" - embeddings: The input embeddings as a numpy array.\n",
" - dim: The target dimensionality for UMAP reduction.\n",
" - threshold: The probability threshold for assigning an embedding to a cluster in GMM.\n",
"\n",
" Returns:\n",
" - A list of numpy arrays, where each array contains the cluster IDs for each embedding.\n",
" \"\"\"\n",
" if len(embeddings) <= dim + 1:\n",
" # Avoid clustering when there's insufficient data\n",
" return [np.array([0]) for _ in range(len(embeddings))]\n",
"\n",
" # Global dimensionality reduction\n",
" reduced_embeddings_global = global_cluster_embeddings(embeddings, dim)\n",
" # Global clustering\n",
" global_clusters, n_global_clusters = GMM_cluster(\n",
" reduced_embeddings_global, threshold\n",
" )\n",
"\n",
" all_local_clusters = [np.array([]) for _ in range(len(embeddings))]\n",
" total_clusters = 0\n",
"\n",
" # Iterate through each global cluster to perform local clustering\n",
" for i in range(n_global_clusters):\n",
" # Extract embeddings belonging to the current global cluster\n",
" global_cluster_embeddings_ = embeddings[\n",
" np.array([i in gc for gc in global_clusters])\n",
" ]\n",
"\n",
" if len(global_cluster_embeddings_) == 0:\n",
" continue\n",
" if len(global_cluster_embeddings_) <= dim + 1:\n",
" # Handle small clusters with direct assignment\n",
" local_clusters = [np.array([0]) for _ in global_cluster_embeddings_]\n",
" n_local_clusters = 1\n",
" else:\n",
" # Local dimensionality reduction and clustering\n",
" reduced_embeddings_local = local_cluster_embeddings(\n",
" global_cluster_embeddings_, dim\n",
" )\n",
" local_clusters, n_local_clusters = GMM_cluster(\n",
" reduced_embeddings_local, threshold\n",
" )\n",
"\n",
" # Assign local cluster IDs, adjusting for total clusters already processed\n",
" for j in range(n_local_clusters):\n",
" local_cluster_embeddings_ = global_cluster_embeddings_[\n",
" np.array([j in lc for lc in local_clusters])\n",
" ]\n",
" indices = np.where(\n",
" (embeddings == local_cluster_embeddings_[:, None]).all(-1)\n",
" )[1]\n",
" for idx in indices:\n",
" all_local_clusters[idx] = np.append(\n",
" all_local_clusters[idx], j + total_clusters\n",
" )\n",
"\n",
" total_clusters += n_local_clusters\n",
"\n",
" return all_local_clusters\n",
"\n",
"\n",
"### --- Our code below --- ###\n",
"\n",
"\n",
"def embed(texts):\n",
" \"\"\"\n",
" Generate embeddings for a list of text documents.\n",
"\n",
" This function assumes the existence of an `embd` object with a method `embed_documents`\n",
" that takes a list of texts and returns their embeddings.\n",
"\n",
" Parameters:\n",
" - texts: List[str], a list of text documents to be embedded.\n",
"\n",
" Returns:\n",
" - numpy.ndarray: An array of embeddings for the given text documents.\n",
" \"\"\"\n",
" text_embeddings = embd.embed_documents(texts)\n",
" text_embeddings_np = np.array(text_embeddings)\n",
" return text_embeddings_np\n",
"\n",
"\n",
"def embed_cluster_texts(texts):\n",
" \"\"\"\n",
" Embeds a list of texts and clusters them, returning a DataFrame with texts, their embeddings, and cluster labels.\n",
"\n",
" This function combines embedding generation and clustering into a single step. It assumes the existence\n",
" of a previously defined `perform_clustering` function that performs clustering on the embeddings.\n",
"\n",
" Parameters:\n",
" - texts: List[str], a list of text documents to be processed.\n",
"\n",
" Returns:\n",
" - pandas.DataFrame: A DataFrame containing the original texts, their embeddings, and the assigned cluster labels.\n",
" \"\"\"\n",
" text_embeddings_np = embed(texts) # Generate embeddings\n",
" cluster_labels = perform_clustering(\n",
" text_embeddings_np, 10, 0.1\n",
" ) # Perform clustering on the embeddings\n",
" df = pd.DataFrame() # Initialize a DataFrame to store the results\n",
" df[\"text\"] = texts # Store original texts\n",
" df[\"embd\"] = list(text_embeddings_np) # Store embeddings as a list in the DataFrame\n",
" df[\"cluster\"] = cluster_labels # Store cluster labels\n",
" return df\n",
"\n",
"\n",
"def fmt_txt(df: pd.DataFrame) -> str:\n",
" \"\"\"\n",
" Formats the text documents in a DataFrame into a single string.\n",
"\n",
" Parameters:\n",
" - df: DataFrame containing the 'text' column with text documents to format.\n",
"\n",
" Returns:\n",
" - A single string where all text documents are joined by a specific delimiter.\n",
" \"\"\"\n",
" unique_txt = df[\"text\"].tolist()\n",
" return \"--- --- \\n --- --- \".join(unique_txt)\n",
"\n",
"\n",
"def embed_cluster_summarize_texts(\n",
" texts: List[str], level: int\n",
") -> Tuple[pd.DataFrame, pd.DataFrame]:\n",
" \"\"\"\n",
" Embeds, clusters, and summarizes a list of texts. This function first generates embeddings for the texts,\n",
" clusters them based on similarity, expands the cluster assignments for easier processing, and then summarizes\n",
" the content within each cluster.\n",
"\n",
" Parameters:\n",
" - texts: A list of text documents to be processed.\n",
" - level: An integer parameter that could define the depth or detail of processing.\n",
"\n",
" Returns:\n",
" - Tuple containing two DataFrames:\n",
" 1. The first DataFrame (`df_clusters`) includes the original texts, their embeddings, and cluster assignments.\n",
" 2. The second DataFrame (`df_summary`) contains summaries for each cluster, the specified level of detail,\n",
" and the cluster identifiers.\n",
" \"\"\"\n",
"\n",
" # Embed and cluster the texts, resulting in a DataFrame with 'text', 'embd', and 'cluster' columns\n",
" df_clusters = embed_cluster_texts(texts)\n",
"\n",
" # Prepare to expand the DataFrame for easier manipulation of clusters\n",
" expanded_list = []\n",
"\n",
" # Expand DataFrame entries to document-cluster pairings for straightforward processing\n",
" for index, row in df_clusters.iterrows():\n",
" for cluster in row[\"cluster\"]:\n",
" expanded_list.append(\n",
" {\"text\": row[\"text\"], \"embd\": row[\"embd\"], \"cluster\": cluster}\n",
" )\n",
"\n",
" # Create a new DataFrame from the expanded list\n",
" expanded_df = pd.DataFrame(expanded_list)\n",
"\n",
" # Retrieve unique cluster identifiers for processing\n",
" all_clusters = expanded_df[\"cluster\"].unique()\n",
"\n",
" print(f\"--Generated {len(all_clusters)} clusters--\")\n",
"\n",
" # Summarization\n",
" template = \"\"\"Here is a sub-set of LangChain Expression Language doc. \n",
" \n",
" LangChain Expression Language provides a way to compose chain in LangChain.\n",
" \n",
" Give a detailed summary of the documentation provided.\n",
" \n",
" Documentation:\n",
" {context}\n",
" \"\"\"\n",
" prompt = ChatPromptTemplate.from_template(template)\n",
" chain = prompt | model | StrOutputParser()\n",
"\n",
" # Format text within each cluster for summarization\n",
" summaries = []\n",
" for i in all_clusters:\n",
" df_cluster = expanded_df[expanded_df[\"cluster\"] == i]\n",
" formatted_txt = fmt_txt(df_cluster)\n",
" summaries.append(chain.invoke({\"context\": formatted_txt}))\n",
"\n",
" # Create a DataFrame to store summaries with their corresponding cluster and level\n",
" df_summary = pd.DataFrame(\n",
" {\n",
" \"summaries\": summaries,\n",
" \"level\": [level] * len(summaries),\n",
" \"cluster\": list(all_clusters),\n",
" }\n",
" )\n",
"\n",
" return df_clusters, df_summary\n",
"\n",
"\n",
"def recursive_embed_cluster_summarize(\n",
" texts: List[str], level: int = 1, n_levels: int = 3\n",
") -> Dict[int, Tuple[pd.DataFrame, pd.DataFrame]]:\n",
" \"\"\"\n",
" Recursively embeds, clusters, and summarizes texts up to a specified level or until\n",
" the number of unique clusters becomes 1, storing the results at each level.\n",
"\n",
" Parameters:\n",
" - texts: List[str], texts to be processed.\n",
" - level: int, current recursion level (starts at 1).\n",
" - n_levels: int, maximum depth of recursion.\n",
"\n",
" Returns:\n",
" - Dict[int, Tuple[pd.DataFrame, pd.DataFrame]], a dictionary where keys are the recursion\n",
" levels and values are tuples containing the clusters DataFrame and summaries DataFrame at that level.\n",
" \"\"\"\n",
" results = {} # Dictionary to store results at each level\n",
"\n",
" # Perform embedding, clustering, and summarization for the current level\n",
" df_clusters, df_summary = embed_cluster_summarize_texts(texts, level)\n",
"\n",
" # Store the results of the current level\n",
" results[level] = (df_clusters, df_summary)\n",
"\n",
" # Determine if further recursion is possible and meaningful\n",
" unique_clusters = df_summary[\"cluster\"].nunique()\n",
" if level < n_levels and unique_clusters > 1:\n",
" # Use summaries as the input texts for the next level of recursion\n",
" new_texts = df_summary[\"summaries\"].tolist()\n",
" next_level_results = recursive_embed_cluster_summarize(\n",
" new_texts, level + 1, n_levels\n",
" )\n",
"\n",
" # Merge the results from the next level into the current results dictionary\n",
" results.update(next_level_results)\n",
"\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f0d8cd3e-cd49-484d-9617-1b9811cc08b3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--Generated 7 clusters--\n",
"--Generated 1 clusters--\n"
]
}
],
"source": [
"# Build tree\n",
"leaf_texts = docs_texts\n",
"results = recursive_embed_cluster_summarize(leaf_texts, level=1, n_levels=3)"
]
},
{
"cell_type": "markdown",
"id": "e80d7098-5d16-4fa6-837c-968e5c9f118d",
"metadata": {},
"source": [
"The paper reports best performance from `collapsed tree retrieval`. \n",
"\n",
"This involves flattening the tree structure into a single layer and then applying a k-nearest neighbors (kNN) search across all nodes simultaneously. \n",
"\n",
"We do simply do this below."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d28ba9e6-9124-41a8-b4fd-55a6ef4ac062",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.vectorstores import Chroma\n",
"\n",
"# Initialize all_texts with leaf_texts\n",
"all_texts = leaf_texts.copy()\n",
"\n",
"# Iterate through the results to extract summaries from each level and add them to all_texts\n",
"for level in sorted(results.keys()):\n",
" # Extract summaries from the current level's DataFrame\n",
" summaries = results[level][1][\"summaries\"].tolist()\n",
" # Extend all_texts with the summaries from the current level\n",
" all_texts.extend(summaries)\n",
"\n",
"# Now, use all_texts to build the vectorstore with Chroma\n",
"vectorstore = Chroma.from_texts(texts=all_texts, embedding=embd)\n",
"retriever = vectorstore.as_retriever()"
]
},
{
"cell_type": "markdown",
"id": "0d497627-44c6-41f7-bb63-1d858d3f188f",
"metadata": {},
"source": [
"Now we can using our flattened, indexed tree in a RAG chain."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9d6c894b-b3a3-4a01-b779-3e98ea382ff5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Here is a code example of how to define a RAG (Retrieval Augmented Generation) chain in LangChain:\\n\\n```python\\nfrom langchain.vectorstores import FAISS\\nfrom langchain.embeddings import OpenAIEmbeddings\\nfrom langchain.prompts import ChatPromptTemplate\\nfrom langchain.chat_models import ChatOpenAI\\nfrom langchain.output_parsers import StrOutputParser\\n\\n# Load documents into vector store\\nvectorstore = FAISS.from_texts(\\n [\"harrison worked at kensho\"], embedding=OpenAIEmbeddings()\\n)\\nretriever = vectorstore.as_retriever()\\n\\n# Define prompt template\\ntemplate = \"\"\"Answer the question based only on the following context:\\n{context}\\nQuestion: {question}\"\"\"\\nprompt = ChatPromptTemplate.from_template(template)\\n\\n# Define model and output parser\\nmodel = ChatOpenAI()\\noutput_parser = StrOutputParser()\\n\\n# Define RAG chain\\nchain = (\\n {\"context\": retriever, \"question\": RunnablePassthrough()}\\n | prompt\\n | model '"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain import hub\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"\n",
"# Prompt\n",
"prompt = hub.pull(\"rlm/rag-prompt\")\n",
"\n",
"\n",
"# Post-processing\n",
"def format_docs(docs):\n",
" return \"\\n\\n\".join(doc.page_content for doc in docs)\n",
"\n",
"\n",
"# Chain\n",
"rag_chain = (\n",
" {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
" | prompt\n",
" | model\n",
" | StrOutputParser()\n",
")\n",
"\n",
"# Question\n",
"rag_chain.invoke(\"How to define a RAG chain? Give me a specific code example.\")"
]
},
{
"cell_type": "markdown",
"id": "0c585b37-ad83-4069-8f5d-4a6a3e15128d",
"metadata": {},
"source": [
"Trace: \n",
"\n",
"https://smith.langchain.com/public/1dabf475-1675-4494-b16c-928fbf079851/r"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}