You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/cookbook/Multi_modal_RAG.ipynb

631 lines
489 KiB
Plaintext

{
"cells": [
{
"attachments": {
"9bbbcfe4-2b85-4e76-996a-ce8d1497d34e.png": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABnkAAAMxCAYAAAAnrNaWAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkJAQIICAlNCbIFIDSAmhBZBeBBshCRBKjIGgYkcXFVy7iIANXRVR7IDYETuLYu+LBRVlXSzYlTcpoOu+8r35vrnz33/O/OfMuTP33gGAfpwnkeSimgDkiQukcaGBzNEpqUzSU0AEdEAFVkCLx8+XsGNiIgEsA+3fy7vrAJG3VxzlWv/s/69FSyDM5wOAxECcLsjn50G8HwC8mi+RFgBAlPMWkwskcgwr0JHCACFeIMeZSlwtx+lKvFthkxDHgbgVADUqjyfNBEDjEuSZhfxMqKHRC7GzWCASA0BnQuyXlzdRAHEaxLbQRgKxXJ+V/oNO5t800wc1ebzMQayci6KoBYnyJbm8qf9nOv53ycuVDfiwhpWaJQ2Lk88Z5u1mzsQIOaZC3CNOj4qGWBviDyKBwh5ilJIlC0tU2qNG/HwOzBnQg9hZwAuKgNgI4hBxblSkik/PEIVwIYYrBJ0iKuAmQKwP8QJhfnC8ymaDdGKcyhfakCHlsFX8WZ5U4Vfu674sJ5Gt0n+dJeSq9DGNoqyEZIgpEFsWipKiINaA2Ck/Jz5CZTOyKIsTNWAjlcXJ47eEOE4oDg1U6mOFGdKQOJV9aV7+wHyxDVkibpQK7y3ISghT5gdr5fMU8cO5YJeEYnbigI4wf3TkwFwEwqBg5dyxZ0JxYrxK54OkIDBOORanSHJjVPa4uTA3VM6bQ+yWXxivGosnFcAFqdTHMyQFMQnKOPGibF54jDIefCmIBBwQBJhABms6mAiygai9p7EH3il7QgAPSEEmEAJHFTMwIlnRI4bXeFAE/oRICPIHxwUqeoWgEPJfB1nl1RFkKHoLFSNywBOI80AEyIX3MsUo8aC3JPAYMqJ/eOfByofx5sIq7//3/AD7nWFDJlLFyAY8MukDlsRgYhAxjBhCtMMNcT/cB4+E1wBYXXAW7jUwj+/2hCeEDsJDwjVCJ+HWBFGx9KcoR4FOqB+iykX6j7nAraGmOx6I+0J1qIzr4YbAEXeDfti4P/TsDlmOKm55Vpg/af9tBj88DZUd2ZmMkoeQA8i2P4/UsNdwH1SR5/rH/ChjTR/MN2ew52f/nB+yL4BtxM+W2AJsH3YGO4Gdww5jjYCJHcOasDbsiBwPrq7HitU14C1OEU8O1BH9w9/Ak5VnMt+5zrnb+Yuyr0A4Rf6OBpyJkqlSUWZWAZMNvwhCJlfMdxrGdHF2cQVA/n1Rvr7exCq+G4he23du7h8A+B7r7+8/9J0LPwbAHk+4/Q9+52xZ8NOhDsDZg3yZtFDJ4fILAb4l6HCnGQATYAFs4XxcgAfwAQEgGISDaJAAUsB4GH0WXOdSMBlMB3NACSgDS8EqUAnWg01gG9gJ9oJGcBicAKfBBXAJXAN34OrpAi9AL3gHPiMIQkJoCAMxQEwRK8QBcUFYiB8SjEQicUgKkoZkImJEhkxH5iJlyHKkEtmI1CJ7kIPICeQc0oHcQh4g3chr5BOKoVRUBzVGrdHhKAtloxFoAjoOzUQnoUXoPHQxWoHWoDvQBvQEegG9hnaiL9A+DGDqmB5mhjliLIyDRWOpWAYmxWZipVg5VoPVY83wOV/BOrEe7CNOxBk4E3eEKzgMT8T5+CR8Jr4Ir8S34Q14K34Ff4D34t8INIIRwYHgTeASRhMyCZMJJYRywhbCAcIpuJe6CO+IRKIe0YboCfdiCjGbOI24iLiWuIt4nNhBfETsI5FIBiQHki8pmsQjFZBKSGtIO0jHSJdJXaQPaupqpmouaiFqqWpitWK1crXtakfVLqs9VftM1iRbkb3J0WQBeSp5CXkzuZl8kdxF/kzRothQfCkJlGzKHEoFpZ5yinKX8kZdXd1c3Us9Vl2kPlu9Qn23+ln1B+ofqdpUeyqHOpYqoy6mbqUep96ivqHRaNa0AFoqrYC2mFZLO0m7T/ugwdBw0uBqCDRmaVRpNGhc1nhJJ9Ot6Gz6eHoRvZy+j36R3qNJ1rTW5GjyNGdqVmke1Lyh2afF0BqhFa2Vp7VIa7vWOa1n2iRta+1gbYH2PO1N2ie1HzEwhgWDw+Az5jI2M04xunSIOjY6XJ1snTKdnTrtOr262rpuukm6U3SrdI/oduphetZ6XL1cvSV6e/Wu630aYjyEPUQ4ZOGQ+iGXh7zXH6ofoC/UL9XfpX9N/5MB0yDYIMdgmUGjwT1D3NDeMNZwsuE6w1OGPUN1hvoM5Q8tHbp36G0j1MjeKM5omtEmozajPmMT41BjifEa45PGPSZ6JgEm2SYrTY6adJsyTP1MRaYrTY+ZPmfqMtnMXGYFs5XZa2ZkFmYmM9to1m722dzGPNG82HyX+T0LigXLIsNipUWLRa+lqeUoy+mWdZa3rchWLKssq9VWZ6zeW9tYJ1vPt260fmajb8O1KbKps7lrS7P1t51kW2N71Y5ox7LLsVtrd8ketXe3z7Kvsr/ogDp4OIgc1jp0DCMM8xomHlYz7IYj1ZHtWOhY5/jASc8p0qnYqdHp5XDL4anDlw0/M/ybs7tzrvNm5zsjtEeEjyge0TzitYu9C9+lyuWqK801xHWWa5PrKzcHN6HbOreb7gz3Ue7z3Vvcv3p4ekg96j26PS090zyrPW+wdFgxrEWss14Er0CvWV6HvT56e3gXeO/1/svH0SfHZ7vPs5E2I4UjN4985Gvuy/Pd6Nvpx/RL89vg1+lv5s/zr/F/GGARIAjYEvCUbcfOZu9gvwx0DpQGHgh8z/HmzOAcD8KCQoNKg9qDtYMTgyuD74eYh2SG1IX0hrqHTgs9HkYIiwhbFnaDa8zlc2u5veGe4TPCWyOoEfERlREPI+0jpZHNo9BR4aNWjLobZRUljmqMBtHc6BXR92JsYibFHIolxsbEVsU+iRsRNz3uTDwjfkL89vh3CYEJSxLuJNomyhJbkuhJY5Nqk94nByUvT+4cPXz0jNEXUgxTRClNqaTUpNQtqX1jgsesGtM11n1sydjr42zGTRl3brzh+NzxRybQJ/Am7EsjpCWnbU/7wovm1fD60rnp1em9fA5/Nf+FIECwUtAt9BUuFz7N8M1YnvEs0zdzRWZ3ln9WeVaPiCOqFL3KDsten/0+Jzpna05/bnLurjy1vLS8g2JtcY64daLJxCkTOyQOkhJJ5yTvSasm9UojpFvykfxx+U0FOvBHvk1mK/tF9qDQr7Cq8MPkpMn7pmhNEU9pm2o/deHUp0UhRb9Nw6fxp7VMN5s+Z/qDGewZG2ciM9NntsyymDVvVtfs0Nnb5lDm5Mz5vdi5eHnx27nJc5vnGc+bPe/RL6G/1JVolEhLbsz3mb9+Ab5AtKB9oevCNQu/lQpKz5c5l5WXfVnEX3T+1xG/VvzavzhjcfsSjyXrlhKXipdeX+a/bNtyreVFyx+tGLWiYSVzZenKt6smrDpX7la+fjVltWx1Z0VkRdMayzVL13ypzKq8VhVYtavaqHph9fu1grWX1wWsq19vvL5s/acNog03N4ZubKixrinfRNxUuOnJ5qTNZ35j/Va7xXBL2ZavW8VbO7fFbWut9ayt3W60fUkdWier694xdselnUE7m+od6zfu0ttVthvslu1+vidtz/W9EXtb9rH21e+32l99gHGgtAFpmNrQ25jV2NmU0tRxMPxgS7NP84FDToe2HjY7XHVE98iSo5Sj8472Hys61ndccrznROaJRy0TWu6cHH3yamtsa/upiFNnT4ecPnmGfebYWd+zh895nzt4nnW+8YLHhYY297YDv7v/fqDdo73houfFpktel5o7RnYcvex/+cSVoCunr3KvXrgWda3jeuL1mzfG3ui8Kbj57FburVe3C29/vjP7LuFu6T3Ne+X3je7X/GH3x65Oj84jD4IetD2Mf3jnEf/Ri8f5j790zXtCe1L+1PRp7TOXZ4e7Q7ovPR/zvOuF5MXnnpI/tf6sfmn7cv9fAX+19Y7u7XolfdX/etEbgzdb37q9bemL6bv/Lu/d5/elHww+bPvI+njmU/Knp58nfyF9qfhq97X5
}
},
"cell_type": "markdown",
"id": "812a4dbc-fe04-4b84-bdf9-390045e30806",
"metadata": {},
"source": [
"## Multi-modal RAG\n",
"\n",
"[See Trace of Option 3](https://smith.langchain.com/public/db0441a8-2c17-4070-bdf7-45d4fdf8f517/r/80cb0f89-1766-4caf-8959-fc43ec4b071c)\n",
"\n",
"Many documents contain a mixture of content types, including text and images. \n",
"\n",
"Yet, information captured in images is lost in most RAG applications.\n",
"\n",
"With the emergence of multimodal LLMs, like [GPT-4V](https://openai.com/research/gpt-4v-system-card), it is worth considering how to utilize images in RAG:\n",
"\n",
"`Option 1:` \n",
"\n",
"* Use multimodal embeddings (such as [CLIP](https://openai.com/research/clip)) to embed images and text\n",
"* Retrieve both using similarity search\n",
"* Pass raw images and text chunks to a multimodal LLM for answer synthesis \n",
"\n",
"`Option 2:` \n",
"\n",
"* Use a multimodal LLM (such as [GPT-4V](https://openai.com/research/gpt-4v-system-card), [LLaVA](https://llava.hliu.cc/), or [FUYU-8b](https://www.adept.ai/blog/fuyu-8b)) to produce text summaries from images\n",
"* Embed and retrieve text \n",
"* Pass text chunks to an LLM for answer synthesis \n",
"\n",
"`Option 3` (Shown) \n",
"\n",
"* Use a multimodal LLM (such as [GPT-4V](https://openai.com/research/gpt-4v-system-card), [LLaVA](https://llava.hliu.cc/), or [FUYU-8b](https://www.adept.ai/blog/fuyu-8b)) to produce text summaries from images\n",
"* Embed and retrieve image summaries with a reference to the raw image \n",
"* Pass raw images and text chunks to a multimodal LLM for answer synthesis \n",
"\n",
"This cookbook highlights `Option 3`. \n",
"\n",
"* We will use [Unstructured](https://unstructured.io/) to parse images, text, and tables from documents (PDFs).\n",
"* We will use the [multi-vector retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/multi_vector) with [Chroma](https://www.trychroma.com/) to store raw text and images along with their summaries for retrieval.\n",
"\n",
"A seperate cookbook highlights `Option 1` [here](https://github.com/langchain-ai/langchain/blob/master/cookbook/multi_modal_RAG_chroma.ipynb).\n",
"\n",
"And option `Option 2` is appropriate for cases when a multi-modal LLM cannot be used for answer synthesis (e.g., cost, etc).\n",
"\n",
"![ss_mm_rag.png](attachment:9bbbcfe4-2b85-4e76-996a-ce8d1497d34e.png)\n",
"\n",
"## Packages\n",
"\n",
"In addition to the below pip packages, you will also need `poppler` ([installation instructions](https://pdf2image.readthedocs.io/en/latest/installation.html)) and `tesseract` ([installation instructions](https://tesseract-ocr.github.io/tessdoc/Installation.html)) in your system."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98f9ee74-395f-4aa4-9695-c00ade01195a",
"metadata": {},
"outputs": [],
"source": [
"! pip install \"openai>=1\" \"langchain>=0.0.331rc2\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "140580ef-5db0-43cc-a524-9c39e04d4df0",
"metadata": {},
"outputs": [],
"source": [
"! pip install \"unstructured[all-docs]\" pillow pydantic lxml pillow matplotlib chromadb tiktoken"
]
},
{
"cell_type": "markdown",
"id": "74b56bde-1ba0-4525-a11d-cab02c5659e4",
"metadata": {},
"source": [
"## Data Loading\n",
"\n",
"### Partition PDF tables, text, and images\n",
" \n",
"Let's look at an [example whitepaper](https://sgp.fas.org/crs/misc/IF10244.pdf) that provides a mixture of tables, text, and images about Wildfires in the US.\n",
"\n",
"We use Unstructured to partition it (see [blog post](https://blog.langchain.dev/semi-structured-multi-modal-rag/))."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "61cbb874-ecc0-4d5d-9954-f0a41f65e0d7",
"metadata": {},
"outputs": [],
"source": [
"path = \"/Users/rlm/Desktop/wildfire_stats/\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0a046528-8d22-4f4e-a520-962026562939",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from unstructured.partition.pdf import partition_pdf\n",
"\n",
"# Extract images, tables, and chunk text\n",
"raw_pdf_elements = partition_pdf(\n",
" filename=path + \"wildfire_stats.pdf\",\n",
" extract_images_in_pdf=True,\n",
" infer_table_structure=True,\n",
" chunking_strategy=\"by_title\",\n",
" max_characters=4000,\n",
" new_after_n_chars=3800,\n",
" combine_text_under_n_chars=2000,\n",
" image_output_dir_path=path,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2ad9fcc5-57e0-495a-9632-28a1b368f9cd",
"metadata": {},
"outputs": [],
"source": [
"# Categorize by type\n",
"tables = []\n",
"texts = []\n",
"for element in raw_pdf_elements:\n",
" if \"unstructured.documents.elements.Table\" in str(type(element)):\n",
" tables.append(str(element))\n",
" elif \"unstructured.documents.elements.CompositeElement\" in str(type(element)):\n",
" texts.append(str(element))"
]
},
{
"cell_type": "markdown",
"id": "0aa7f52f-bf5c-4ba4-af72-b2ccba59a4cf",
"metadata": {},
"source": [
"## Multi-vector retriever\n",
"\n",
"Use [multi-vector-retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/multi_vector#summary).\n",
"\n",
"### Text and Table summaries\n",
"\n",
"Summaries are used to retrieve raw tables and / or raw chunks of text."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "523e6ed2-2132-4748-bdb7-db765f20648d",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.prompts import ChatPromptTemplate\n",
"from langchain.schema.output_parser import StrOutputParser"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "22c22e3f-42fb-4a4a-a87a-89f10ba8ab99",
"metadata": {},
"outputs": [],
"source": [
"# Prompt\n",
"prompt_text = \"\"\"You are an assistant tasked with summarizing tables and text. \\ \n",
"Give a concise summary of the table or text. Table or text chunk: {element} \"\"\"\n",
"prompt = ChatPromptTemplate.from_template(prompt_text)\n",
"\n",
"# Summary chain\n",
"model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
"summarize_chain = {\"element\": lambda x: x} | prompt | model | StrOutputParser()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f176b374-aef0-48f4-a104-fb26b1dd6922",
"metadata": {},
"outputs": [],
"source": [
"# Apply to text\n",
"# Typically this is reccomended only if you have large text chunks\n",
"text_summaries = texts # Skip it\n",
"\n",
"# Apply to tables\n",
"table_summaries = summarize_chain.batch(tables, {\"max_concurrency\": 5})"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b1feadda-8171-4aed-9a60-320a88dc9ee1",
"metadata": {},
"source": [
"### Image summaries \n",
"\n",
"We will use [GPT4-V](https://openai.com/research/gpt-4v-system-card) to produce the image summaries.\n",
"\n",
"See the traces for each of the 5 ingested images here ([1](https://smith.langchain.com/public/f5548212-2e70-4fa8-91d6-c3e7d768d52b/r), \n",
"[2](https://smith.langchain.com/public/8b198178-5b83-4960-bbc1-c10516779208/r), \n",
"[3](https://smith.langchain.com/public/c4fcbcd5-38fb-462a-9ed1-e90b1d009fa9/r), \n",
"[4](https://smith.langchain.com/public/1df53c23-63b8-4f87-b5ae-e9d59b2a54ab/r), \n",
"[5](https://smith.langchain.com/public/f93efd6c-f9f6-46c9-b169-29270d33ad63/r))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "9e6b1d97-4245-45ac-95ba-9bc1cfd10182",
"metadata": {},
"outputs": [],
"source": [
"import base64\n",
"import io\n",
"import os\n",
"\n",
"import numpy as np\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.schema.messages import HumanMessage, SystemMessage\n",
"from PIL import Image\n",
"\n",
"\n",
"def encode_image(image_path):\n",
" \"\"\"Getting the base64 string\"\"\"\n",
" with open(image_path, \"rb\") as image_file:\n",
" return base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
"\n",
"\n",
"def image_summarize(img_base64, prompt):\n",
" \"\"\"Image summary\"\"\"\n",
" chat = ChatOpenAI(model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
"\n",
" msg = chat.invoke(\n",
" [\n",
" HumanMessage(\n",
" content=[\n",
" {\"type\": \"text\", \"text\": prompt},\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\"url\": f\"data:image/jpeg;base64,{img_base64}\"},\n",
" },\n",
" ]\n",
" )\n",
" ]\n",
" )\n",
" return msg.content\n",
"\n",
"\n",
"# Store base64 encoded images\n",
"img_base64_list = []\n",
"\n",
"# Store image summaries\n",
"image_summaries = []\n",
"\n",
"# Prompt\n",
"prompt = \"Describe the image in detail. Be specific about graphs, such as bar plots.\"\n",
"\n",
"# Read images, encode to base64 strings\n",
"for img_file in sorted(os.listdir(path)):\n",
" if img_file.endswith(\".jpg\"):\n",
" img_path = os.path.join(path, img_file)\n",
" base64_image = encode_image(img_path)\n",
" img_base64_list.append(base64_image)\n",
" image_summaries.append(image_summarize(base64_image, prompt))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7f0b5405-fe45-4aa4-b5b4-6c973a25c168",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<img src=\"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from IPython.display import HTML, display\n",
"\n",
"\n",
"def plt_img_base64(img_base64):\n",
" # Create an HTML img tag with the base64 string as the source\n",
" image_html = f'<img src=\"data:image/jpeg;base64,{img_base64}\" />'\n",
"\n",
" # Display the image by rendering the HTML\n",
" display(HTML(image_html))\n",
"\n",
"\n",
"plt_img_base64(img_base64_list[1])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "ffd987d8-c182-4b06-b15e-8a305ec82fcb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"This is a dual-axis line chart that illustrates two different datasets over a 30-year period from 1993 to 2022. On the left vertical axis, we have the number of fires, measured in thousands, which goes from 0 to 120. On the right vertical axis, there's the number of acres burned, measured in millions, which ranges from 0 to 12.\\n\\nThe chart shows two lines: one for the number of fires and one for the acres burned. The line representing fires is a solid red line, and it fluctuates over the years, with some peaks and troughs. It appears to have a slight overall downward trend over the 30-year span. The line for acres burned is represented by a gray area chart with a darker outline, and it also shows variability. Notably, the acres burned appear to have an overall upward trend, with some years having significantly higher values than others.\\n\\nThe data points are marked annually, with each year labeled at the bottom of the chart from 1993 to 2022. The chart seems to indicate that while the number of fires has not consistently increased, the number of acres burned per year has generally been on the rise, suggesting that fires may have become larger or more severe over time.\""
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"image_summaries[1]"
]
},
{
"cell_type": "markdown",
"id": "67b030d4-2ac5-41b6-9245-fc3ba5771d87",
"metadata": {},
"source": [
"### Add to vectorstore\n",
"\n",
"Add raw docs and doc summaries to [Multi Vector Retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/multi_vector#summary)."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "9d8d7a34-69e0-49a2-b9f7-1a4e7b26d78f",
"metadata": {},
"outputs": [],
"source": [
"import uuid\n",
"\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.retrievers.multi_vector import MultiVectorRetriever\n",
"from langchain.schema.document import Document\n",
"from langchain.storage import InMemoryStore\n",
"from langchain.vectorstores import Chroma\n",
"\n",
"# The vectorstore to use to index the child chunks\n",
"vectorstore = Chroma(\n",
" collection_name=\"multi_modal_rag\", embedding_function=OpenAIEmbeddings()\n",
")\n",
"\n",
"# The storage layer for the parent documents\n",
"store = InMemoryStore()\n",
"id_key = \"doc_id\"\n",
"\n",
"# The retriever (empty to start)\n",
"retriever = MultiVectorRetriever(\n",
" vectorstore=vectorstore,\n",
" docstore=store,\n",
" id_key=id_key,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "b40d6eb6-c256-47a2-858c-c102557e83ce",
"metadata": {},
"source": [
"* Store the raw texts and tables in the `docstore`.\n",
"* Store the table summaries in the `vectorstore` for semantic retrieval.\n",
"* Use the tables in answer synthesis."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "1792e683",
"metadata": {},
"outputs": [],
"source": [
"# Add texts\n",
"doc_ids = [str(uuid.uuid4()) for _ in texts]\n",
"summary_texts = [\n",
" Document(page_content=s, metadata={id_key: doc_ids[i]})\n",
" for i, s in enumerate(text_summaries)\n",
"]\n",
"retriever.vectorstore.add_documents(summary_texts)\n",
"retriever.docstore.mset(list(zip(doc_ids, texts)))\n",
"\n",
"# Add tables\n",
"table_ids = [str(uuid.uuid4()) for _ in tables]\n",
"summary_tables = [\n",
" Document(page_content=s, metadata={id_key: table_ids[i]})\n",
" for i, s in enumerate(table_summaries)\n",
"]\n",
"retriever.vectorstore.add_documents(summary_tables)\n",
"retriever.docstore.mset(list(zip(table_ids, tables)))"
]
},
{
"cell_type": "markdown",
"id": "6d667e5c-5385-48c4-b878-51dcc03cc4d0",
"metadata": {},
"source": [
"* Store the images in the `docstore`.\n",
"* Store the image summaries in the `vectorstore` for semantic retrieval.\n",
"* Using the image in answer synthesis with a multimodal LLM."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "24a0a289-b970-49fe-b04f-5d857a4c159b",
"metadata": {},
"outputs": [],
"source": [
"# Add image summaries\n",
"img_ids = [str(uuid.uuid4()) for _ in img_base64_list]\n",
"summary_img = [\n",
" Document(page_content=s, metadata={id_key: img_ids[i]})\n",
" for i, s in enumerate(image_summaries)\n",
"]\n",
"retriever.vectorstore.add_documents(summary_img)\n",
"retriever.docstore.mset(list(zip(img_ids, img_base64_list)))"
]
},
{
"cell_type": "markdown",
"id": "4b45fb81-46b1-426e-aa2c-01aed4eac700",
"metadata": {},
"source": [
"### Check retrieval\n",
"\n",
"The mult-vector retriever will return base64-encoded images or text documents.\n",
"\n",
"Confirm we can get images back based on natural language search.\n",
"\n",
"Here is our retrieval of that table from the natural language query:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1bea75fe-85af-4955-a80c-6e0b44a8e215",
"metadata": {},
"outputs": [],
"source": [
"# Retrieve\n",
"docs = retriever.get_relevant_documents(\n",
" \"What is the change in wild fires from 1993 to 2022?\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "174ed58c-101a-401f-bd9b-7bf7bafe2827",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<img src=\"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from base64 import b64decode\n",
"\n",
"\n",
"def split_image_text_types(docs):\n",
" \"\"\"Split base64-encoded images and texts\"\"\"\n",
" b64 = []\n",
" text = []\n",
" for doc in docs:\n",
" try:\n",
" b64decode(doc)\n",
" b64.append(doc)\n",
" except Exception as e:\n",
" text.append(doc)\n",
" return {\"images\": b64, \"texts\": text}\n",
"\n",
"\n",
"docs_by_type = split_image_text_types(docs)\n",
"plt_img_base64(docs_by_type[\"images\"][0])"
]
},
{
"cell_type": "markdown",
"id": "69060724-e390-4dda-8250-5f86025c874a",
"metadata": {},
"source": [
"## RAG\n",
"\n",
"Currently, we format the inputs using a `RunnableLambda` while we add image support to `ChatPromptTemplates`.\n",
"\n",
"Our runnable follows the classic RAG flow - \n",
"\n",
"* We first compute the context (both \"texts\" and \"images\" in this case) and the question (just a RunnablePassthrough here) \n",
"* Then we pass this into our prompt template, which is a custom function that formats the message for the gpt-4-vision-preview model. \n",
"* And finally we parse the output as a string."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "771a47fa-1267-4db8-a6ae-5fde48bbc069",
"metadata": {},
"outputs": [],
"source": [
"from operator import itemgetter\n",
"\n",
"from langchain.schema.runnable import RunnableLambda, RunnablePassthrough\n",
"\n",
"\n",
"def prompt_func(dict):\n",
" format_texts = \"\\n\".join(dict[\"context\"][\"texts\"])\n",
" return [\n",
" HumanMessage(\n",
" content=[\n",
" {\n",
" \"type\": \"text\",\n",
" \"text\": f\"\"\"Answer the question based only on the following context, which can include text, tables, and the below image:\n",
"Question: {dict[\"question\"]}\n",
"\n",
"Text and tables:\n",
"{format_texts}\n",
"\"\"\",\n",
" },\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\n",
" \"url\": f\"data:image/jpeg;base64,{dict['context']['images'][0]}\"\n",
" },\n",
" },\n",
" ]\n",
" )\n",
" ]\n",
"\n",
"\n",
"model = ChatOpenAI(temperature=0, model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
"\n",
"# RAG pipeline\n",
"chain = (\n",
" {\n",
" \"context\": retriever | RunnableLambda(split_image_text_types),\n",
" \"question\": RunnablePassthrough(),\n",
" }\n",
" | RunnableLambda(prompt_func)\n",
" | model\n",
" | StrOutputParser()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "ea8414a8-65ee-4e11-8154-029b454f46af",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: api.smith.langchain.com. Connection pool size: 10\n"
]
},
{
"data": {
"text/plain": [
"'Based on the provided image, which shows a graph of annual wildfires and acres burned from 1993 to 2022, there is a noticeable fluctuation in both the number of fires and the acres burned over the years. In 1993, the number of fires was around 80,000, and the acres burned were approximately 2 million. By 2022, the number of fires appears to have decreased slightly to just under 80,000, while the acres burned have increased significantly to around 10 million.\\n\\nTherefore, from 1993 to 2022, the change in wildfires shows a slight decrease in the number of fires but a substantial increase in the total acres burned.'"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke(\"What is the change in wild fires from 1993 to 2022?\")"
]
},
{
"cell_type": "markdown",
"id": "dea241f1-bd11-45cb-bb33-c4e2e8286855",
"metadata": {},
"source": [
"Here is the [trace](https://smith.langchain.com/public/db0441a8-2c17-4070-bdf7-45d4fdf8f517/r/80cb0f89-1766-4caf-8959-fc43ec4b071c). "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}