mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
834 lines
53 KiB
Plaintext
834 lines
53 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f1571abe-8e84-44d1-b222-e4121fdbb4be",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Advanced RAG Eval\n",
|
|
"\n",
|
|
"The cookbook walks through the process of running eval(s) on advanced RAG. \n",
|
|
"\n",
|
|
"This can be very useful to determine the best RAG approach for your application."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0d8415ee-709c-407f-9ac2-f03a9d697aaf",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"! pip install -U langchain openai chromadb langchain-experimental # (newest versions required for multi-modal)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "191f8465-fd6b-4017-8f0e-d284971b45ae",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# lock to 0.10.19 due to a persistent bug in more recent versions\n",
|
|
"! pip install \"unstructured[all-docs]==0.10.19\" pillow pydantic lxml pillow matplotlib tiktoken open_clip_torch torch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "45949db5-d9b6-44a9-85f8-96d83a288616",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Data Loading\n",
|
|
"\n",
|
|
"Let's look at an [example whitepaper](https://sgp.fas.org/crs/misc/IF10244.pdf) that provides a mixture of tables, text, and images about Wildfires in the US."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "961a42b9-c16b-472e-b994-3c3f73afbbcb",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Option 1: Load text"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "12f24fc0-c176-4201-982b-8a84b278ff1b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Path\n",
|
|
"path = \"/Users/rlm/Desktop/cpi/\"\n",
|
|
"\n",
|
|
"# Load\n",
|
|
"from langchain_community.document_loaders import PyPDFLoader\n",
|
|
"\n",
|
|
"loader = PyPDFLoader(path + \"cpi.pdf\")\n",
|
|
"pdf_pages = loader.load()\n",
|
|
"\n",
|
|
"# Split\n",
|
|
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
|
|
"\n",
|
|
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)\n",
|
|
"all_splits_pypdf = text_splitter.split_documents(pdf_pages)\n",
|
|
"all_splits_pypdf_texts = [d.page_content for d in all_splits_pypdf]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "92fc1870-1836-4bc3-945a-78e2c16ad823",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Option 2: Load text, tables, images \n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "7d863632-f894-4471-b4cc-a1d9aa834d29",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from unstructured.partition.pdf import partition_pdf\n",
|
|
"\n",
|
|
"# Extract images, tables, and chunk text\n",
|
|
"raw_pdf_elements = partition_pdf(\n",
|
|
" filename=path + \"cpi.pdf\",\n",
|
|
" extract_images_in_pdf=True,\n",
|
|
" infer_table_structure=True,\n",
|
|
" chunking_strategy=\"by_title\",\n",
|
|
" max_characters=4000,\n",
|
|
" new_after_n_chars=3800,\n",
|
|
" combine_text_under_n_chars=2000,\n",
|
|
" image_output_dir_path=path,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Categorize by type\n",
|
|
"tables = []\n",
|
|
"texts = []\n",
|
|
"for element in raw_pdf_elements:\n",
|
|
" if \"unstructured.documents.elements.Table\" in str(type(element)):\n",
|
|
" tables.append(str(element))\n",
|
|
" elif \"unstructured.documents.elements.CompositeElement\" in str(type(element)):\n",
|
|
" texts.append(str(element))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "65f399c5-bd91-4ed4-89c6-c89d2e17466e",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Store\n",
|
|
"\n",
|
|
"### Option 1: Embed, store text chunks"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "7d7ecdb2-0bb5-46b8-bcff-af8fc272e88e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain_community.vectorstores import Chroma\n",
|
|
"from langchain_openai import OpenAIEmbeddings\n",
|
|
"\n",
|
|
"baseline = Chroma.from_texts(\n",
|
|
" texts=all_splits_pypdf_texts,\n",
|
|
" collection_name=\"baseline\",\n",
|
|
" embedding=OpenAIEmbeddings(),\n",
|
|
")\n",
|
|
"retriever_baseline = baseline.as_retriever()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6a0eaefe-5e4b-4853-94c7-5abd6f7fbeac",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Option 2: Multi-vector retriever\n",
|
|
"\n",
|
|
"#### Text Summary"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "3d4b4b43-e96e-48ab-899d-c39d0430562e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain_core.output_parsers import StrOutputParser\n",
|
|
"from langchain_core.prompts import ChatPromptTemplate\n",
|
|
"from langchain_openai import ChatOpenAI\n",
|
|
"\n",
|
|
"# Prompt\n",
|
|
"prompt_text = \"\"\"You are an assistant tasked with summarizing tables and text for retrieval. \\\n",
|
|
"These summaries will be embedded and used to retrieve the raw text or table elements. \\\n",
|
|
"Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} \"\"\"\n",
|
|
"prompt = ChatPromptTemplate.from_template(prompt_text)\n",
|
|
"\n",
|
|
"# Text summary chain\n",
|
|
"model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
|
|
"summarize_chain = {\"element\": lambda x: x} | prompt | model | StrOutputParser()\n",
|
|
"\n",
|
|
"# Apply to text\n",
|
|
"text_summaries = summarize_chain.batch(texts, {\"max_concurrency\": 5})\n",
|
|
"\n",
|
|
"# Apply to tables\n",
|
|
"table_summaries = summarize_chain.batch(tables, {\"max_concurrency\": 5})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "bdb5c903-5b4c-4ddb-8f9a-e20f5155dfb9",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Image Summary"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "4570578c-531b-422c-bedd-cc519d9b7887",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Image summary chain\n",
|
|
"import base64\n",
|
|
"import io\n",
|
|
"import os\n",
|
|
"from io import BytesIO\n",
|
|
"\n",
|
|
"from langchain_core.messages import HumanMessage\n",
|
|
"from PIL import Image\n",
|
|
"\n",
|
|
"\n",
|
|
"def encode_image(image_path):\n",
|
|
" \"\"\"Getting the base64 string\"\"\"\n",
|
|
" with open(image_path, \"rb\") as image_file:\n",
|
|
" return base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
|
|
"\n",
|
|
"\n",
|
|
"def image_summarize(img_base64, prompt):\n",
|
|
" \"\"\"Image summary\"\"\"\n",
|
|
" chat = ChatOpenAI(model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
|
|
"\n",
|
|
" msg = chat.invoke(\n",
|
|
" [\n",
|
|
" HumanMessage(\n",
|
|
" content=[\n",
|
|
" {\"type\": \"text\", \"text\": prompt},\n",
|
|
" {\n",
|
|
" \"type\": \"image_url\",\n",
|
|
" \"image_url\": {\"url\": f\"data:image/jpeg;base64,{img_base64}\"},\n",
|
|
" },\n",
|
|
" ]\n",
|
|
" )\n",
|
|
" ]\n",
|
|
" )\n",
|
|
" return msg.content\n",
|
|
"\n",
|
|
"\n",
|
|
"# Store base64 encoded images\n",
|
|
"img_base64_list = []\n",
|
|
"\n",
|
|
"# Store image summaries\n",
|
|
"image_summaries = []\n",
|
|
"\n",
|
|
"# Prompt\n",
|
|
"prompt = \"\"\"You are an assistant tasked with summarizing images for retrieval. \\\n",
|
|
"These summaries will be embedded and used to retrieve the raw image. \\\n",
|
|
"Give a concise summary of the image that is well optimized for retrieval.\"\"\"\n",
|
|
"\n",
|
|
"# Apply to images\n",
|
|
"for img_file in sorted(os.listdir(path)):\n",
|
|
" if img_file.endswith(\".jpg\"):\n",
|
|
" img_path = os.path.join(path, img_file)\n",
|
|
" base64_image = encode_image(img_path)\n",
|
|
" img_base64_list.append(base64_image)\n",
|
|
" image_summaries.append(image_summarize(base64_image, prompt))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "87e03f07-4c82-4743-a3c6-d0597fb55107",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Option 2a: Multi-vector retriever w/ raw images\n",
|
|
"\n",
|
|
"* Return images to LLM for answer synthesis"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "6bf8a07d-203f-4397-8b0b-a84ec4d0adab",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import uuid\n",
|
|
"from base64 import b64decode\n",
|
|
"\n",
|
|
"from langchain.retrievers.multi_vector import MultiVectorRetriever\n",
|
|
"from langchain.storage import InMemoryStore\n",
|
|
"from langchain_core.documents import Document\n",
|
|
"\n",
|
|
"\n",
|
|
"def create_multi_vector_retriever(\n",
|
|
" vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images\n",
|
|
"):\n",
|
|
" # Initialize the storage layer\n",
|
|
" store = InMemoryStore()\n",
|
|
" id_key = \"doc_id\"\n",
|
|
"\n",
|
|
" # Create the multi-vector retriever\n",
|
|
" retriever = MultiVectorRetriever(\n",
|
|
" vectorstore=vectorstore,\n",
|
|
" docstore=store,\n",
|
|
" id_key=id_key,\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Helper function to add documents to the vectorstore and docstore\n",
|
|
" def add_documents(retriever, doc_summaries, doc_contents):\n",
|
|
" doc_ids = [str(uuid.uuid4()) for _ in doc_contents]\n",
|
|
" summary_docs = [\n",
|
|
" Document(page_content=s, metadata={id_key: doc_ids[i]})\n",
|
|
" for i, s in enumerate(doc_summaries)\n",
|
|
" ]\n",
|
|
" retriever.vectorstore.add_documents(summary_docs)\n",
|
|
" retriever.docstore.mset(list(zip(doc_ids, doc_contents)))\n",
|
|
"\n",
|
|
" # Add texts, tables, and images\n",
|
|
" # Check that text_summaries is not empty before adding\n",
|
|
" if text_summaries:\n",
|
|
" add_documents(retriever, text_summaries, texts)\n",
|
|
" # Check that table_summaries is not empty before adding\n",
|
|
" if table_summaries:\n",
|
|
" add_documents(retriever, table_summaries, tables)\n",
|
|
" # Check that image_summaries is not empty before adding\n",
|
|
" if image_summaries:\n",
|
|
" add_documents(retriever, image_summaries, images)\n",
|
|
"\n",
|
|
" return retriever\n",
|
|
"\n",
|
|
"\n",
|
|
"# The vectorstore to use to index the summaries\n",
|
|
"multi_vector_img = Chroma(\n",
|
|
" collection_name=\"multi_vector_img\", embedding_function=OpenAIEmbeddings()\n",
|
|
")\n",
|
|
"\n",
|
|
"# Create retriever\n",
|
|
"retriever_multi_vector_img = create_multi_vector_retriever(\n",
|
|
" multi_vector_img,\n",
|
|
" text_summaries,\n",
|
|
" texts,\n",
|
|
" table_summaries,\n",
|
|
" tables,\n",
|
|
" image_summaries,\n",
|
|
" img_base64_list,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"id": "84d5b4ea-51b8-49cf-8ad1-db8f7a50e3cf",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Testing on retrieval\n",
|
|
"query = \"What percentage of CPI is dedicated to Housing, and how does it compare to the combined percentage of Medical Care, Apparel, and Other Goods and Services?\"\n",
|
|
"suffix_for_images = \" Include any pie charts, graphs, or tables.\"\n",
|
|
"docs = retriever_multi_vector_img.get_relevant_documents(query + suffix_for_images)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "8db51ac6-ec0c-4c5d-a9a7-0316035e139d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<img src=\"\" />"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"from IPython.display import HTML, display\n",
|
|
"\n",
|
|
"\n",
|
|
"def plt_img_base64(img_base64):\n",
|
|
" # Create an HTML img tag with the base64 string as the source\n",
|
|
" image_html = f'<img src=\"data:image/jpeg;base64,{img_base64}\" />'\n",
|
|
"\n",
|
|
" # Display the image by rendering the HTML\n",
|
|
" display(HTML(image_html))\n",
|
|
"\n",
|
|
"\n",
|
|
"plt_img_base64(docs[1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "48b268ec-db04-4107-9833-ea1615f6dbd1",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Option 2b: Multi-vector retriever w/ image summaries\n",
|
|
"\n",
|
|
"* Return text summary of images to LLM for answer synthesis"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "ae57c804-0dd1-4806-b761-a913efc4f173",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# The vectorstore to use to index the summaries\n",
|
|
"multi_vector_text = Chroma(\n",
|
|
" collection_name=\"multi_vector_text\", embedding_function=OpenAIEmbeddings()\n",
|
|
")\n",
|
|
"\n",
|
|
"# Create retriever\n",
|
|
"retriever_multi_vector_img_summary = create_multi_vector_retriever(\n",
|
|
" multi_vector_text,\n",
|
|
" text_summaries,\n",
|
|
" texts,\n",
|
|
" table_summaries,\n",
|
|
" tables,\n",
|
|
" image_summaries,\n",
|
|
" image_summaries,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "580a3d55-5025-472d-9c14-cec7a384379f",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Option 3: Multi-modal embeddings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "8dbed5dc-f7a3-4324-9436-1c3ebc24f9fd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain_experimental.open_clip import OpenCLIPEmbeddings\n",
|
|
"\n",
|
|
"# Create chroma w/ multi-modal embeddings\n",
|
|
"multimodal_embd = Chroma(\n",
|
|
" collection_name=\"multimodal_embd\", embedding_function=OpenCLIPEmbeddings()\n",
|
|
")\n",
|
|
"\n",
|
|
"# Get image URIs\n",
|
|
"image_uris = sorted(\n",
|
|
" [\n",
|
|
" os.path.join(path, image_name)\n",
|
|
" for image_name in os.listdir(path)\n",
|
|
" if image_name.endswith(\".jpg\")\n",
|
|
" ]\n",
|
|
")\n",
|
|
"\n",
|
|
"# Add images and documents\n",
|
|
"if image_uris:\n",
|
|
" multimodal_embd.add_images(uris=image_uris)\n",
|
|
"if texts:\n",
|
|
" multimodal_embd.add_texts(texts=texts)\n",
|
|
"if tables:\n",
|
|
" multimodal_embd.add_texts(texts=tables)\n",
|
|
"\n",
|
|
"# Make retriever\n",
|
|
"retriever_multimodal_embd = multimodal_embd.as_retriever()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "647abb6c-adf3-4d29-acd2-885c4925fa12",
|
|
"metadata": {},
|
|
"source": [
|
|
"## RAG\n",
|
|
"\n",
|
|
"### Text Pipeline"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "73440ca0-4330-4c16-9d9d-6f27c249ae58",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from operator import itemgetter\n",
|
|
"\n",
|
|
"from langchain_core.runnables import RunnablePassthrough\n",
|
|
"\n",
|
|
"# Prompt\n",
|
|
"template = \"\"\"Answer the question based only on the following context, which can include text and tables:\n",
|
|
"{context}\n",
|
|
"Question: {question}\n",
|
|
"\"\"\"\n",
|
|
"rag_prompt_text = ChatPromptTemplate.from_template(template)\n",
|
|
"\n",
|
|
"\n",
|
|
"# Build\n",
|
|
"def text_rag_chain(retriever):\n",
|
|
" \"\"\"RAG chain\"\"\"\n",
|
|
"\n",
|
|
" # LLM\n",
|
|
" model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
|
|
"\n",
|
|
" # RAG pipeline\n",
|
|
" chain = (\n",
|
|
" {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
|
|
" | rag_prompt_text\n",
|
|
" | model\n",
|
|
" | StrOutputParser()\n",
|
|
" )\n",
|
|
"\n",
|
|
" return chain"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "14b358ad-42fd-4c6d-b2c0-215dba135707",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Multi-modal Pipeline"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "ae89ce84-283e-4634-8169-9ff16f152807",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import re\n",
|
|
"\n",
|
|
"from langchain_core.documents import Document\n",
|
|
"from langchain_core.runnables import RunnableLambda\n",
|
|
"\n",
|
|
"\n",
|
|
"def looks_like_base64(sb):\n",
|
|
" \"\"\"Check if the string looks like base64.\"\"\"\n",
|
|
" return re.match(\"^[A-Za-z0-9+/]+[=]{0,2}$\", sb) is not None\n",
|
|
"\n",
|
|
"\n",
|
|
"def is_image_data(b64data):\n",
|
|
" \"\"\"Check if the base64 data is an image by looking at the start of the data.\"\"\"\n",
|
|
" image_signatures = {\n",
|
|
" b\"\\xFF\\xD8\\xFF\": \"jpg\",\n",
|
|
" b\"\\x89\\x50\\x4E\\x47\\x0D\\x0A\\x1A\\x0A\": \"png\",\n",
|
|
" b\"\\x47\\x49\\x46\\x38\": \"gif\",\n",
|
|
" b\"\\x52\\x49\\x46\\x46\": \"webp\",\n",
|
|
" }\n",
|
|
" try:\n",
|
|
" header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes\n",
|
|
" for sig, format in image_signatures.items():\n",
|
|
" if header.startswith(sig):\n",
|
|
" return True\n",
|
|
" return False\n",
|
|
" except Exception:\n",
|
|
" return False\n",
|
|
"\n",
|
|
"\n",
|
|
"def split_image_text_types(docs):\n",
|
|
" \"\"\"Split base64-encoded images and texts.\"\"\"\n",
|
|
" b64_images = []\n",
|
|
" texts = []\n",
|
|
" for doc in docs:\n",
|
|
" # Check if the document is of type Document and extract page_content if so\n",
|
|
" if isinstance(doc, Document):\n",
|
|
" doc = doc.page_content\n",
|
|
" if looks_like_base64(doc) and is_image_data(doc):\n",
|
|
" b64_images.append(doc)\n",
|
|
" else:\n",
|
|
" texts.append(doc)\n",
|
|
" return {\"images\": b64_images, \"texts\": texts}\n",
|
|
"\n",
|
|
"\n",
|
|
"def img_prompt_func(data_dict):\n",
|
|
" # Joining the context texts into a single string\n",
|
|
" formatted_texts = \"\\n\".join(data_dict[\"context\"][\"texts\"])\n",
|
|
" messages = []\n",
|
|
"\n",
|
|
" # Adding image(s) to the messages if present\n",
|
|
" if data_dict[\"context\"][\"images\"]:\n",
|
|
" image_message = {\n",
|
|
" \"type\": \"image_url\",\n",
|
|
" \"image_url\": {\n",
|
|
" \"url\": f\"data:image/jpeg;base64,{data_dict['context']['images'][0]}\"\n",
|
|
" },\n",
|
|
" }\n",
|
|
" messages.append(image_message)\n",
|
|
"\n",
|
|
" # Adding the text message for analysis\n",
|
|
" text_message = {\n",
|
|
" \"type\": \"text\",\n",
|
|
" \"text\": (\n",
|
|
" \"Answer the question based only on the provided context, which can include text, tables, and image(s). \"\n",
|
|
" \"If an image is provided, analyze it carefully to help answer the question.\\n\"\n",
|
|
" f\"User-provided question / keywords: {data_dict['question']}\\n\\n\"\n",
|
|
" \"Text and / or tables:\\n\"\n",
|
|
" f\"{formatted_texts}\"\n",
|
|
" ),\n",
|
|
" }\n",
|
|
" messages.append(text_message)\n",
|
|
" return [HumanMessage(content=messages)]\n",
|
|
"\n",
|
|
"\n",
|
|
"def multi_modal_rag_chain(retriever):\n",
|
|
" \"\"\"Multi-modal RAG chain\"\"\"\n",
|
|
"\n",
|
|
" # Multi-modal LLM\n",
|
|
" model = ChatOpenAI(temperature=0, model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
|
|
"\n",
|
|
" # RAG pipeline\n",
|
|
" chain = (\n",
|
|
" {\n",
|
|
" \"context\": retriever | RunnableLambda(split_image_text_types),\n",
|
|
" \"question\": RunnablePassthrough(),\n",
|
|
" }\n",
|
|
" | RunnableLambda(img_prompt_func)\n",
|
|
" | model\n",
|
|
" | StrOutputParser()\n",
|
|
" )\n",
|
|
"\n",
|
|
" return chain"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5e8b0e26-bb7e-420a-a7bd-8512b7eef92f",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Build RAG Pipelines"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"id": "4f1ec8a9-f0fe-4f08-928f-23504803897c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# RAG chains\n",
|
|
"chain_baseline = text_rag_chain(retriever_baseline)\n",
|
|
"chain_mv_text = text_rag_chain(retriever_multi_vector_img_summary)\n",
|
|
"\n",
|
|
"# Multi-modal RAG chains\n",
|
|
"chain_multimodal_mv_img = multi_modal_rag_chain(retriever_multi_vector_img)\n",
|
|
"chain_multimodal_embd = multi_modal_rag_chain(retriever_multimodal_embd)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "448d943c-a1b1-4300-9197-891a03232ee4",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Eval set"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"id": "9aabf72f-26be-437f-9372-b06dc2509235",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>Question</th>\n",
|
|
" <th>Answer</th>\n",
|
|
" <th>Source</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>What percentage of CPI is dedicated to Housing?</td>\n",
|
|
" <td>Housing occupies 42% of CPI.</td>\n",
|
|
" <td>Figure 1</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>Medical Care and Transportation account for wh...</td>\n",
|
|
" <td>Transportation accounts for 18% of CPI. Medica...</td>\n",
|
|
" <td>Figure 1</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>Based on the CPI Owners' Equivalent Rent and t...</td>\n",
|
|
" <td>The FHFA Purchase Only Price Index appears to ...</td>\n",
|
|
" <td>Figure 2</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" Question \\\n",
|
|
"0 What percentage of CPI is dedicated to Housing? \n",
|
|
"1 Medical Care and Transportation account for wh... \n",
|
|
"2 Based on the CPI Owners' Equivalent Rent and t... \n",
|
|
"\n",
|
|
" Answer Source \n",
|
|
"0 Housing occupies 42% of CPI. Figure 1 \n",
|
|
"1 Transportation accounts for 18% of CPI. Medica... Figure 1 \n",
|
|
"2 The FHFA Purchase Only Price Index appears to ... Figure 2 "
|
|
]
|
|
},
|
|
"execution_count": 34,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Read\n",
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"eval_set = pd.read_csv(path + \"cpi_eval.csv\")\n",
|
|
"eval_set.head(3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 35,
|
|
"id": "7fdeb77a-e185-47d2-a93f-822f1fc810a2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langsmith import Client\n",
|
|
"\n",
|
|
"# Dataset\n",
|
|
"client = Client()\n",
|
|
"dataset_name = f\"CPI Eval {str(uuid.uuid4())}\"\n",
|
|
"dataset = client.create_dataset(dataset_name=dataset_name)\n",
|
|
"\n",
|
|
"# Populate dataset\n",
|
|
"for _, row in eval_set.iterrows():\n",
|
|
" # Get Q, A\n",
|
|
" q = row[\"Question\"]\n",
|
|
" a = row[\"Answer\"]\n",
|
|
" # Use the values in your function\n",
|
|
" client.create_example(\n",
|
|
" inputs={\"question\": q}, outputs={\"answer\": a}, dataset_id=dataset.id\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"id": "3c4faf4b-f29f-4a42-9cf2-bfbb5158ab59",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"View the evaluation results for project 'CPI Eval 9648e7fe-5ae2-469f-8701-33c63212d126-baseline' at:\n",
|
|
"https://smith.langchain.com/o/1fa8b1f4-fcb9-4072-9aa9-983e35ad61b8/projects/p/533846be-d907-4d9c-82db-ce2f1a18fdbf?eval=true\n",
|
|
"\n",
|
|
"View all tests for Dataset CPI Eval 9648e7fe-5ae2-469f-8701-33c63212d126 at:\n",
|
|
"https://smith.langchain.com/datasets/d1762232-5e01-40e7-9978-63002a4c95a3\n",
|
|
"[------------------------------------------------->] 4/4View the evaluation results for project 'CPI Eval 9648e7fe-5ae2-469f-8701-33c63212d126-mv_text' at:\n",
|
|
"https://smith.langchain.com/o/1fa8b1f4-fcb9-4072-9aa9-983e35ad61b8/projects/p/f5caeede-6f8e-46f7-b4f2-9f23daa31eda?eval=true\n",
|
|
"\n",
|
|
"View all tests for Dataset CPI Eval 9648e7fe-5ae2-469f-8701-33c63212d126 at:\n",
|
|
"https://smith.langchain.com/datasets/d1762232-5e01-40e7-9978-63002a4c95a3\n",
|
|
"[------------------------------------------------->] 4/4View the evaluation results for project 'CPI Eval 9648e7fe-5ae2-469f-8701-33c63212d126-mv_img' at:\n",
|
|
"https://smith.langchain.com/o/1fa8b1f4-fcb9-4072-9aa9-983e35ad61b8/projects/p/48cf1002-7ae2-451d-a9b1-5bd8088f6a69?eval=true\n",
|
|
"\n",
|
|
"View all tests for Dataset CPI Eval 9648e7fe-5ae2-469f-8701-33c63212d126 at:\n",
|
|
"https://smith.langchain.com/datasets/d1762232-5e01-40e7-9978-63002a4c95a3\n",
|
|
"[------------------------------------------------->] 4/4View the evaluation results for project 'CPI Eval 9648e7fe-5ae2-469f-8701-33c63212d126-mm_embd' at:\n",
|
|
"https://smith.langchain.com/o/1fa8b1f4-fcb9-4072-9aa9-983e35ad61b8/projects/p/aaa1c2e3-79b0-43e0-b5d5-8e3d00a51d50?eval=true\n",
|
|
"\n",
|
|
"View all tests for Dataset CPI Eval 9648e7fe-5ae2-469f-8701-33c63212d126 at:\n",
|
|
"https://smith.langchain.com/datasets/d1762232-5e01-40e7-9978-63002a4c95a3\n",
|
|
"[------------------------------------------------->] 4/4"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from langchain.smith import RunEvalConfig\n",
|
|
"\n",
|
|
"eval_config = RunEvalConfig(\n",
|
|
" evaluators=[\"qa\"],\n",
|
|
")\n",
|
|
"\n",
|
|
"\n",
|
|
"def run_eval(chain, run_name, dataset_name):\n",
|
|
" _ = client.run_on_dataset(\n",
|
|
" dataset_name=dataset_name,\n",
|
|
" llm_or_chain_factory=lambda: (lambda x: x[\"question\"] + suffix_for_images)\n",
|
|
" | chain,\n",
|
|
" evaluation=eval_config,\n",
|
|
" project_name=run_name,\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"for chain, run in zip(\n",
|
|
" [chain_baseline, chain_mv_text, chain_multimodal_mv_img, chain_multimodal_embd],\n",
|
|
" [\"baseline\", \"mv_text\", \"mv_img\", \"mm_embd\"],\n",
|
|
"):\n",
|
|
" run_eval(chain, dataset_name + \"-\" + run, dataset_name)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|