You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/cookbook/Semi_Structured_RAG.ipynb

387 lines
182 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"id": "80e7ce94-33a4-480b-af65-2e7f78eab032",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from lxml import html\n",
"from pydantic import BaseModel\n",
"from typing import Any, Optional\n",
"from unstructured.partition.pdf import partition_pdf"
]
},
{
"cell_type": "markdown",
"id": "b6d466cc-aa8b-4baf-a80a-fef01921ca8d",
"metadata": {},
"source": [
"# Use Case\n",
"\n",
"Many documents contain a mixture of content types, including text, images, and tables. \n",
"\n",
"* `Semi-structured data`: RAG on text and tables has remained a challenge\n",
"* `Image`: images often contain valuable information that are excluded from RAG due to model limitations\n",
"\n",
"Here, we show how Unstructured can be used to partition all semi-structured data from documents. \n",
"\n",
"And, we show RAG on semi-structured data.\n",
"\n",
"In a follow-up notebook, we show how this can be extended to images."
]
},
{
"attachments": {
"54b0f9ee-a563-49e4-836e-804b1a1f7b80.png": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABmkAAAG5CAYAAABlftDiAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkJAQIICAlNCbIFIDSAmhBZBeBBshCRBKjIGgYkcXFVy7iIANXRVR7IDYETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMuTP33gGAfpwnkeSimgDkiQukcaGBzNEpqUzSU0AEdEAFVkCLx8+XsGNiIgEsA+3fy7vrAJG3VxzlWv/s/69FSyDM5wOAxECcLsjn50G8HwC8mi+RFgBAlPMWkwskcgwr0JHCACFeIMeZSlwtx+lKvFthkxDHgbgVADUqjyfNBEDjEuSZhfxMqKHRC7GzWCASA0BnQuyXlzdRAHEaxLbQRgKxXJ+V/oNO5t800wc1ebzMQayci6KoBYnyJbm8qf9nOv53ycuVDfiwhpWaJQ2Lk88Z5u1mzsQIOaZC3CNOj4qGWBviDyKBwh5ilJIlC0tU2qNG/HwOzBnQg9hZwAuKgNgI4hBxblSkik/PEIVwIYYrBJ0iKuAmQKwP8QJhfnC8ymaDdGKcyhfakCHlsFX8WZ5U4Vfu674sJ5Gt0n+dJeSq9DGNoqyEZIgpEFsWipKiINaA2Ck/Jz5CZTOyKIsTNWAjlcXJ47eEOE4oDg1U6mOFGdKQOJV9aV7+wHyxDVkibpQK7y3ISghT5gdr5fMU8cO5YJeEYnbigI4wf3TkwFwEwqBg5dyxZ0JxYrxK54OkIDBOORanSHJjVPa4uTA3VM6bQ+yWXxivGosnFcAFqdTHMyQFMQnKOPGibF54jDIefCmIBBwQBJhABms6mAiygai9p7EH3il7QgAPSEEmEAJHFTMwIlnRI4bXeFAE/oRICPIHxwUqeoWgEPJfB1nl1RFkKHoLFSNywBOI80AEyIX3MsUo8aC3JPAYMqJ/eOfByofx5sIq7//3/AD7nWFDJlLFyAY8MukDlsRgYhAxjBhCtMMNcT/cB4+E1wBYXXAW7jUwj+/2hCeEDsJDwjVCJ+HWBFGx9KcoR4FOqB+iykX6j7nAraGmOx6I+0J1qIzr4YbAEXeDfti4P/TsDlmOKm55Vpg/af9tBj88DZUd2ZmMkoeQA8i2P4/UsNdwH1SR5/rH/ChjTR/MN2ew52f/nB+yL4BtxM+W2AJsH3YGO4Gdww5jjYCJHcOasDbsiBwPrq7HitU14C1OEU8O1BH9w9/Ak5VnMt+5zrnb+Yuyr0A4Rf6OBpyJkqlSUWZWAZMNvwhCJlfMdxrGdHF2cQVA/n1Rvr7exCq+G4he23du7h8A+B7r7+8/9J0LPwbAHk+4/Q9+52xZ8NOhDsDZg3yZtFDJ4fILAb4l6HCnGQATYAFs4XxcgAfwAQEgGISDaJAAUsB4GH0WXOdSMBlMB3NACSgDS8EqUAnWg01gG9gJ9oJGcBicAKfBBXAJXAN34OrpAi9AL3gHPiMIQkJoCAMxQEwRK8QBcUFYiB8SjEQicUgKkoZkImJEhkxH5iJlyHKkEtmI1CJ7kIPICeQc0oHcQh4g3chr5BOKoVRUBzVGrdHhKAtloxFoAjoOzUQnoUXoPHQxWoHWoDvQBvQEegG9hnaiL9A+DGDqmB5mhjliLIyDRWOpWAYmxWZipVg5VoPVY83wOV/BOrEe7CNOxBk4E3eEKzgMT8T5+CR8Jr4Ir8S34Q14K34Ff4D34t8INIIRwYHgTeASRhMyCZMJJYRywhbCAcIpuJe6CO+IRKIe0YboCfdiCjGbOI24iLiWuIt4nNhBfETsI5FIBiQHki8pmsQjFZBKSGtIO0jHSJdJXaQPaupqpmouaiFqqWpitWK1crXtakfVLqs9VftM1iRbkb3J0WQBeSp5CXkzuZl8kdxF/kzRothQfCkJlGzKHEoFpZ5yinKX8kZdXd1c3Us9Vl2kPlu9Qn23+ln1B+ofqdpUeyqHOpYqoy6mbqUep96ivqHRaNa0AFoqrYC2mFZLO0m7T/ugwdBw0uBqCDRmaVRpNGhc1nhJJ9Ot6Gz6eHoRvZy+j36R3qNJ1rTW5GjyNGdqVmke1Lyh2afF0BqhFa2Vp7VIa7vWOa1n2iRta+1gbYH2PO1N2ie1HzEwhgWDw+Az5jI2M04xunSIOjY6XJ1snTKdnTrtOr262rpuukm6U3SrdI/oduphetZ6XL1cvSV6e/Wu630aYjyEPUQ4ZOGQ+iGXh7zXH6ofoC/UL9XfpX9N/5MB0yDYIMdgmUGjwT1D3NDeMNZwsuE6w1OGPUN1hvoM5Q8tHbp36G0j1MjeKM5omtEmozajPmMT41BjifEa45PGPSZ6JgEm2SYrTY6adJsyTP1MRaYrTY+ZPmfqMtnMXGYFs5XZa2ZkFmYmM9to1m722dzGPNG82HyX+T0LigXLIsNipUWLRa+lqeUoy+mWdZa3rchWLKssq9VWZ6zeW9tYJ1vPt260fmajb8O1KbKps7lrS7P1t51kW2N71Y5ox7LLsVtrd8ketXe3z7Kvsr/ogDp4OIgc1jp0DCMM8xomHlYz7IYj1ZHtWOhY5/jASc8p0qnYqdHp5XDL4anDlw0/M/ybs7tzrvNm5zsjtEeEjyge0TzitYu9C9+lyuWqK801xHWWa5PrKzcHN6HbOreb7gz3Ue7z3Vvcv3p4ekg96j26PS090zyrPW+wdFgxrEWss14Er0CvWV6HvT56e3gXeO/1/svH0SfHZ7vPs5E2I4UjN4985Gvuy/Pd6Nvpx/RL89vg1+lv5s/zr/F/GGARIAjYEvCUbcfOZu9gvwx0DpQGHgh8z/HmzOAcD8KCQoNKg9qDtYMTgyuD74eYh2SG1IX0hrqHTgs9HkYIiwhbFnaDa8zlc2u5veGe4TPCWyOoEfERlREPI+0jpZHNo9BR4aNWjLobZRUljmqMBtHc6BXR92JsYibFHIolxsbEVsU+iRsRNz3uTDwjfkL89vh3CYEJSxLuJNomyhJbkuhJY5Nqk94nByUvT+4cPXz0jNEXUgxTRClNqaTUpNQtqX1jgsesGtM11n1sydjr42zGTRl3brzh+NzxRybQJ/Am7EsjpCWnbU/7wovm1fD60rnp1em9fA5/Nf+FIECwUtAt9BUuFz7N8M1YnvEs0zdzRWZ3ln9WeVaPiCOqFL3KDsten/0+Jzpna05/bnLurjy1vLS8g2JtcY64daLJxCkTOyQOkhJJ5yTvSasm9UojpFvykfxx+U0FOvBHvk1mK/tF9qDQr7Cq8MPkpMn7pmhNEU9pm2o/deHUp0UhRb9Nw6fxp7VMN5s+Z/qDGewZG2ciM9NntsyymDVvVtfs0Nnb5lDm5Mz5vdi5eHnx27nJc5vnGc+bPe/RL6G/1JVolEhLbsz3mb9+Ab5AtKB9oevCNQu/lQpKz5c5l5WXfVnEX3T+1xG/VvzavzhjcfsSjyXrlhKXipdeX+a/bNtyreVFyx+tGLWiYSVzZenKt6smrDpX7la+fjVltWx1Z0VkRdMayzVL13ypzKq8VhVYtavaqHph9fu1grWX1wWsq19vvL5s/acNog03N4ZubKixrinfRNxUuOnJ5qTNZ35j/Va7xXBL2ZavW8VbO7fFbWut9ayt3W60fUkdWier694xdselnUE7m+od6zfu0ttVthvslu1+vidtz/W9EXtb9rH21e+32l99gHGgtAFpmNrQ25jV2NmU0tRxMPxgS7NP84FDToe2HjY7XHVE98iSo5Sj8472Hys61ndccrznROaJRy0TWu6cHH3yamtsa/upiFNnT4ecPnmGfebYWd+zh895nzt4nnW+8YLHhYY297YDv7v/fqDdo73houfFpktel5o7RnYcvex/+cSVoCunr3KvXrgWda3jeuL1mzfG3ui8Kbj57FburVe3C29/vjP7LuFu6T3Ne+X3je7X/GH3x65Oj84jD4IetD2Mf3jnEf/Ri8f5j790zXtCe1L+1PRp7TOXZ4e7Q7ovPR/zvOuF5MXnnpI/tf6sfmn7cv9fAX+19Y7u7XolfdX/etEbgzdb37q9bemL6bv/Lu/d5/elHww+bPvI+njmU/Knp58nfyF9qfhq97X5
}
},
"cell_type": "markdown",
"id": "c2dd1ac7-f553-4d7c-919d-5b088cac6981",
"metadata": {},
"source": [
"![img_optional_flow.png](attachment:54b0f9ee-a563-49e4-836e-804b1a1f7b80.png)"
]
},
{
"cell_type": "markdown",
"id": "dc871e89-649b-41af-902d-a7bc94808d16",
"metadata": {},
"source": [
"## Data Loading\n",
"\n",
"### Partition PDF tables and text w/ Unstructured\n",
" \n",
"* `LLaMA2` Paper: https://arxiv.org/pdf/2307.09288.pdf\n",
"* Use `chunking_strategy=\"by_title\"`, which rolls up subsequent non-Table elements under a Title into a `CompositeElement`\n",
"* Unstructured uses a [table recognition module](https://unstructured-io.github.io/unstructured/introduction.html#tables)."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "57790bd1-7a11-48d5-90de-4813e4776578",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at microsoft/table-transformer-structure-recognition were not used when initializing TableTransformerForObjectDetection: ['model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked']\n",
"- This IS expected if you are initializing TableTransformerForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing TableTransformerForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"# Get elements\n",
"raw_pdf_elements = partition_pdf(filename=\"/Users/rlm/Desktop/Papers/2307.09288.pdf\",\n",
" # Using pdf format to find embedded image blocks\n",
" extract_images_in_pdf=False,\n",
" # Use layout model (YOLO-X) to get bounding boxes (for tables) and find titles\n",
" # Titles are any sub-section of the document \n",
" infer_table_structure=True, \n",
" # Post processing to aggregate text once we have the title \n",
" chunking_strategy=\"by_title\",\n",
" # Chunking params to aggregate text blocks\n",
" # Attempt to create a new chunk 3800 chars\n",
" # Attempt to keep chunks > 2000 chars \n",
" # Hard max on chunks\n",
" max_characters=4000, \n",
" new_after_n_chars=3800, \n",
" combine_text_under_n_chars=2000,\n",
" image_output_dir_path=path)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "628abfc6-4057-434b-b880-d88e3ba44657",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{\"<class 'unstructured.documents.elements.CompositeElement'>\": 184,\n",
" \"<class 'unstructured.documents.elements.Table'>\": 47,\n",
" \"<class 'unstructured.documents.elements.TableChunk'>\": 2}"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create a dictionary to store counts of each type\n",
"category_counts = {}\n",
"\n",
"for element in raw_pdf_elements:\n",
" category = str(type(element))\n",
" if category in category_counts:\n",
" category_counts[category] += 1\n",
" else:\n",
" category_counts[category] = 1\n",
"\n",
"# Unique_categories will have unique elements\n",
"unique_categories = set(category_counts.keys())\n",
"category_counts"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "5462f29e-fd59-4e0e-9493-ea3b560e523e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"49\n",
"184\n"
]
}
],
"source": [
"class Element(BaseModel):\n",
" type: str\n",
" text: Any\n",
"\n",
"# Categorize by type\n",
"categorized_elements = []\n",
"for element in raw_pdf_elements:\n",
" if \"unstructured.documents.elements.Table\" in str(type(element)):\n",
" categorized_elements.append(Element(type=\"table\", text=str(element)))\n",
" elif \"unstructured.documents.elements.CompositeElement\" in str(type(element)):\n",
" categorized_elements.append(Element(type=\"text\", text=str(element)))\n",
"\n",
"# Tables\n",
"table_elements = [e for e in categorized_elements if e.type == \"table\"]\n",
"print(len(table_elements))\n",
"\n",
"# Text\n",
"text_elements = [e for e in categorized_elements if e.type == \"text\"]\n",
"print(len(text_elements))"
]
},
{
"cell_type": "markdown",
"id": "731b3dfc-7ddf-4a11-9a30-9a79b7c66e16",
"metadata": {},
"source": [
"## Multi-vector retriever\n",
"\n",
"Use [multi-vector-retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/multi_vector#summary).\n",
"\n",
"### Summaries"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "8e275736-3408-4d7a-990e-4362c88e81f8",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.prompts import ChatPromptTemplate\n",
"from langchain.schema.output_parser import StrOutputParser"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "1b12536a-1303-41ad-9948-4eb5a5f32614",
"metadata": {},
"outputs": [],
"source": [
"# Prompt \n",
"prompt_text=\"\"\"You are an assistant tasked with summarizing tables and text. \\ \n",
"Give a concise summary of the table or text. Table or text chunk: {element} \"\"\"\n",
"prompt = ChatPromptTemplate.from_template(prompt_text) \n",
"\n",
"# Summary chain \n",
"model = ChatOpenAI(temperature=0,model=\"gpt-4\")\n",
"summarize_chain = {\"element\": lambda x:x} | prompt | model | StrOutputParser()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8d8b567c-b442-4bf0-b639-04bd89effc62",
"metadata": {},
"outputs": [],
"source": [
"# Apply to tables\n",
"tables = [i.text for i in table_elements]\n",
"table_summaries = summarize_chain.batch(tables, {\"max_concurrency\": 5})"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "3e9c176c-3d46-4034-b169-0d7305d42d27",
"metadata": {},
"outputs": [],
"source": [
"# Apply to texts\n",
"texts = [i.text for i in text_elements]\n",
"text_summaries = summarize_chain.batch(texts, {\"max_concurrency\": 5})"
]
},
{
"cell_type": "markdown",
"id": "60524010-754f-4924-ad75-78cb54ca7257",
"metadata": {},
"source": [
"### Add to vectorstore\n",
"\n",
"Use [Multi Vector Retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/multi_vector#summary) with summaries."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "346c3a02-8fea-4f75-a69e-fc9542b99dbc",
"metadata": {},
"outputs": [],
"source": [
"import uuid\n",
"from langchain.vectorstores import Chroma\n",
"from langchain.storage import InMemoryStore\n",
"from langchain.schema.document import Document\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.retrievers.multi_vector import MultiVectorRetriever\n",
"\n",
"# The vectorstore to use to index the child chunks\n",
"vectorstore = Chroma(\n",
" collection_name=\"summaries\",\n",
" embedding_function=OpenAIEmbeddings()\n",
")\n",
"\n",
"# The storage layer for the parent documents\n",
"store = InMemoryStore()\n",
"id_key = \"doc_id\"\n",
"\n",
"# The retriever (empty to start)\n",
"retriever = MultiVectorRetriever(\n",
" vectorstore=vectorstore, \n",
" docstore=store, \n",
" id_key=id_key,\n",
")\n",
"\n",
"# Add texts\n",
"doc_ids = [str(uuid.uuid4()) for _ in texts]\n",
"summary_texts = [Document(page_content=s,metadata={id_key: doc_ids[i]}) for i, s in enumerate(text_summaries)]\n",
"retriever.vectorstore.add_documents(summary_texts)\n",
"retriever.docstore.mset(list(zip(doc_ids, texts)))\n",
"\n",
"# Add tables\n",
"table_ids = [str(uuid.uuid4()) for _ in tables]\n",
"summary_tables = [Document(page_content=s,metadata={id_key: table_ids[i]}) for i, s in enumerate(table_summaries)]\n",
"retriever.vectorstore.add_documents(summary_tables)\n",
"retriever.docstore.mset(list(zip(table_ids, tables)))"
]
},
{
"cell_type": "markdown",
"id": "1d8bbbd9-009b-4b34-a206-5874a60adbda",
"metadata": {},
"source": [
"## RAG\n",
"\n",
"Run [RAG pipeline](https://python.langchain.com/docs/expression_language/cookbook/retrieval)."
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "f2489de4-51e3-48b4-bbcd-ed9171deadf3",
"metadata": {},
"outputs": [],
"source": [
"from operator import itemgetter\n",
"from langchain.schema.runnable import RunnablePassthrough\n",
"\n",
"# Prompt template\n",
"template = \"\"\"Answer the question based only on the following context, which can include text and tables:\n",
"{context}\n",
"Question: {question}\n",
"\"\"\"\n",
"prompt = ChatPromptTemplate.from_template(template)\n",
"\n",
"# LLM\n",
"model = ChatOpenAI(temperature=0,model=\"gpt-4\")\n",
"\n",
"# RAG pipeline\n",
"chain = (\n",
" {\"context\": retriever, \"question\": RunnablePassthrough()} \n",
" | prompt \n",
" | model \n",
" | StrOutputParser()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "90e3d100-10e8-4ee6-ae46-2480b1524ec8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'The number of training tokens for LLaMA2 is 2.0T.'"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke(\"What is the number of training tokens for LLaMA2?\")"
]
},
{
"cell_type": "markdown",
"id": "37f46054-e239-4ba8-af81-22d0d6a9bc32",
"metadata": {},
"source": [
"We can check the [trace](https://smith.langchain.com/public/4739ae7c-1a13-406d-bc4e-3462670ebc01/r) to see what chunks were retrieved:\n",
"\n",
"This includes Table 1 of the paper, showing the Tokens used for training.\n",
"\n",
"```\n",
"Training Data Params Context GQA Tokens LR Length 7B 2k 1.0T 3.0x 10-4 See Touvron et al. 13B 2k 1.0T 3.0 x 10-4 LiaMa 1 (2023) 33B 2k 14T 1.5 x 10-4 65B 2k 1.4T 1.5 x 10-4 7B 4k 2.0T 3.0x 10-4 Liama 2 A new mix of publicly 13B 4k 2.0T 3.0 x 10-4 available online data 34B 4k v 2.0T 1.5 x 10-4 70B 4k v 2.0T 1.5 x 10-4\n",
"```"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}