mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
700 lines
182 KiB
Plaintext
700 lines
182 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "SzvBjdID1V3m",
|
||
|
"metadata": {
|
||
|
"id": "SzvBjdID1V3m"
|
||
|
},
|
||
|
"source": [
|
||
|
"# Multi-modal RAG with Google Cloud"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "4tfidrmE1Zlo",
|
||
|
"metadata": {
|
||
|
"id": "4tfidrmE1Zlo"
|
||
|
},
|
||
|
"source": [
|
||
|
"This tutorial demonstrates how to implement the Option 2 described [here](https://github.com/langchain-ai/langchain/blob/master/cookbook/Multi_modal_RAG.ipynb) with Generative API on Google Cloud."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "84fcd59f-2eaf-4a76-ad1a-96d6db70bf42",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Setup\n",
|
||
|
"\n",
|
||
|
"Install the required dependencies, and create an API key for your Google service."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "6b1e10dd-25de-4c0a-9577-f36e72518f89",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"%pip install -U --quiet langchain langchain_community openai chromadb langchain-experimental\n",
|
||
|
"%pip install --quiet \"unstructured[all-docs]\" pypdf pillow pydantic lxml pillow matplotlib chromadb tiktoken"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "pSInKtCZ32mt",
|
||
|
"metadata": {
|
||
|
"id": "pSInKtCZ32mt"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Data loading"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "Iv2R8-lJ37dG",
|
||
|
"metadata": {
|
||
|
"id": "Iv2R8-lJ37dG"
|
||
|
},
|
||
|
"source": [
|
||
|
"We use a zip file with a sub-set of the extracted images and pdf from [this](https://cloudedjudgement.substack.com/p/clouded-judgement-111023) blog post. If you want to follow the full flow, please, use the original [example](https://github.com/langchain-ai/langchain/blob/master/cookbook/Multi_modal_RAG.ipynb)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "d999f3fe-c165-4772-b63e-ffe4dd5b03cf",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# First download\n",
|
||
|
"import logging\n",
|
||
|
"import zipfile\n",
|
||
|
"\n",
|
||
|
"import requests\n",
|
||
|
"\n",
|
||
|
"logging.basicConfig(level=logging.INFO)\n",
|
||
|
"\n",
|
||
|
"data_url = \"https://storage.googleapis.com/benchmarks-artifacts/langchain-docs-benchmarking/cj.zip\"\n",
|
||
|
"result = requests.get(data_url)\n",
|
||
|
"filename = \"cj.zip\"\n",
|
||
|
"with open(filename, \"wb\") as file:\n",
|
||
|
" file.write(result.content)\n",
|
||
|
"\n",
|
||
|
"with zipfile.ZipFile(filename, \"r\") as zip_ref:\n",
|
||
|
" zip_ref.extractall()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "eGUfuevMUA6R",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from langchain.document_loaders import PyPDFLoader\n",
|
||
|
"\n",
|
||
|
"loader = PyPDFLoader(\"./cj/cj.pdf\")\n",
|
||
|
"docs = loader.load()\n",
|
||
|
"tables = []\n",
|
||
|
"texts = [d.page_content for d in docs]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "Fst17fNHWYcq",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"21"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"len(texts)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "vjfcg_Vn3_1C",
|
||
|
"metadata": {
|
||
|
"id": "vjfcg_Vn3_1C"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Multi-vector retriever"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "1ynRqJn04BFG",
|
||
|
"metadata": {
|
||
|
"id": "1ynRqJn04BFG"
|
||
|
},
|
||
|
"source": [
|
||
|
"Let's generate text and image summaries and save them to a ChromaDB vectorstore."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "kWDWfSDBMPl8",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"INFO:numexpr.utils:Note: NumExpr detected 12 cores but \"NUMEXPR_MAX_THREADS\" not set, so enforcing safe limit of 8.\n",
|
||
|
"INFO:numexpr.utils:NumExpr defaulting to 8 threads.\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from langchain.chat_models import ChatVertexAI\n",
|
||
|
"from langchain.llms import VertexAI\n",
|
||
|
"from langchain.prompts import PromptTemplate\n",
|
||
|
"from langchain.schema.output_parser import StrOutputParser\n",
|
||
|
"from langchain_core.messages import AIMessage\n",
|
||
|
"from langchain_core.runnables import RunnableLambda\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# Generate summaries of text elements\n",
|
||
|
"def generate_text_summaries(texts, tables, summarize_texts=False):\n",
|
||
|
" \"\"\"\n",
|
||
|
" Summarize text elements\n",
|
||
|
" texts: List of str\n",
|
||
|
" tables: List of str\n",
|
||
|
" summarize_texts: Bool to summarize texts\n",
|
||
|
" \"\"\"\n",
|
||
|
"\n",
|
||
|
" # Prompt\n",
|
||
|
" prompt_text = \"\"\"You are an assistant tasked with summarizing tables and text for retrieval. \\\n",
|
||
|
" These summaries will be embedded and used to retrieve the raw text or table elements. \\\n",
|
||
|
" Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} \"\"\"\n",
|
||
|
" prompt = PromptTemplate.from_template(prompt_text)\n",
|
||
|
" empty_response = RunnableLambda(\n",
|
||
|
" lambda x: AIMessage(content=\"Error processing document\")\n",
|
||
|
" )\n",
|
||
|
" # Text summary chain\n",
|
||
|
" model = VertexAI(\n",
|
||
|
" temperature=0, model_name=\"gemini-pro\", max_output_tokens=1024\n",
|
||
|
" ).with_fallbacks([empty_response])\n",
|
||
|
" summarize_chain = {\"element\": lambda x: x} | prompt | model | StrOutputParser()\n",
|
||
|
"\n",
|
||
|
" # Initialize empty summaries\n",
|
||
|
" text_summaries = []\n",
|
||
|
" table_summaries = []\n",
|
||
|
"\n",
|
||
|
" # Apply to text if texts are provided and summarization is requested\n",
|
||
|
" if texts and summarize_texts:\n",
|
||
|
" text_summaries = summarize_chain.batch(texts, {\"max_concurrency\": 1})\n",
|
||
|
" elif texts:\n",
|
||
|
" text_summaries = texts\n",
|
||
|
"\n",
|
||
|
" # Apply to tables if tables are provided\n",
|
||
|
" if tables:\n",
|
||
|
" table_summaries = summarize_chain.batch(tables, {\"max_concurrency\": 1})\n",
|
||
|
"\n",
|
||
|
" return text_summaries, table_summaries\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# Get text, table summaries\n",
|
||
|
"text_summaries, table_summaries = generate_text_summaries(\n",
|
||
|
" texts, tables, summarize_texts=True\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "F0NnyUl48yYb",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"21"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"len(text_summaries)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "PeK9bzXv3olF",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import base64\n",
|
||
|
"import os\n",
|
||
|
"\n",
|
||
|
"from langchain.schema.messages import HumanMessage\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def encode_image(image_path):\n",
|
||
|
" \"\"\"Getting the base64 string\"\"\"\n",
|
||
|
" with open(image_path, \"rb\") as image_file:\n",
|
||
|
" return base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def image_summarize(img_base64, prompt):\n",
|
||
|
" \"\"\"Make image summary\"\"\"\n",
|
||
|
" model = ChatVertexAI(model_name=\"gemini-pro-vision\", max_output_tokens=1024)\n",
|
||
|
"\n",
|
||
|
" msg = model(\n",
|
||
|
" [\n",
|
||
|
" HumanMessage(\n",
|
||
|
" content=[\n",
|
||
|
" {\"type\": \"text\", \"text\": prompt},\n",
|
||
|
" {\n",
|
||
|
" \"type\": \"image_url\",\n",
|
||
|
" \"image_url\": {\"url\": f\"data:image/jpeg;base64,{img_base64}\"},\n",
|
||
|
" },\n",
|
||
|
" ]\n",
|
||
|
" )\n",
|
||
|
" ]\n",
|
||
|
" )\n",
|
||
|
" return msg.content\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def generate_img_summaries(path):\n",
|
||
|
" \"\"\"\n",
|
||
|
" Generate summaries and base64 encoded strings for images\n",
|
||
|
" path: Path to list of .jpg files extracted by Unstructured\n",
|
||
|
" \"\"\"\n",
|
||
|
"\n",
|
||
|
" # Store base64 encoded images\n",
|
||
|
" img_base64_list = []\n",
|
||
|
"\n",
|
||
|
" # Store image summaries\n",
|
||
|
" image_summaries = []\n",
|
||
|
"\n",
|
||
|
" # Prompt\n",
|
||
|
" prompt = \"\"\"You are an assistant tasked with summarizing images for retrieval. \\\n",
|
||
|
" These summaries will be embedded and used to retrieve the raw image. \\\n",
|
||
|
" Give a concise summary of the image that is well optimized for retrieval.\"\"\"\n",
|
||
|
"\n",
|
||
|
" # Apply to images\n",
|
||
|
" for img_file in sorted(os.listdir(path)):\n",
|
||
|
" if img_file.endswith(\".jpg\"):\n",
|
||
|
" img_path = os.path.join(path, img_file)\n",
|
||
|
" base64_image = encode_image(img_path)\n",
|
||
|
" img_base64_list.append(base64_image)\n",
|
||
|
" image_summaries.append(image_summarize(base64_image, prompt))\n",
|
||
|
"\n",
|
||
|
" return img_base64_list, image_summaries\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# Image summaries\n",
|
||
|
"img_base64_list, image_summaries = generate_img_summaries(\"./cj\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "6WDYpDFzjocl",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"5"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"len(image_summaries)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"id": "cWyWfZ-XB6cS",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"INFO:chromadb.telemetry.product.posthog:Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information.\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import uuid\n",
|
||
|
"\n",
|
||
|
"from langchain.embeddings import VertexAIEmbeddings\n",
|
||
|
"from langchain.retrievers.multi_vector import MultiVectorRetriever\n",
|
||
|
"from langchain.schema.document import Document\n",
|
||
|
"from langchain.storage import InMemoryStore\n",
|
||
|
"from langchain.vectorstores import Chroma\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def create_multi_vector_retriever(\n",
|
||
|
" vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images\n",
|
||
|
"):\n",
|
||
|
" \"\"\"\n",
|
||
|
" Create retriever that indexes summaries, but returns raw images or texts\n",
|
||
|
" \"\"\"\n",
|
||
|
"\n",
|
||
|
" # Initialize the storage layer\n",
|
||
|
" store = InMemoryStore()\n",
|
||
|
" id_key = \"doc_id\"\n",
|
||
|
"\n",
|
||
|
" # Create the multi-vector retriever\n",
|
||
|
" retriever = MultiVectorRetriever(\n",
|
||
|
" vectorstore=vectorstore,\n",
|
||
|
" docstore=store,\n",
|
||
|
" id_key=id_key,\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" # Helper function to add documents to the vectorstore and docstore\n",
|
||
|
" def add_documents(retriever, doc_summaries, doc_contents):\n",
|
||
|
" doc_ids = [str(uuid.uuid4()) for _ in doc_contents]\n",
|
||
|
" summary_docs = [\n",
|
||
|
" Document(page_content=s, metadata={id_key: doc_ids[i]})\n",
|
||
|
" for i, s in enumerate(doc_summaries)\n",
|
||
|
" ]\n",
|
||
|
" retriever.vectorstore.add_documents(summary_docs)\n",
|
||
|
" retriever.docstore.mset(list(zip(doc_ids, doc_contents)))\n",
|
||
|
"\n",
|
||
|
" # Add texts, tables, and images\n",
|
||
|
" # Check that text_summaries is not empty before adding\n",
|
||
|
" if text_summaries:\n",
|
||
|
" add_documents(retriever, text_summaries, texts)\n",
|
||
|
" # Check that table_summaries is not empty before adding\n",
|
||
|
" if table_summaries:\n",
|
||
|
" add_documents(retriever, table_summaries, tables)\n",
|
||
|
" # Check that image_summaries is not empty before adding\n",
|
||
|
" if image_summaries:\n",
|
||
|
" add_documents(retriever, image_summaries, images)\n",
|
||
|
"\n",
|
||
|
" return retriever\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# The vectorstore to use to index the summaries\n",
|
||
|
"vectorstore = Chroma(\n",
|
||
|
" collection_name=\"mm_rag_cj_blog\",\n",
|
||
|
" embedding_function=VertexAIEmbeddings(model_name=\"textembedding-gecko@latest\"),\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"# Create retriever\n",
|
||
|
"retriever_multi_vector_img = create_multi_vector_retriever(\n",
|
||
|
" vectorstore,\n",
|
||
|
" text_summaries,\n",
|
||
|
" texts,\n",
|
||
|
" table_summaries,\n",
|
||
|
" tables,\n",
|
||
|
" image_summaries,\n",
|
||
|
" img_base64_list,\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "NGDkkMFfCg4j",
|
||
|
"metadata": {
|
||
|
"id": "NGDkkMFfCg4j"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Building a RAG"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "8TzOcHVsCmBc",
|
||
|
"metadata": {
|
||
|
"id": "8TzOcHVsCmBc"
|
||
|
},
|
||
|
"source": [
|
||
|
"Let's build a retriever:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"id": "GlwCErBaCKQW",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import io\n",
|
||
|
"import re\n",
|
||
|
"\n",
|
||
|
"from IPython.display import HTML, display\n",
|
||
|
"from langchain.schema.runnable import RunnableLambda, RunnablePassthrough\n",
|
||
|
"from PIL import Image\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def plt_img_base64(img_base64):\n",
|
||
|
" \"\"\"Disply base64 encoded string as image\"\"\"\n",
|
||
|
" # Create an HTML img tag with the base64 string as the source\n",
|
||
|
" image_html = f'<img src=\"data:image/jpeg;base64,{img_base64}\" />'\n",
|
||
|
" # Display the image by rendering the HTML\n",
|
||
|
" display(HTML(image_html))\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def looks_like_base64(sb):\n",
|
||
|
" \"\"\"Check if the string looks like base64\"\"\"\n",
|
||
|
" return re.match(\"^[A-Za-z0-9+/]+[=]{0,2}$\", sb) is not None\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def is_image_data(b64data):\n",
|
||
|
" \"\"\"\n",
|
||
|
" Check if the base64 data is an image by looking at the start of the data\n",
|
||
|
" \"\"\"\n",
|
||
|
" image_signatures = {\n",
|
||
|
" b\"\\xFF\\xD8\\xFF\": \"jpg\",\n",
|
||
|
" b\"\\x89\\x50\\x4E\\x47\\x0D\\x0A\\x1A\\x0A\": \"png\",\n",
|
||
|
" b\"\\x47\\x49\\x46\\x38\": \"gif\",\n",
|
||
|
" b\"\\x52\\x49\\x46\\x46\": \"webp\",\n",
|
||
|
" }\n",
|
||
|
" try:\n",
|
||
|
" header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes\n",
|
||
|
" for sig, format in image_signatures.items():\n",
|
||
|
" if header.startswith(sig):\n",
|
||
|
" return True\n",
|
||
|
" return False\n",
|
||
|
" except Exception:\n",
|
||
|
" return False\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def resize_base64_image(base64_string, size=(128, 128)):\n",
|
||
|
" \"\"\"\n",
|
||
|
" Resize an image encoded as a Base64 string\n",
|
||
|
" \"\"\"\n",
|
||
|
" # Decode the Base64 string\n",
|
||
|
" img_data = base64.b64decode(base64_string)\n",
|
||
|
" img = Image.open(io.BytesIO(img_data))\n",
|
||
|
"\n",
|
||
|
" # Resize the image\n",
|
||
|
" resized_img = img.resize(size, Image.LANCZOS)\n",
|
||
|
"\n",
|
||
|
" # Save the resized image to a bytes buffer\n",
|
||
|
" buffered = io.BytesIO()\n",
|
||
|
" resized_img.save(buffered, format=img.format)\n",
|
||
|
"\n",
|
||
|
" # Encode the resized image to Base64\n",
|
||
|
" return base64.b64encode(buffered.getvalue()).decode(\"utf-8\")\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def split_image_text_types(docs):\n",
|
||
|
" \"\"\"\n",
|
||
|
" Split base64-encoded images and texts\n",
|
||
|
" \"\"\"\n",
|
||
|
" b64_images = []\n",
|
||
|
" texts = []\n",
|
||
|
" for doc in docs:\n",
|
||
|
" # Check if the document is of type Document and extract page_content if so\n",
|
||
|
" if isinstance(doc, Document):\n",
|
||
|
" doc = doc.page_content\n",
|
||
|
" if looks_like_base64(doc) and is_image_data(doc):\n",
|
||
|
" doc = resize_base64_image(doc, size=(1300, 600))\n",
|
||
|
" b64_images.append(doc)\n",
|
||
|
" else:\n",
|
||
|
" texts.append(doc)\n",
|
||
|
" if len(b64_images) > 0:\n",
|
||
|
" return {\"images\": b64_images[:1], \"texts\": []}\n",
|
||
|
" return {\"images\": b64_images, \"texts\": texts}\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def img_prompt_func(data_dict):\n",
|
||
|
" \"\"\"\n",
|
||
|
" Join the context into a single string\n",
|
||
|
" \"\"\"\n",
|
||
|
" formatted_texts = \"\\n\".join(data_dict[\"context\"][\"texts\"])\n",
|
||
|
" messages = []\n",
|
||
|
"\n",
|
||
|
" # Adding the text for analysis\n",
|
||
|
" text_message = {\n",
|
||
|
" \"type\": \"text\",\n",
|
||
|
" \"text\": (\n",
|
||
|
" \"You are financial analyst tasking with providing investment advice.\\n\"\n",
|
||
|
" \"You will be given a mixed of text, tables, and image(s) usually of charts or graphs.\\n\"\n",
|
||
|
" \"Use this information to provide investment advice related to the user question. \\n\"\n",
|
||
|
" f\"User-provided question: {data_dict['question']}\\n\\n\"\n",
|
||
|
" \"Text and / or tables:\\n\"\n",
|
||
|
" f\"{formatted_texts}\"\n",
|
||
|
" ),\n",
|
||
|
" }\n",
|
||
|
" messages.append(text_message)\n",
|
||
|
" # Adding image(s) to the messages if present\n",
|
||
|
" if data_dict[\"context\"][\"images\"]:\n",
|
||
|
" for image in data_dict[\"context\"][\"images\"]:\n",
|
||
|
" image_message = {\n",
|
||
|
" \"type\": \"image_url\",\n",
|
||
|
" \"image_url\": {\"url\": f\"data:image/jpeg;base64,{image}\"},\n",
|
||
|
" }\n",
|
||
|
" messages.append(image_message)\n",
|
||
|
" return [HumanMessage(content=messages)]\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def multi_modal_rag_chain(retriever):\n",
|
||
|
" \"\"\"\n",
|
||
|
" Multi-modal RAG chain\n",
|
||
|
" \"\"\"\n",
|
||
|
"\n",
|
||
|
" # Multi-modal LLM\n",
|
||
|
" model = ChatVertexAI(\n",
|
||
|
" temperature=0, model_name=\"gemini-pro-vision\", max_output_tokens=1024\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" # RAG pipeline\n",
|
||
|
" chain = (\n",
|
||
|
" {\n",
|
||
|
" \"context\": retriever | RunnableLambda(split_image_text_types),\n",
|
||
|
" \"question\": RunnablePassthrough(),\n",
|
||
|
" }\n",
|
||
|
" | RunnableLambda(img_prompt_func)\n",
|
||
|
" | model\n",
|
||
|
" | StrOutputParser()\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" return chain\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# Create RAG chain\n",
|
||
|
"chain_multimodal_rag = multi_modal_rag_chain(retriever_multi_vector_img)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "BS4hNKqCCp8u",
|
||
|
"metadata": {
|
||
|
"id": "BS4hNKqCCp8u"
|
||
|
},
|
||
|
"source": [
|
||
|
"Let's check that we get images as documents:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"id": "Q7GrwFC_FGwr",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"4"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 10,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"query = \"What are the EV / NTM and NTM rev growth for MongoDB, Cloudflare, and Datadog?\"\n",
|
||
|
"docs = retriever_multi_vector_img.get_relevant_documents(query, limit=1)\n",
|
||
|
"\n",
|
||
|
"# We get 2 docs\n",
|
||
|
"len(docs)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"id": "unnxB5M_FLCD",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<img src=\"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.HTML object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plt_img_base64(docs[0])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "YUkGZXqsCtF6",
|
||
|
"metadata": {
|
||
|
"id": "YUkGZXqsCtF6"
|
||
|
},
|
||
|
"source": [
|
||
|
"And let's run our RAG on the same query:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "LsPTehdK-T-_",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"' | Company | EV / NTM Rev | NTM Rev Growth |\\n|---|---|---|\\n| MongoDB | 14.6x | 17% |\\n| Cloudflare | 13.4x | 28% |\\n| Datadog | 13.1x | 19% |'"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"chain_multimodal_rag.invoke(query)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "XpLQB6dEfQX-",
|
||
|
"metadata": {
|
||
|
"id": "XpLQB6dEfQX-"
|
||
|
},
|
||
|
"source": [
|
||
|
"As we can see, the model was able to figure out the the right values that are relevant to answer the question."
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.11.2"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|