mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
748 lines
210 KiB
Plaintext
748 lines
210 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "3058e9ca-07c3-4eef-b98c-bc2f2dbb9cc6",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"pip install -U langchain umap-learn scikit-learn langchain_community tiktoken langchain-openai langchainhub chromadb langchain-anthropic"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"attachments": {
|
||
|
"72039e0c-e8c4-4b17-8780-04ad9fc584f3.png": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA88AAAMDCAYAAACcoScMAAAMP2lDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBCCSAgJfQmCEgJICWEFkB6EWyEJEAoMQaCiB1dVHDtYgEbuiqi2AGxI3YWwd4XRRSUdbFgV96kgK77yvfO9829//3nzH/OnDu3DADqp7hicQ6qAUCuKF8SGxLAGJucwiB1AwTggAYIgMDl5YlZ0dERANrg+e/27ib0hnbNQab1z/7/app8QR4PACQa4jR+Hi8X4kMA4JU8sSQfAKKMN5+aL5Zh2IC2BCYI8UIZzlDgShlOU+B9cp/4WDbEzQCoqHG5kgwAaG2QZxTwMqAGrQ9iJxFfKAJAnQGxb27uZD7EqRDbQB8xxDJ9ZtoPOhl/00wb0uRyM4awYi5yUwkU5olzuNP+z3L8b8vNkQ7GsIJNLVMSGiubM6zb7ezJ4TKsBnGvKC0yCmItiD8I+XJ/iFFKpjQ0QeGPGvLy2LBmQBdiJz43MBxiQ4iDRTmREUo+LV0YzIEYrhC0UJjPiYdYD+KFgrygOKXPZsnkWGUstC5dwmYp+QtciTyuLNZDaXYCS6n/OlPAUepjtKLM+CSIKRBbFAgTIyGmQeyYlx0XrvQZXZTJjhz0kUhjZflbQBwrEIUEKPSxgnRJcKzSvzQ3b3C+2OZMISdSiQ/kZ8aHKuqDNfO48vzhXLA2gYiVMKgjyBsbMTgXviAwSDF3rFsgSohT6nwQ5wfEKsbiFHFOtNIfNxPkhMh4M4hd8wrilGPxxHy4IBX6eLo4PzpekSdelMUNi1bkgy8DEYANAgEDSGFLA5NBFhC29tb3witFTzDgAgnIAALgoGQGRyTJe0TwGAeKwJ8QCUDe0LgAea8AFED+6xCrODqAdHlvgXxENngKcS4IBznwWiofJRqKlgieQEb4j+hc2Hgw3xzYZP3/nh9kvzMsyEQoGelgRIb6oCcxiBhIDCUGE21xA9wX98Yj4NEfNheciXsOzuO7P+EpoZ3wmHCD0EG4M0lYLPkpyzGgA+oHK2uR9mMtcCuo6YYH4D5QHSrjurgBcMBdYRwW7gcju0GWrcxbVhXGT9p/m8EPd0PpR3Yio+RhZH+yzc8jaXY0tyEVWa1/rI8i17SherOHen6Oz/6h+nx4Dv/ZE1uIHcTOY6exi9gxrB4wsJNYA9aCHZfhodX1RL66BqPFyvPJhjrCf8QbvLOySuY51Tj1OH1R9OULCmXvaMCeLJ4mEWZk5jNY8IsgYHBEPMcRDBcnF1cAZN8XxevrTYz8u4Hotnzn5v0BgM/JgYGBo9+5sJMA7PeAj/+R75wNE346VAG4cIQnlRQoOFx2IMC3hDp80vSBMTAHNnA+LsAdeAN/EATCQBSIB8lgIsw+E65zCZgKZoC5oASUgWVgNVgPNoGtYCfYAw6AenAMnAbnwGXQBm6Ae3D1dIEXoA+8A58RBCEhVISO6CMmiCVij7ggTMQXCUIikFgkGUlFMhARIkVmIPOQMmQFsh7ZglQj+5EjyGnkItKO3EEeIT3Ia+QTiqFqqDZqhFqhI1EmykLD0Xh0ApqBTkGL0PnoEnQtWoXuRuvQ0+hl9Abagb5A+zGAqWK6mCnmgDExNhaFpWDpmASbhZVi5VgVVos1wvt8DevAerGPOBGn4wzcAa7gUDwB5+FT8Fn4Ynw9vhOvw5vxa/gjvA//RqASDAn2BC8ChzCWkEGYSighlBO2Ew4TzsJnqYvwjkgk6hKtiR7wWUwmZhGnExcTNxD3Ek8R24mdxH4SiaRPsif5kKJIXFI+qYS0jrSbdJJ0ldRF+qCiqmKi4qISrJKiIlIpVilX2aVyQuWqyjOVz2QNsiXZixxF5pOnkZeSt5EbyVfIXeTPFE2KNcWHEk/JosylrKXUUs5S7lPeqKqqmql6qsaoClXnqK5V3ad6QfWR6kc1LTU7NbbaeDWp2hK1HWqn1O6ovaFSqVZUf2oKNZ+6hFpNPUN9SP1Ao9McaRwanzabVkGro12lvVQnq1uqs9Qnqhepl6sfVL+i3qtB1rDSYGtwNWZpVGgc0bil0a9J13TWjNLM1VysuUvzoma3FknLSitIi681X2ur1hmtTjpGN6ez6Tz6PPo2+ll6lzZR21qbo52lXaa9R7tVu09HS8dVJ1GnUKdC57hOhy6ma6XL0c3RXap7QPem7qdhRsNYwwTDFg2rHXZ12Hu94Xr+egK9Ur29ejf0Pukz9IP0s/WX69frPzDADewMYgymGmw0OGvQO1x7uPdw3vDS4QeG3zVEDe0MYw2nG241bDHsNzI2CjESG60zOmPUa6xr7G+cZbzK+IRxjwndxNdEaLLK5KTJc4YOg8XIYaxlNDP6TA1NQ02lpltMW00/m1mbJZgVm+01e2BOMWeap5uvMm8y77MwsRhjMcOixuKuJdmSaZlpucbyvOV7K2urJKsFVvVW3dZ61hzrIusa6/s2VBs/myk2VTbXbYm2TNts2w22bXaonZtdpl2F3RV71N7dXmi/wb59BGGE5wjRiKoRtxzUHFgOBQ41Do8cdR0jHIsd6x1fjrQYmTJy+cjzI785uTnlOG1zuues5RzmXOzc6Pzaxc6F51Lhcn0UdVTwqNmjGka9crV3FbhudL3tRncb47bArcntq7uHu8S91r3Hw8Ij1aPS4xZTmxnNXMy84EnwDPCc7XnM86OXu1e+1wGvv7wdvLO9d3l3j7YeLRi9bXSnj5kP12eLT4cvwzfVd7Nvh5+pH9evyu+xv7k/33+7/zOWLSuLtZv1MsApQBJwOOA924s9k30qEAsMCSwNbA3SCkoIWh/0MNgsOCO4JrgvxC1kesipUEJoeOjy0FscIw6PU83pC/MImxnWHK4WHhe+PvxxhF2EJKJxDDombMzKMfcjLSNFkfVRIIoTtTLqQbR19JToozHEmOiYipinsc6xM2LPx9HjJsXtinsXHxC/NP5egk2CNKEpUT1xfGJ14vukwKQVSR1jR46dOfZyskGyMLkhhZSSmLI9pX9c0LjV47rGu40vGX9zgvWEwgkXJxpMzJl4fJL6JO6kg6mE1KTUXalfuFHcKm5/GietMq2Px+at4b3g+/NX8XsEPoIVgmfpPukr0rszfDJWZvRk+mWWZ/YK2cL1wldZoVmbst5nR2XvyB7IScrZm6uSm5p7RKQlyhY1TzaeXDi5XWwvLhF3TPGasnpKnyRcsj0PyZuQ15CvDX/kW6Q20l+kjwp8CyoKPkxNnHqwULNQVNgyzW7aomnPioKLfpuOT+dNb5phOmPujEczWTO3zEJmpc1qmm0+e/7srjkhc3bOpczNnvt7sVPxiuK385LmNc43mj9nfucvIb/UlNBKJCW3Fngv2LQQXyhc2Lpo1KJ1i76V8ksvlTmVlZd9WcxbfOlX51/X/jqwJH1J61L3pRuXEZeJlt1c7rd85wrNFUUrOleOWVm3irGqdNXb1ZNWXyx3Ld+0hrJGuqZjbcTahnUW65at+7I+c/2NioCKvZWGlYsq32/gb7i60X9j7SajTWWbPm0Wbr69JWRLXZVVVflW4taCrU+3JW47/xvzt+rtBtvLtn/dIdrRsTN2Z3O1R3X1LsNdS2vQGmlNz+7xu9v2BO5pqHWo3bJXd2/ZPrBPuu/5/tT9Nw+EH2g6yDxYe8jyUOVh+uHSOqRuWl1ffWZ9R0NyQ/uRsCNNjd6Nh486Ht1xzPRYxXGd40tPUE7MPzFwsuhk/ynxqd7TGac7myY13Tsz9sz15pjm1rPhZy+cCz535jzr/MkLPheOXfS6eOQS81L9ZffLdS1uLYd/d/v9cKt7a90VjysNbZ5tje2j209c9bt6+lrgtXPXOdcv34i80X4z4ebtW+Nvddzm3+6+k3Pn1d2Cu5/vzblPuF/6QONB+UPDh1V/2P6xt8O94/ijwEctj+Me3+vkdb54kvfkS9f8p9Sn5c9MnlV3u3Qf6wnuaXs+7nnXC/GLz70lf2r+WfnS5uWhv/z/aukb29f1SvJq4PXiN/pvdrx1fdvUH93/8F3uu8/vSz/of9j5kfnx/KekT88+T/1C+rL2q+3Xxm/h
|
||
|
}
|
||
|
},
|
||
|
"cell_type": "markdown",
|
||
|
"id": "ea54c848-0df6-474e-b266-218a2acf67d3",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval\n",
|
||
|
"\n",
|
||
|
"The [RAPTOR](https://arxiv.org/pdf/2401.18059.pdf) paper presents an interesting approaching for indexing and retrieval of documents:\n",
|
||
|
"\n",
|
||
|
"* The `leafs` are a set of starting documents\n",
|
||
|
"* Leafs are embedded and clustered\n",
|
||
|
"* Clusters are then summarized into higher level (more abstract) consolidations of information across similar documents\n",
|
||
|
"\n",
|
||
|
"This process is done recursivly, resulting in a \"tree\" going from raw docs (`leafs`) to more abstract summaries.\n",
|
||
|
" \n",
|
||
|
"We can applying this at varying scales; `leafs` can be:\n",
|
||
|
"\n",
|
||
|
"* Text chunks from a single doc (as shown in the paper)\n",
|
||
|
"* Full docs (as we show below)\n",
|
||
|
"\n",
|
||
|
"With longer context LLMs, it's possible to perform this over full documents. \n",
|
||
|
"\n",
|
||
|
"![Screenshot 2024-03-04 at 12.45.25 PM.png](attachment:72039e0c-e8c4-4b17-8780-04ad9fc584f3.png)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "083dd961-b401-4fc6-867c-8f8950059b02",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Docs\n",
|
||
|
"\n",
|
||
|
"Let's apply this to LangChain's LCEL documentation.\n",
|
||
|
"\n",
|
||
|
"In this case, each `doc` is a unique web page of the LCEL docs.\n",
|
||
|
"\n",
|
||
|
"The context varies from < 2k tokens on up to > 10k tokens."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "b17c1331-373f-491d-8b53-ccf634e68c8e",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"<function matplotlib.pyplot.show(close=None, block=None)>"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0kAAAIjCAYAAADWYVDIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAABGN0lEQVR4nO3deVxUZf//8ffAwLDIoriSqLjnkpZmmZaZFqm5tWqWaLZbaprZcleSmZZlVpbl3Z3LnWXZ3farNNcy01RUXAs1DS23cmEEERnm+v3hg/meEVTEgQF8PR8PHndznetc53PO5T3O23PmwmaMMQIAAAAASJIC/F0AAAAAAJQmhCQAAAAAsCAkAQAAAIAFIQkAAAAALAhJAAAAAGBBSAIAAAAAC0ISAAAAAFgQkgAAAADAgpAEAAAAABaEJAAoAXXq1NGAAQP8XUa5N2HCBNWtW1eBgYFq2bJlsR7rhx9+kM1m02effVasxwEAlDxCEgCco+nTp8tmsyk5ObnA7ddee62aNWt23sf57rvvNHr06PMe50Ixf/58PfHEE2rXrp2mTZuml156KV+fvGBTmJ+y6Pjx43r99dd1xRVXKCoqSiEhIWrYsKEeeeQRbd261d/lSZKWL1+u0aNH68iRI/4uBQBOy+7vAgDgQpCamqqAgHP7d6nvvvtOb7/9NkGpkBYvXqyAgAD95z//UXBwcIF9Lr74Yv33v//1anvqqadUoUIFPfPMMyVRZrH5559/dOONN2rNmjW66aabdOedd6pChQpKTU3V7NmzNXXqVJ04ccLfZWr58uVKSkrSgAEDFB0d7e9yAKBAhCQAKAEOh8PfJZyzzMxMhYeH+7uMQjtw4IBCQ0NPG5AkqVq1arrrrru82saPH6/KlSvnay9rBgwYoHXr1umzzz7TLbfc4rVtzJgxZT4EAkBJ4nE7ACgBp34nKScnR0lJSWrQoIFCQkIUExOj9u3ba8GCBZJOfuB9++23JanAR8AyMzM1YsQIxcXFyeFwqFGjRnr11VdljPE6blZWloYMGaLKlSsrIiJCPXr00F9//SWbzeZ1h2r06NGy2WzasmWL7rzzTlWsWFHt27eXJG3YsEEDBgxQ3bp1FRISourVq+uee+7RwYMHvY6VN8bWrVt11113KSoqSlWqVNGzzz4rY4x2796tnj17KjIyUtWrV9drr71WqGvncrk0ZswY1atXTw6HQ3Xq1NHTTz+t7OxsTx+bzaZp06YpMzPTc62mT59eqPELsmPHDt12222qVKmSwsLCdOWVV+rbb789637Z2dm66aabFBUVpeXLl0uS3G63Jk2apKZNmyokJETVqlXTAw88oMOHD3vtW6dOHd10001atmyZ2rRpo5CQENWtW1czZ84863FXrlypb7/9VoMGDcoXkKSTIf3VV1/1alu8eLGuvvpqhYeHKzo6Wj179tSvv/7q1WfAgAGqU6dOvvHy5trKZrPpkUce0ZdffqlmzZrJ4XCoadOmmjdvntd+I0eOlCTFx8d75uqPP/6QJC1YsEDt27dXdHS0KlSooEaNGunpp58+6/kDgK9xJwkAiig9PV3//PNPvvacnJyz7jt69GiNGzdO9957r9q0aSOn06nk5GStXbtW119/vR544AHt2bNHCxYsyPd4mDFGPXr00JIlSzRo0CC1bNlS33//vUaOHKm//vpLr7/+uqfvgAED9Omnn+ruu+/WlVdeqR9//FHdunU7bV233XabGjRooJdeeskTuBYsWKAdO3Zo4MCBql69ujZv3qypU6dq8+bN+uWXX/J9WL7jjjt08cUXa/z48fr222/14osvqlKlSnrvvfd03XXX6eWXX9asWbP0+OOP6/LLL9c111xzxmt17733asaMGbr11ls1YsQIrVy5UuPGjdOvv/6qL774QpL03//+V1OnTtWqVav0/vvvS5Kuuuqqs85DQfbv36+rrrpKx44d05AhQxQTE6MZM2aoR48e+uyzz9S7d+8C98vKylLPnj2VnJyshQsX6vLLL5ckPfDAA5o+fboGDhyoIUOGaOfOnZo8ebLWrVunn3/+WUFBQZ4xtm/frltvvVWDBg1SYmKiPvjgAw0YMECtWrVS06ZNT1vz119/LUm6++67C3WOCxcuVJcuXVS3bl2NHj1aWVlZeuutt9SuXTutXbu2wGBUGMuWLdPnn3+uhx9+WBEREXrzzTd1yy23aNeuXYqJidHNN9+srVu36uOPP9brr7+uypUrS5KqVKmizZs366abbtIll1yiF154QQ6HQ9u3b9fPP/9cpFoA4LwYAMA5mTZtmpF0xp+mTZt67VO7dm2TmJjoed2iRQvTrVu3Mx5n8ODBpqC36S+//NJIMi+++KJX+6233mpsNpvZvn27McaYNWvWGElm2LBhXv0GDBhgJJnnn3/e0/b8888bSaZv3775jnfs2LF8bR9//LGRZJYuXZpvjPvvv9/T5nK5TM2aNY3NZjPjx4/3tB8+fNiEhoZ6XZOCpKSkGEnm3nvv9Wp//PHHjSSzePFiT1tiYqIJDw8/43gFadq0qenQoYPn9bBhw4wk89NPP3najh49auLj402dOnVMbm6uMcaYJUuWGElmzpw55ujRo6ZDhw6mcuXKZt26dZ79fvrpJyPJzJo1y+uY8+bNy9deu3btfNf0wIEDxuFwmBEjRpzxHHr37m0kmcOHDxfqnFu2bGmqVq1qDh486Glbv369CQgIMP379/e0JSYmmtq1a+fbP2+urSSZ4OBgz5+/vDElmbfeesvTNmHCBCPJ7Ny502v/119/3Ugyf//9d6HOAQCKE4/bAUARvf3221qwYEG+n0suueSs+0ZHR2vz5s3atm3bOR/3u+++U2BgoIYMGeLVPmLECBljNHfuXEnyPOb08MMPe/V79NFHTzv2gw8+mK8tNDTU89/Hjx/XP//8oyuvvFKStHbt2nz97733Xs9/BwYGqnXr1jLGaNCgQZ726OhoNWrUSDt27DhtLdLJc5Wk4cOHe7WPGDFCkgr1CNy5+u6779SmTRvP44aSVKFCBd1///36448/tGXLFq/+6enpuuGGG/Tbb7/phx9+8Fp6fM6cOYqKitL111+vf/75x/PTqlUrVahQQUuWLPEaq0mTJrr66qs9r6tUqVKo6+R0OiVJERERZz2/vXv3KiUlRQMGDFClSpU87Zdccomuv/56zzUvis6dO6tevXpeY0ZGRp61fkmeRRy++uorud3uItcAAL5ASAKAImrTpo06d+6c76dixYpn3feFF17QkSNH1LBhQzVv3lwjR47Uhg0bCnXctLQ0xcbG5vtAfPHFF3u25/1vQECA4uPjvfrVr1//tGOf2leSDh06pKFDh6patWoKDQ1VlSpVPP3S09Pz9a9Vq5bX67ylqPMerbK2n/q9nFPlncOpNVevXl3R0dGec/WltLQ0NWrUKF/7qdc3z7Bhw7R69WotXLgw3yNx27ZtU3p6uqpWraoqVap4/WRkZOjAgQNe/U+9dpJUsWLFs16nyMhISdLRo0cLdX6STnuO//zzjzIzM886TkGKWr908jHNdu3a6d5771W1atXUp08fffrppwQmAH7Bd5IAwA+uueYa/f777/rqq680f/58vf/++3r99df17rvvet2JKWnWu0Z5br/9di1fvlwjR45Uy5YtVaFCBbndbt14440FfoANDAwsVJukfAtNnE5p/r1FPXv21OzZszV+/HjNnDnTa6l3t9utqlWratasWQXuW6VKFa/XRb1OjRs3liRt3LjR607U+Trddc/NzS2w/XzmOTQ0VEuXLtWSJUv07bffat68efrkk0903XXXaf78+acdGwCKA3eSAMBPKlWqpIEDB+rjjz/W7t27dckll3itOHe6D6i1a9fWnj178t01+O233zzb8/7X7XZr586dXv22b99e6BoPHz6sRYsW6cknn1RSUpJ69+6t66+/XnXr1i30GOcj7xxOfSxx//79OnLkiOdcfX3M1NTUfO2nXt88vXr10gcffKCPPvpIgwcP9tpWr149HTx4UO3atSvwrmOLFi18UnP37t0lSR9++OFZ++bVf7pzrFy5smf
|
||
|
"text/plain": [
|
||
|
"<Figure size 1000x600 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import tiktoken\n",
|
||
|
"from bs4 import BeautifulSoup as Soup\n",
|
||
|
"from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def num_tokens_from_string(string: str, encoding_name: str) -> int:\n",
|
||
|
" \"\"\"Returns the number of tokens in a text string.\"\"\"\n",
|
||
|
" encoding = tiktoken.get_encoding(encoding_name)\n",
|
||
|
" num_tokens = len(encoding.encode(string))\n",
|
||
|
" return num_tokens\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# LCEL docs\n",
|
||
|
"url = \"https://python.langchain.com/docs/expression_language/\"\n",
|
||
|
"loader = RecursiveUrlLoader(\n",
|
||
|
" url=url, max_depth=20, extractor=lambda x: Soup(x, \"html.parser\").text\n",
|
||
|
")\n",
|
||
|
"docs = loader.load()\n",
|
||
|
"\n",
|
||
|
"# LCEL w/ PydanticOutputParser (outside the primary LCEL docs)\n",
|
||
|
"url = \"https://python.langchain.com/docs/modules/model_io/output_parsers/quick_start\"\n",
|
||
|
"loader = RecursiveUrlLoader(\n",
|
||
|
" url=url, max_depth=1, extractor=lambda x: Soup(x, \"html.parser\").text\n",
|
||
|
")\n",
|
||
|
"docs_pydantic = loader.load()\n",
|
||
|
"\n",
|
||
|
"# LCEL w/ Self Query (outside the primary LCEL docs)\n",
|
||
|
"url = \"https://python.langchain.com/docs/modules/data_connection/retrievers/self_query/\"\n",
|
||
|
"loader = RecursiveUrlLoader(\n",
|
||
|
" url=url, max_depth=1, extractor=lambda x: Soup(x, \"html.parser\").text\n",
|
||
|
")\n",
|
||
|
"docs_sq = loader.load()\n",
|
||
|
"\n",
|
||
|
"# Doc texts\n",
|
||
|
"docs.extend([*docs_pydantic, *docs_sq])\n",
|
||
|
"docs_texts = [d.page_content for d in docs]\n",
|
||
|
"\n",
|
||
|
"# Calculate the number of tokens for each document\n",
|
||
|
"counts = [num_tokens_from_string(d, \"cl100k_base\") for d in docs_texts]\n",
|
||
|
"\n",
|
||
|
"# Plotting the histogram of token counts\n",
|
||
|
"plt.figure(figsize=(10, 6))\n",
|
||
|
"plt.hist(counts, bins=30, color=\"blue\", edgecolor=\"black\", alpha=0.7)\n",
|
||
|
"plt.title(\"Histogram of Token Counts\")\n",
|
||
|
"plt.xlabel(\"Token Count\")\n",
|
||
|
"plt.ylabel(\"Frequency\")\n",
|
||
|
"plt.grid(axis=\"y\", alpha=0.75)\n",
|
||
|
"\n",
|
||
|
"# Display the histogram\n",
|
||
|
"plt.show"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 75,
|
||
|
"id": "70750603-ec82-4439-9b32-d22014b5ff2c",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Num tokens in all context: 68705\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Doc texts concat\n",
|
||
|
"d_sorted = sorted(docs, key=lambda x: x.metadata[\"source\"])\n",
|
||
|
"d_reversed = list(reversed(d_sorted))\n",
|
||
|
"concatenated_content = \"\\n\\n\\n --- \\n\\n\\n\".join(\n",
|
||
|
" [doc.page_content for doc in d_reversed]\n",
|
||
|
")\n",
|
||
|
"print(\n",
|
||
|
" \"Num tokens in all context: %s\"\n",
|
||
|
" % num_tokens_from_string(concatenated_content, \"cl100k_base\")\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 155,
|
||
|
"id": "25ca3cf2-0f6b-40f9-a2ff-285a8dcb33dc",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Doc texts split\n",
|
||
|
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
|
||
|
"\n",
|
||
|
"chunk_size_tok = 2000\n",
|
||
|
"text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
|
||
|
" chunk_size=chunk_size_tok, chunk_overlap=0\n",
|
||
|
")\n",
|
||
|
"texts_split = text_splitter.split_text(concatenated_content)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "797a5469-0942-45a5-adb6-f12e05d76798",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Models\n",
|
||
|
"\n",
|
||
|
"We can test various models, including the new [Claude3](https://www.anthropic.com/news/claude-3-family) family.\n",
|
||
|
"\n",
|
||
|
"Be sure to set the relevant API keys:\n",
|
||
|
"\n",
|
||
|
"* `ANTHROPIC_API_KEY`\n",
|
||
|
"* `OPENAI_API_KEY`"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "033e71d3-5dc8-42a3-a0b7-4df116048c14",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from langchain_openai import OpenAIEmbeddings\n",
|
||
|
"\n",
|
||
|
"embd = OpenAIEmbeddings()\n",
|
||
|
"\n",
|
||
|
"# from langchain_openai import ChatOpenAI\n",
|
||
|
"\n",
|
||
|
"# model = ChatOpenAI(temperature=0, model=\"gpt-4-1106-preview\")\n",
|
||
|
"\n",
|
||
|
"from langchain_anthropic import ChatAnthropic\n",
|
||
|
"\n",
|
||
|
"model = ChatAnthropic(temperature=0, model=\"claude-3-opus-20240229\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "5c63db01-cf95-4c17-ae5d-8dc7267ad58a",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Tree Constrution\n",
|
||
|
"\n",
|
||
|
"The clustering approach in tree construction includes a few interesting ideas.\n",
|
||
|
"\n",
|
||
|
"**GMM (Gaussian Mixture Model)** \n",
|
||
|
"\n",
|
||
|
"- Model the distribution of data points across different clusters\n",
|
||
|
"- Optimal number of clusters by evaluating the model's Bayesian Information Criterion (BIC)\n",
|
||
|
"\n",
|
||
|
"**UMAP (Uniform Manifold Approximation and Projection)** \n",
|
||
|
"\n",
|
||
|
"- Supports clustering\n",
|
||
|
"- Reduces the dimensionality of high-dimensional data\n",
|
||
|
"- UMAP helps to highlight the natural grouping of data points based on their similarities\n",
|
||
|
"\n",
|
||
|
"**Local and Global Clustering** \n",
|
||
|
"\n",
|
||
|
"- Used to analyze data at different scales\n",
|
||
|
"- Both fine-grained and broader patterns within the data are captured effectively\n",
|
||
|
"\n",
|
||
|
"**Thresholding** \n",
|
||
|
"\n",
|
||
|
"- Apply in the context of GMM to determine cluster membership\n",
|
||
|
"- Based on the probability distribution (assignment of data points to ≥ 1 cluster)\n",
|
||
|
"---\n",
|
||
|
"\n",
|
||
|
"Code for GMM and thresholding is from Sarthi et al, as noted in the below two sources:\n",
|
||
|
" \n",
|
||
|
"* [Origional repo](https://github.com/parthsarthi03/raptor/blob/master/raptor/cluster_tree_builder.py)\n",
|
||
|
"* [Minor tweaks](https://github.com/run-llama/llama_index/blob/main/llama-index-packs/llama-index-packs-raptor/llama_index/packs/raptor/clustering.py)\n",
|
||
|
"\n",
|
||
|
"Full credit to both authors."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "a849980c-27d4-48e0-87a0-c2a5143cb8c0",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from typing import Dict, List, Optional, Tuple\n",
|
||
|
"\n",
|
||
|
"import numpy as np\n",
|
||
|
"import pandas as pd\n",
|
||
|
"import umap\n",
|
||
|
"from langchain.prompts import ChatPromptTemplate\n",
|
||
|
"from langchain_core.output_parsers import StrOutputParser\n",
|
||
|
"from sklearn.mixture import GaussianMixture\n",
|
||
|
"\n",
|
||
|
"RANDOM_SEED = 224 # Fixed seed for reproducibility\n",
|
||
|
"\n",
|
||
|
"### --- Code from citations referenced above (added comments and docstrings) --- ###\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def global_cluster_embeddings(\n",
|
||
|
" embeddings: np.ndarray,\n",
|
||
|
" dim: int,\n",
|
||
|
" n_neighbors: Optional[int] = None,\n",
|
||
|
" metric: str = \"cosine\",\n",
|
||
|
") -> np.ndarray:\n",
|
||
|
" \"\"\"\n",
|
||
|
" Perform global dimensionality reduction on the embeddings using UMAP.\n",
|
||
|
"\n",
|
||
|
" Parameters:\n",
|
||
|
" - embeddings: The input embeddings as a numpy array.\n",
|
||
|
" - dim: The target dimensionality for the reduced space.\n",
|
||
|
" - n_neighbors: Optional; the number of neighbors to consider for each point.\n",
|
||
|
" If not provided, it defaults to the square root of the number of embeddings.\n",
|
||
|
" - metric: The distance metric to use for UMAP.\n",
|
||
|
"\n",
|
||
|
" Returns:\n",
|
||
|
" - A numpy array of the embeddings reduced to the specified dimensionality.\n",
|
||
|
" \"\"\"\n",
|
||
|
" if n_neighbors is None:\n",
|
||
|
" n_neighbors = int((len(embeddings) - 1) ** 0.5)\n",
|
||
|
" return umap.UMAP(\n",
|
||
|
" n_neighbors=n_neighbors, n_components=dim, metric=metric\n",
|
||
|
" ).fit_transform(embeddings)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def local_cluster_embeddings(\n",
|
||
|
" embeddings: np.ndarray, dim: int, num_neighbors: int = 10, metric: str = \"cosine\"\n",
|
||
|
") -> np.ndarray:\n",
|
||
|
" \"\"\"\n",
|
||
|
" Perform local dimensionality reduction on the embeddings using UMAP, typically after global clustering.\n",
|
||
|
"\n",
|
||
|
" Parameters:\n",
|
||
|
" - embeddings: The input embeddings as a numpy array.\n",
|
||
|
" - dim: The target dimensionality for the reduced space.\n",
|
||
|
" - num_neighbors: The number of neighbors to consider for each point.\n",
|
||
|
" - metric: The distance metric to use for UMAP.\n",
|
||
|
"\n",
|
||
|
" Returns:\n",
|
||
|
" - A numpy array of the embeddings reduced to the specified dimensionality.\n",
|
||
|
" \"\"\"\n",
|
||
|
" return umap.UMAP(\n",
|
||
|
" n_neighbors=num_neighbors, n_components=dim, metric=metric\n",
|
||
|
" ).fit_transform(embeddings)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def get_optimal_clusters(\n",
|
||
|
" embeddings: np.ndarray, max_clusters: int = 50, random_state: int = RANDOM_SEED\n",
|
||
|
") -> int:\n",
|
||
|
" \"\"\"\n",
|
||
|
" Determine the optimal number of clusters using the Bayesian Information Criterion (BIC) with a Gaussian Mixture Model.\n",
|
||
|
"\n",
|
||
|
" Parameters:\n",
|
||
|
" - embeddings: The input embeddings as a numpy array.\n",
|
||
|
" - max_clusters: The maximum number of clusters to consider.\n",
|
||
|
" - random_state: Seed for reproducibility.\n",
|
||
|
"\n",
|
||
|
" Returns:\n",
|
||
|
" - An integer representing the optimal number of clusters found.\n",
|
||
|
" \"\"\"\n",
|
||
|
" max_clusters = min(max_clusters, len(embeddings))\n",
|
||
|
" n_clusters = np.arange(1, max_clusters)\n",
|
||
|
" bics = []\n",
|
||
|
" for n in n_clusters:\n",
|
||
|
" gm = GaussianMixture(n_components=n, random_state=random_state)\n",
|
||
|
" gm.fit(embeddings)\n",
|
||
|
" bics.append(gm.bic(embeddings))\n",
|
||
|
" return n_clusters[np.argmin(bics)]\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def GMM_cluster(embeddings: np.ndarray, threshold: float, random_state: int = 0):\n",
|
||
|
" \"\"\"\n",
|
||
|
" Cluster embeddings using a Gaussian Mixture Model (GMM) based on a probability threshold.\n",
|
||
|
"\n",
|
||
|
" Parameters:\n",
|
||
|
" - embeddings: The input embeddings as a numpy array.\n",
|
||
|
" - threshold: The probability threshold for assigning an embedding to a cluster.\n",
|
||
|
" - random_state: Seed for reproducibility.\n",
|
||
|
"\n",
|
||
|
" Returns:\n",
|
||
|
" - A tuple containing the cluster labels and the number of clusters determined.\n",
|
||
|
" \"\"\"\n",
|
||
|
" n_clusters = get_optimal_clusters(embeddings)\n",
|
||
|
" gm = GaussianMixture(n_components=n_clusters, random_state=random_state)\n",
|
||
|
" gm.fit(embeddings)\n",
|
||
|
" probs = gm.predict_proba(embeddings)\n",
|
||
|
" labels = [np.where(prob > threshold)[0] for prob in probs]\n",
|
||
|
" return labels, n_clusters\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def perform_clustering(\n",
|
||
|
" embeddings: np.ndarray,\n",
|
||
|
" dim: int,\n",
|
||
|
" threshold: float,\n",
|
||
|
") -> List[np.ndarray]:\n",
|
||
|
" \"\"\"\n",
|
||
|
" Perform clustering on the embeddings by first reducing their dimensionality globally, then clustering\n",
|
||
|
" using a Gaussian Mixture Model, and finally performing local clustering within each global cluster.\n",
|
||
|
"\n",
|
||
|
" Parameters:\n",
|
||
|
" - embeddings: The input embeddings as a numpy array.\n",
|
||
|
" - dim: The target dimensionality for UMAP reduction.\n",
|
||
|
" - threshold: The probability threshold for assigning an embedding to a cluster in GMM.\n",
|
||
|
"\n",
|
||
|
" Returns:\n",
|
||
|
" - A list of numpy arrays, where each array contains the cluster IDs for each embedding.\n",
|
||
|
" \"\"\"\n",
|
||
|
" if len(embeddings) <= dim + 1:\n",
|
||
|
" # Avoid clustering when there's insufficient data\n",
|
||
|
" return [np.array([0]) for _ in range(len(embeddings))]\n",
|
||
|
"\n",
|
||
|
" # Global dimensionality reduction\n",
|
||
|
" reduced_embeddings_global = global_cluster_embeddings(embeddings, dim)\n",
|
||
|
" # Global clustering\n",
|
||
|
" global_clusters, n_global_clusters = GMM_cluster(\n",
|
||
|
" reduced_embeddings_global, threshold\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" all_local_clusters = [np.array([]) for _ in range(len(embeddings))]\n",
|
||
|
" total_clusters = 0\n",
|
||
|
"\n",
|
||
|
" # Iterate through each global cluster to perform local clustering\n",
|
||
|
" for i in range(n_global_clusters):\n",
|
||
|
" # Extract embeddings belonging to the current global cluster\n",
|
||
|
" global_cluster_embeddings_ = embeddings[\n",
|
||
|
" np.array([i in gc for gc in global_clusters])\n",
|
||
|
" ]\n",
|
||
|
"\n",
|
||
|
" if len(global_cluster_embeddings_) == 0:\n",
|
||
|
" continue\n",
|
||
|
" if len(global_cluster_embeddings_) <= dim + 1:\n",
|
||
|
" # Handle small clusters with direct assignment\n",
|
||
|
" local_clusters = [np.array([0]) for _ in global_cluster_embeddings_]\n",
|
||
|
" n_local_clusters = 1\n",
|
||
|
" else:\n",
|
||
|
" # Local dimensionality reduction and clustering\n",
|
||
|
" reduced_embeddings_local = local_cluster_embeddings(\n",
|
||
|
" global_cluster_embeddings_, dim\n",
|
||
|
" )\n",
|
||
|
" local_clusters, n_local_clusters = GMM_cluster(\n",
|
||
|
" reduced_embeddings_local, threshold\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" # Assign local cluster IDs, adjusting for total clusters already processed\n",
|
||
|
" for j in range(n_local_clusters):\n",
|
||
|
" local_cluster_embeddings_ = global_cluster_embeddings_[\n",
|
||
|
" np.array([j in lc for lc in local_clusters])\n",
|
||
|
" ]\n",
|
||
|
" indices = np.where(\n",
|
||
|
" (embeddings == local_cluster_embeddings_[:, None]).all(-1)\n",
|
||
|
" )[1]\n",
|
||
|
" for idx in indices:\n",
|
||
|
" all_local_clusters[idx] = np.append(\n",
|
||
|
" all_local_clusters[idx], j + total_clusters\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" total_clusters += n_local_clusters\n",
|
||
|
"\n",
|
||
|
" return all_local_clusters\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"### --- Our code below --- ###\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def embed(texts):\n",
|
||
|
" \"\"\"\n",
|
||
|
" Generate embeddings for a list of text documents.\n",
|
||
|
"\n",
|
||
|
" This function assumes the existence of an `embd` object with a method `embed_documents`\n",
|
||
|
" that takes a list of texts and returns their embeddings.\n",
|
||
|
"\n",
|
||
|
" Parameters:\n",
|
||
|
" - texts: List[str], a list of text documents to be embedded.\n",
|
||
|
"\n",
|
||
|
" Returns:\n",
|
||
|
" - numpy.ndarray: An array of embeddings for the given text documents.\n",
|
||
|
" \"\"\"\n",
|
||
|
" text_embeddings = embd.embed_documents(texts)\n",
|
||
|
" text_embeddings_np = np.array(text_embeddings)\n",
|
||
|
" return text_embeddings_np\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def embed_cluster_texts(texts):\n",
|
||
|
" \"\"\"\n",
|
||
|
" Embeds a list of texts and clusters them, returning a DataFrame with texts, their embeddings, and cluster labels.\n",
|
||
|
"\n",
|
||
|
" This function combines embedding generation and clustering into a single step. It assumes the existence\n",
|
||
|
" of a previously defined `perform_clustering` function that performs clustering on the embeddings.\n",
|
||
|
"\n",
|
||
|
" Parameters:\n",
|
||
|
" - texts: List[str], a list of text documents to be processed.\n",
|
||
|
"\n",
|
||
|
" Returns:\n",
|
||
|
" - pandas.DataFrame: A DataFrame containing the original texts, their embeddings, and the assigned cluster labels.\n",
|
||
|
" \"\"\"\n",
|
||
|
" text_embeddings_np = embed(texts) # Generate embeddings\n",
|
||
|
" cluster_labels = perform_clustering(\n",
|
||
|
" text_embeddings_np, 10, 0.1\n",
|
||
|
" ) # Perform clustering on the embeddings\n",
|
||
|
" df = pd.DataFrame() # Initialize a DataFrame to store the results\n",
|
||
|
" df[\"text\"] = texts # Store original texts\n",
|
||
|
" df[\"embd\"] = list(text_embeddings_np) # Store embeddings as a list in the DataFrame\n",
|
||
|
" df[\"cluster\"] = cluster_labels # Store cluster labels\n",
|
||
|
" return df\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def fmt_txt(df: pd.DataFrame) -> str:\n",
|
||
|
" \"\"\"\n",
|
||
|
" Formats the text documents in a DataFrame into a single string.\n",
|
||
|
"\n",
|
||
|
" Parameters:\n",
|
||
|
" - df: DataFrame containing the 'text' column with text documents to format.\n",
|
||
|
"\n",
|
||
|
" Returns:\n",
|
||
|
" - A single string where all text documents are joined by a specific delimiter.\n",
|
||
|
" \"\"\"\n",
|
||
|
" unique_txt = df[\"text\"].tolist()\n",
|
||
|
" return \"--- --- \\n --- --- \".join(unique_txt)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def embed_cluster_summarize_texts(\n",
|
||
|
" texts: List[str], level: int\n",
|
||
|
") -> Tuple[pd.DataFrame, pd.DataFrame]:\n",
|
||
|
" \"\"\"\n",
|
||
|
" Embeds, clusters, and summarizes a list of texts. This function first generates embeddings for the texts,\n",
|
||
|
" clusters them based on similarity, expands the cluster assignments for easier processing, and then summarizes\n",
|
||
|
" the content within each cluster.\n",
|
||
|
"\n",
|
||
|
" Parameters:\n",
|
||
|
" - texts: A list of text documents to be processed.\n",
|
||
|
" - level: An integer parameter that could define the depth or detail of processing.\n",
|
||
|
"\n",
|
||
|
" Returns:\n",
|
||
|
" - Tuple containing two DataFrames:\n",
|
||
|
" 1. The first DataFrame (`df_clusters`) includes the original texts, their embeddings, and cluster assignments.\n",
|
||
|
" 2. The second DataFrame (`df_summary`) contains summaries for each cluster, the specified level of detail,\n",
|
||
|
" and the cluster identifiers.\n",
|
||
|
" \"\"\"\n",
|
||
|
"\n",
|
||
|
" # Embed and cluster the texts, resulting in a DataFrame with 'text', 'embd', and 'cluster' columns\n",
|
||
|
" df_clusters = embed_cluster_texts(texts)\n",
|
||
|
"\n",
|
||
|
" # Prepare to expand the DataFrame for easier manipulation of clusters\n",
|
||
|
" expanded_list = []\n",
|
||
|
"\n",
|
||
|
" # Expand DataFrame entries to document-cluster pairings for straightforward processing\n",
|
||
|
" for index, row in df_clusters.iterrows():\n",
|
||
|
" for cluster in row[\"cluster\"]:\n",
|
||
|
" expanded_list.append(\n",
|
||
|
" {\"text\": row[\"text\"], \"embd\": row[\"embd\"], \"cluster\": cluster}\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" # Create a new DataFrame from the expanded list\n",
|
||
|
" expanded_df = pd.DataFrame(expanded_list)\n",
|
||
|
"\n",
|
||
|
" # Retrieve unique cluster identifiers for processing\n",
|
||
|
" all_clusters = expanded_df[\"cluster\"].unique()\n",
|
||
|
"\n",
|
||
|
" print(f\"--Generated {len(all_clusters)} clusters--\")\n",
|
||
|
"\n",
|
||
|
" # Summarization\n",
|
||
|
" template = \"\"\"Here is a sub-set of LangChain Expression Langauge doc. \n",
|
||
|
" \n",
|
||
|
" LangChain Expression Langauge provides a way to compose chain in LangChain.\n",
|
||
|
" \n",
|
||
|
" Give a detailed summary of the documentation provided.\n",
|
||
|
" \n",
|
||
|
" Documentation:\n",
|
||
|
" {context}\n",
|
||
|
" \"\"\"\n",
|
||
|
" prompt = ChatPromptTemplate.from_template(template)\n",
|
||
|
" chain = prompt | model | StrOutputParser()\n",
|
||
|
"\n",
|
||
|
" # Format text within each cluster for summarization\n",
|
||
|
" summaries = []\n",
|
||
|
" for i in all_clusters:\n",
|
||
|
" df_cluster = expanded_df[expanded_df[\"cluster\"] == i]\n",
|
||
|
" formatted_txt = fmt_txt(df_cluster)\n",
|
||
|
" summaries.append(chain.invoke({\"context\": formatted_txt}))\n",
|
||
|
"\n",
|
||
|
" # Create a DataFrame to store summaries with their corresponding cluster and level\n",
|
||
|
" df_summary = pd.DataFrame(\n",
|
||
|
" {\n",
|
||
|
" \"summaries\": summaries,\n",
|
||
|
" \"level\": [level] * len(summaries),\n",
|
||
|
" \"cluster\": list(all_clusters),\n",
|
||
|
" }\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" return df_clusters, df_summary\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def recursive_embed_cluster_summarize(\n",
|
||
|
" texts: List[str], level: int = 1, n_levels: int = 3\n",
|
||
|
") -> Dict[int, Tuple[pd.DataFrame, pd.DataFrame]]:\n",
|
||
|
" \"\"\"\n",
|
||
|
" Recursively embeds, clusters, and summarizes texts up to a specified level or until\n",
|
||
|
" the number of unique clusters becomes 1, storing the results at each level.\n",
|
||
|
"\n",
|
||
|
" Parameters:\n",
|
||
|
" - texts: List[str], texts to be processed.\n",
|
||
|
" - level: int, current recursion level (starts at 1).\n",
|
||
|
" - n_levels: int, maximum depth of recursion.\n",
|
||
|
"\n",
|
||
|
" Returns:\n",
|
||
|
" - Dict[int, Tuple[pd.DataFrame, pd.DataFrame]], a dictionary where keys are the recursion\n",
|
||
|
" levels and values are tuples containing the clusters DataFrame and summaries DataFrame at that level.\n",
|
||
|
" \"\"\"\n",
|
||
|
" results = {} # Dictionary to store results at each level\n",
|
||
|
"\n",
|
||
|
" # Perform embedding, clustering, and summarization for the current level\n",
|
||
|
" df_clusters, df_summary = embed_cluster_summarize_texts(texts, level)\n",
|
||
|
"\n",
|
||
|
" # Store the results of the current level\n",
|
||
|
" results[level] = (df_clusters, df_summary)\n",
|
||
|
"\n",
|
||
|
" # Determine if further recursion is possible and meaningful\n",
|
||
|
" unique_clusters = df_summary[\"cluster\"].nunique()\n",
|
||
|
" if level < n_levels and unique_clusters > 1:\n",
|
||
|
" # Use summaries as the input texts for the next level of recursion\n",
|
||
|
" new_texts = df_summary[\"summaries\"].tolist()\n",
|
||
|
" next_level_results = recursive_embed_cluster_summarize(\n",
|
||
|
" new_texts, level + 1, n_levels\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" # Merge the results from the next level into the current results dictionary\n",
|
||
|
" results.update(next_level_results)\n",
|
||
|
"\n",
|
||
|
" return results"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "f0d8cd3e-cd49-484d-9617-1b9811cc08b3",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"--Generated 7 clusters--\n",
|
||
|
"--Generated 1 clusters--\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Build tree\n",
|
||
|
"leaf_texts = docs_texts\n",
|
||
|
"results = recursive_embed_cluster_summarize(leaf_texts, level=1, n_levels=3)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "e80d7098-5d16-4fa6-837c-968e5c9f118d",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"The paper reports best performance from `collapsed tree retrieval`. \n",
|
||
|
"\n",
|
||
|
"This involves flattening the tree structure into a single layer and then applying a k-nearest neighbors (kNN) search across all nodes simultaneously. \n",
|
||
|
"\n",
|
||
|
"We do simply do this below."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "d28ba9e6-9124-41a8-b4fd-55a6ef4ac062",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from langchain_community.vectorstores import Chroma\n",
|
||
|
"\n",
|
||
|
"# Initialize all_texts with leaf_texts\n",
|
||
|
"all_texts = leaf_texts.copy()\n",
|
||
|
"\n",
|
||
|
"# Iterate through the results to extract summaries from each level and add them to all_texts\n",
|
||
|
"for level in sorted(results.keys()):\n",
|
||
|
" # Extract summaries from the current level's DataFrame\n",
|
||
|
" summaries = results[level][1][\"summaries\"].tolist()\n",
|
||
|
" # Extend all_texts with the summaries from the current level\n",
|
||
|
" all_texts.extend(summaries)\n",
|
||
|
"\n",
|
||
|
"# Now, use all_texts to build the vectorstore with Chroma\n",
|
||
|
"vectorstore = Chroma.from_texts(texts=all_texts, embedding=embd)\n",
|
||
|
"retriever = vectorstore.as_retriever()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "0d497627-44c6-41f7-bb63-1d858d3f188f",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now we can using our flattened, indexed tree in a RAG chain."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "9d6c894b-b3a3-4a01-b779-3e98ea382ff5",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"'Here is a code example of how to define a RAG (Retrieval Augmented Generation) chain in LangChain:\\n\\n```python\\nfrom langchain.vectorstores import FAISS\\nfrom langchain.embeddings import OpenAIEmbeddings\\nfrom langchain.prompts import ChatPromptTemplate\\nfrom langchain.chat_models import ChatOpenAI\\nfrom langchain.output_parsers import StrOutputParser\\n\\n# Load documents into vector store\\nvectorstore = FAISS.from_texts(\\n [\"harrison worked at kensho\"], embedding=OpenAIEmbeddings()\\n)\\nretriever = vectorstore.as_retriever()\\n\\n# Define prompt template\\ntemplate = \"\"\"Answer the question based only on the following context:\\n{context}\\nQuestion: {question}\"\"\"\\nprompt = ChatPromptTemplate.from_template(template)\\n\\n# Define model and output parser\\nmodel = ChatOpenAI()\\noutput_parser = StrOutputParser()\\n\\n# Define RAG chain\\nchain = (\\n {\"context\": retriever, \"question\": RunnablePassthrough()}\\n | prompt\\n | model '"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from langchain import hub\n",
|
||
|
"from langchain_core.runnables import RunnablePassthrough\n",
|
||
|
"\n",
|
||
|
"# Prompt\n",
|
||
|
"prompt = hub.pull(\"rlm/rag-prompt\")\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# Post-processing\n",
|
||
|
"def format_docs(docs):\n",
|
||
|
" return \"\\n\\n\".join(doc.page_content for doc in docs)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# Chain\n",
|
||
|
"rag_chain = (\n",
|
||
|
" {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
|
||
|
" | prompt\n",
|
||
|
" | model\n",
|
||
|
" | StrOutputParser()\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"# Question\n",
|
||
|
"rag_chain.invoke(\"How to define a RAG chain? Give me a specific code example.\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "0c585b37-ad83-4069-8f5d-4a6a3e15128d",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Trace: \n",
|
||
|
"\n",
|
||
|
"https://smith.langchain.com/public/1dabf475-1675-4494-b16c-928fbf079851/r"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.16"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|