{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visual RAG using VDMS\n",
"Visual RAG is a framework that retrieves video based on provided user prompt. It uses both video scene description generated by open source vision models (ex. video-llama, video-llava etc.) as text embeddings and frames as image embeddings to perform vector similarity search using VDMS."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Start VDMS Server\n",
"Let's start a VDMS docker container using the port 55559.\n",
"Keep note of the port and hostname as this is needed for the vector store as it uses the VDMS Python client to connect to the server."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2e44b44201c8778b462342ac97f5ccf05a4e02aa8a04505ecde97bf20dcc4cbb\n"
]
}
],
"source": [
"! docker run --rm -d -p 55559:55555 --name vdms_rag_nb intellabs/vdms:latest"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import Python Packages\n",
"\n",
"Verify the necessary python packages are available for this visual RAG example."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"! pip install --quiet -U vdms langchain-experimental sentence-transformers opencv-python open_clip_torch torch accelerate"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now import the packages."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"from pathlib import Path\n",
"from threading import Thread\n",
"from typing import Any, List, Mapping, Optional\n",
"from zipfile import ZipFile\n",
"\n",
"import cv2\n",
"import torch\n",
"from huggingface_hub import hf_hub_download\n",
"from IPython.display import Video\n",
"from langchain.llms.base import LLM\n",
"from langchain_community.embeddings.sentence_transformer import (\n",
" SentenceTransformerEmbeddings,\n",
")\n",
"from langchain_community.vectorstores.vdms import VDMS, VDMS_Client\n",
"from langchain_core.callbacks.manager import CallbackManagerForLLMRun\n",
"from langchain_core.runnables import ConfigurableField\n",
"from langchain_experimental.open_clip import OpenCLIPEmbeddings\n",
"from transformers import (\n",
" AutoModelForCausalLM,\n",
" AutoTokenizer,\n",
" TextIteratorStreamer,\n",
" set_seed,\n",
")\n",
"\n",
"set_seed(22)\n",
"number_of_frames_per_second = 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialize Vector Stores\n",
"In this section, we initialize the VDMS vector store for both text and images. The text components use model `all-MiniLM-L12-v2`from `SentenceTransformerEmbeddings` and the images use model `ViT-g-14` from `OpenCLIPEmbeddings`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Create directory to store data\n",
"datapath = Path(\"./data/visual\").resolve()\n",
"datapath.mkdir(parents=True, exist_ok=True)\n",
"\n",
"# Create directory to store frames\n",
"frame_dir = str(datapath / \"frames_from_clips\")\n",
"os.makedirs(frame_dir, exist_ok=True)\n",
"\n",
"# Connect to VDMS server\n",
"vdms_client = VDMS_Client(port=55559)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"# Initialize VDMS Vector Store\n",
"text_collection = \"text-test\"\n",
"text_embedder = SentenceTransformerEmbeddings(model_name=\"all-MiniLM-L12-v2\")\n",
"text_db = VDMS(\n",
" client=vdms_client,\n",
" embedding=text_embedder,\n",
" collection_name=text_collection,\n",
" engine=\"FaissFlat\",\n",
")\n",
"\n",
"text_retriever = text_db.as_retriever().configurable_fields(\n",
" search_kwargs=ConfigurableField(\n",
" id=\"k_text_docs\",\n",
" name=\"Search Kwargs\",\n",
" description=\"The search kwargs to use\",\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"image_collection = \"image-test\"\n",
"image_embedder = OpenCLIPEmbeddings(\n",
" model_name=\"ViT-g-14\", checkpoint=\"laion2b_s34b_b88k\"\n",
")\n",
"image_db = VDMS(\n",
" client=vdms_client,\n",
" embedding=image_embedder,\n",
" collection_name=image_collection,\n",
" engine=\"FaissFlat\",\n",
")\n",
"image_retriever = image_db.as_retriever(search_type=\"mmr\").configurable_fields(\n",
" search_kwargs=ConfigurableField(\n",
" id=\"k_image_docs\",\n",
" name=\"Search Kwargs\",\n",
" description=\"The search kwargs to use\",\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Loading\n",
"\n",
"For this visual RAG example, we need to obtain videos and also video scene descriptions generated by open source vision models (ex. video-llava etc.) as text. \n",
"We have published a [Video Summarization Dataset](https://huggingface.co/datasets/Intel/Video_Summarization_For_Retail) available on Hugging Face which contains short videos of shoppers in a retail setting along with the corresponding textual description of each video."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Download data\n",
"hf_hub_download(\n",
" repo_id=\"Intel/Video_Summarization_For_Retail\",\n",
" filename=\"VideoSumForRetailData.zip\",\n",
" repo_type=\"dataset\",\n",
" local_dir=str(datapath),\n",
")\n",
"with ZipFile(str(datapath / \"VideoSumForRetailData.zip\"), \"r\") as z:\n",
" z.extractall(path=datapath)\n",
"\n",
"with open(str(datapath / \"VideoSumForRetailData/clips_anno.json\"), \"r\") as f:\n",
" scene_info = json.load(f)\n",
"\n",
"video_dir = str(datapath / \"VideoSumForRetailData/clips/\")\n",
"\n",
"# Create dict for data where key is video name and value is scene description\n",
"video_list = {}\n",
"for scene in scene_info:\n",
" video_list[scene[\"video\"].split(\"/\")[-1]] = scene[\"conversations\"][1][\"value\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we use OpenCV to extract metadata such as fps and number of frames for each video and also metadata such as frame number and timestamp for each extracted video frame. Once the metadata is extracted, the details are stored in VDMS."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"text_content = []\n",
"video_metadata_list = []\n",
"uris = []\n",
"frame_metadata_list = []\n",
"for video_name, description in video_list.items():\n",
" video_path = os.path.join(video_dir, video_name)\n",
"\n",
" # Obtain Description and Video Metadata\n",
" text_content.append(description)\n",
" cap = cv2.VideoCapture(video_path)\n",
" fps = cap.get(cv2.CAP_PROP_FPS)\n",
" total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)\n",
" metadata = {\"video\": video_name, \"fps\": fps, \"total_frames\": total_frames}\n",
" video_metadata_list.append(metadata)\n",
"\n",
" # Obtain Metadata per Extracted Frame\n",
" mod = int(fps // number_of_frames_per_second)\n",
" if mod == 0:\n",
" mod = 1\n",
" frame_count = 0\n",
" while cap.isOpened():\n",
" ret, frame = cap.read()\n",
" if not ret:\n",
" break\n",
" frame_count += 1\n",
" if frame_count % mod == 0:\n",
" timestamp = (\n",
" cap.get(cv2.CAP_PROP_POS_MSEC) / 1000\n",
" ) # Convert milliseconds to seconds\n",
" frame_path = os.path.join(frame_dir, f\"{video_name}_{frame_count}.jpg\")\n",
" cv2.imwrite(frame_path, frame) # Save the frame as an image\n",
" frame_metadata = {\n",
" \"timestamp\": timestamp,\n",
" \"frame_path\": frame_path,\n",
" \"video\": video_name,\n",
" \"frame_num\": frame_count,\n",
" }\n",
" uris.append(frame_path)\n",
" frame_metadata_list.append(frame_metadata)\n",
" cap.release()\n",
"\n",
"# Add Text and Images\n",
"text_db.add_texts(text_content, video_metadata_list)\n",
"image_db.add_images(uris, frame_metadata_list);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run Multimodal Retrieval\n",
"\n",
"Here we define helper functions for retrieving text and image results based on a user query.\n",
"First, we use multi-modal retrieval to retrieve one text and three image documents for the user query. \n",
"Then we return the video name for the video with the most results."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def MultiModalRetrieval(\n",
" query: str,\n",
" n_texts: Optional[int] = 1,\n",
" n_images: Optional[int] = 3,\n",
" print_text_content=False,\n",
"):\n",
" text_config = {\"configurable\": {\"k_text_docs\": {\"k\": n_texts}}}\n",
" image_config = {\"configurable\": {\"k_image_docs\": {\"k\": n_images}}}\n",
"\n",
" print(\"\\tRetrieving 1 text doc and 3 image docs\")\n",
" text_results = text_retriever.invoke(query, config=text_config)\n",
" image_results = image_retriever.invoke(query, config=image_config)\n",
"\n",
" if print_text_content:\n",
" print(\n",
" f\"\\tPage content:\\n\\t\\t{text_results[0].page_content}\\n\\n\\tMetadata:\\n\\t\\t{text_results[0].metadata}\"\n",
" )\n",
"\n",
" return text_results + image_results\n",
"\n",
"\n",
"def get_top_doc(results, qcnt=0):\n",
" hit_score = {}\n",
" for r in results:\n",
" if \"video\" in r.metadata:\n",
" video_name = r.metadata[\"video\"]\n",
" if video_name not in hit_score.keys():\n",
" hit_score[video_name] = 0\n",
" hit_score[video_name] += 1\n",
"\n",
" x = dict(sorted(hit_score.items(), key=lambda item: -item[1]))\n",
"\n",
" if qcnt >= len(x):\n",
" return None\n",
" # print (f'top docs = {x}')\n",
" return {\"video\": list(x)[qcnt]}\n",
"\n",
"\n",
"def Retrieve_top_results(prompt, qcnt=0, print_text_content=False):\n",
" print(\"Querying database . . . \")\n",
" results = MultiModalRetrieval(\n",
" prompt, n_texts=1, n_images=3, print_text_content=print_text_content\n",
" )\n",
" print(\"Retrieved Top matching video!\\n\\n\")\n",
"\n",
" top_doc = get_top_doc(results, qcnt)\n",
" # print('TOP DOC = ', top_doc)\n",
" if top_doc is None:\n",
" return None, None\n",
"\n",
" return top_doc[\"video\"], top_doc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's query for a `man wearing khaki pants` and retrieve the top results."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Querying database . . . \n",
"\tRetrieving 1 text doc and 3 image docs\n",
"\tPage content:\n",
"\t\tThere are 2 shoppers in this video. Shopper 1 is wearing a plaid shirt and a spectacle. Shopper 2 who is not completely captured in the frame seems to wear a black shirt and is moving away with his back turned towards the camera. There is a shelf towards the right of the camera frame. Shopper 2 is hanging an item back to a hanger and then quickly walks away in a similar fashion as shopper 2. Contents of the nearer side of the shelf with respect to camera seems to be camping lanterns and cleansing agents, arranged at the top. In the middle part of the shelf, various tools including grommets, a pocket saw, candles, and other helpful camping items can be observed. Midway through the shelf contains items which appear to be steel containers and items made up of plastic with red, green, orange, and yellow colors, while those at the bottom are packed in cardboard boxes. Contents at the farther part of the shelf are well stocked and organized but are not glaringly visible.\n",
"\n",
"\tMetadata:\n",
"\t\t{'fps': 24.0, 'id': 'c6e5f894-b905-46f5-ac9e-4487a9235561', 'total_frames': 120.0, 'video': 'clip16.mp4'}\n",
"Retrieved Top matching video!\n",
"\n",
"\n"
]
}
],
"source": [
"input_query = \"Find a man wearing khaki pants\"\n",
"video_name, top_doc = Retrieve_top_results(input_query, print_text_content=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run RAG using LLM\n",
"### Load LLM Model\n",
"In this example, we use Meta's [LLama-2-Chat (7B) model](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) which is optimized for dialogue use cases. \n",
"If you do not have access to this model, feel free to substitute the model with a different LLM.\n",
"In this example, the model is expected to be in `data/visual/llama-2-7b-chat-hf`. "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3edf8783e114487ca490d8dec5c46884",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Directory for LLM model\n",
"model_path = str(datapath / \"llama-2-7b-chat-hf\")\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_path,\n",
" torch_dtype=torch.float32,\n",
" device_map=\"auto\",\n",
" trust_remote_code=True,\n",
")\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n",
"tokenizer.padding_size = \"right\"\n",
"streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)\n",
"\n",
"\n",
"class CustomLLM(LLM):\n",
" @torch.inference_mode()\n",
" def _call(\n",
" self,\n",
" prompt: str,\n",
" stop: Optional[List[str]] = None,\n",
" run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
" streamer: Optional[TextIteratorStreamer] = None, # Add streamer as an argument\n",
" ) -> str:\n",
" tokens = tokenizer.encode(prompt, return_tensors=\"pt\")\n",
"\n",
" with torch.no_grad():\n",
" output = model.generate(\n",
" input_ids=tokens.to(model.device.type),\n",
" max_new_tokens=100,\n",
" num_return_sequences=1,\n",
" num_beams=1,\n",
" min_length=1,\n",
" top_p=0.9,\n",
" top_k=50,\n",
" repetition_penalty=1.2,\n",
" length_penalty=1,\n",
" temperature=0.1,\n",
" streamer=streamer,\n",
" # pad_token_id=tokenizer.eos_token_id,\n",
" do_sample=True,\n",
" )\n",
"\n",
" def stream_res(self, prompt):\n",
" thread = Thread(\n",
" target=self._call, args=(prompt, None, None, streamer)\n",
" ) # Pass streamer to _call\n",
" thread.start()\n",
"\n",
" for text in streamer:\n",
" yield text\n",
"\n",
" @property\n",
" def _identifying_params(self) -> Mapping[str, Any]:\n",
" return model_path # {\"name_of_model\": model_path}\n",
"\n",
" @property\n",
" def _llm_type(self) -> str:\n",
" return \"custom\"\n",
"\n",
"\n",
"llm = CustomLLM()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Run Chatbot\n",
"\n",
"First, we define the prompt and a simple chatbot for processing the user query."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def get_formatted_prompt(scene, prompt):\n",
" PROMPT = \"\"\" <>\n",
" You are an Intel assistant who understands visual and textual content.\n",
" <>\n",
" [INST]\n",
" You will be provided with two things, scene description and user's question. You are suppose to understand scene description \\\n",
" and provide answer to user's question.\n",
"\n",
" As an assistant, you need to follow these Rules while answering questions,\n",
"\n",
" Rules:\n",
" - Don't answer any question which are not related to provided scene description.\n",
" - Don't be toxic and don't include harmful information.\n",
" - Answer if you can from provided scene description otherwise just say You don't have enough information to answer the question.\n",
"\n",
" Here is the,\n",
" Scene Description: {{ scene }}\n",
"\n",
" The user wants to know,\n",
" User: {{ prompt }}\n",
" [/INST]\\n\n",
" Assistant:\n",
" \"\"\"\n",
" return PROMPT.replace(\"{{ scene }}\", scene).replace(\"{{ prompt }}\", prompt)\n",
"\n",
"\n",
"def simple_chatbot(user_query):\n",
" messages = [{\"role\": \"assistant\", \"content\": \"How may I assist you today?\"}]\n",
" messages.append({\"role\": \"user\", \"content\": user_query})\n",
" video_name, top_doc = Retrieve_top_results(user_query)\n",
"\n",
" scene_des = video_list[video_name]\n",
" formatted_prompt = get_formatted_prompt(scene=scene_des, prompt=user_query)\n",
" # print(formatted_prompt)\n",
" full_response = f\"Most relevant retrieved video is **{video_name}** \\n\\n\"\n",
" for new_text in llm.stream_res(formatted_prompt):\n",
" full_response += new_text\n",
" message = {\"role\": \"assistant\", \"content\": full_response}\n",
" messages.append(message)\n",
"\n",
" for message in messages:\n",
" print(message[\"role\"].capitalize(), \": \", message[\"content\"])\n",
"\n",
" video_path = os.path.join(video_dir, top_doc[\"video\"])\n",
" return video_path"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's use the simple chatbot to process a query asking for a `man holding a red shopping basket` and display the resulting video."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Querying database . . . \n",
"\tRetrieving 1 text doc and 3 image docs\n",
"\tPage content:\n",
"\t\tA single shopper is seen in this video standing facing the shelf and in the bottom part of the frame. He's wearing a light-colored shirt and a spectacle. The shopper is carrying a red colored basket in his left hand. The entire basket is not clearly visible, but it does seem to contain something in a blue colored package which the shopper has just placed in the basket given his right hand was seen inside the basket. Then the shopper leans towards the shelf and checks out an item in orange package. He picks this single item with his right hand and proceeds to place the item in the basket. The entire shelf looks well stocked except for the top part of the shelf which is empty. The shopper has not picked any item from this part of the shelf. The rest of the shelf looks well stocked and does not need any restocking. The contents on the farther part of the shelf consists of items, majority of which are packed in black, yellow, and green packages. No other details are visible of these items.\n",
"\n",
"\tMetadata:\n",
"\t\t{'fps': 24.0, 'id': '37ddc212-994e-4db0-877f-5ed09965ab90', 'total_frames': 162.0, 'video': 'clip10.mp4'}\n",
"Retrieved Top matching video!\n",
"\n",
"\n"
]
}
],
"source": [
"input_query = \"Find a man holding a red shopping basket\"\n",
"video_name, top_doc = Retrieve_top_results(input_query, print_text_content=True)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Querying database . . . \n",
"\tRetrieving 1 text doc and 3 image docs\n",
"Retrieved Top matching video!\n",
"\n",
"\n",
"Assistant : How may I assist you today?\n",
"User : Find a man holding a red shopping basket\n",
"Assistant : Most relevant retrieved video is **clip9.mp4** \n",
"\n",
"I see a person standing in front of a well-stocked shelf, they are wearing a light-colored shirt and glasses, and they have a red shopping basket in their left hand. They are leaning forward and picking up an item from the shelf with their right hand. The item is packaged in a blue-green box. Based on the scene description, I can confirm that the person is indeed holding a red shopping basket.\n"
]
}
],
"source": [
"input_query = \"Find a man holding a red shopping basket\"\n",
"video_path = simple_chatbot(input_query)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"