{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-Tuning OpenAI Models for Retrieval Augmented Generation (RAG) with Qdrant and Few-Shot Learning\n", "\n", "The aim of this notebook is to walk through a comprehensive example of how to fine-tune OpenAI models for Retrieval Augmented Generation (RAG). \n", "\n", "We will also be integrating Qdrant and Few-Shot Learning to boost the model's performance and reduce hallucinations. This could serve as a practical guide for ML practitioners, data scientists, and AI Engineers interested in leveraging the power of OpenAI models for specific use-cases. π€©\n", "\n", "## Why should you read this blog?\n", "\n", "You want to learn how to \n", "- [Fine-tune OpenAI models](https://platform.openai.com/docs/guides/fine-tuning/) for specific use-cases\n", "- Use [Qdrant](https://qdrant.tech/documentation/) to improve the performance of your RAG model\n", "- Use fine-tuning to improve the correctness of your RAG model and reduce hallucinations\n", "\n", "To begin, we've selected a dataset where we've a guarantee that the retrieval is perfect. We've selected a subset of the [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) dataset, which is a collection of questions and answers about Wikipedia articles. We've also included samples where the answer is not present in the context, to demonstrate how RAG handles this case.\n", "\n", "## Table of Contents\n", "1. Setting up the Environment\n", "\n", "### Section A: Zero-Shot Learning\n", "2. Data Preparation: SQuADv2 Dataset\n", "3. Answering using Base gpt-3.5-turbo-0613 model\n", "4. Fine-tuning and Answering using Fine-tuned model\n", "5. **Evaluation**: How well does the model perform?\n", "\n", "### Section B: Few-Shot Learning\n", "\n", "6. Using Qdrant to Improve RAG Prompt\n", "7. Fine-Tuning OpenAI Model with Qdrant\n", "8. Evaluation\n", "\n", "9. **Conclusion**\n", " - Aggregate Results\n", " - Observations" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Terms, Definitions, and References\n", "\n", "**Retrieval Augmented Generation (RAG)?**\n", "The phrase Retrieval Augmented Generation (RAG) comes from a [recent paper](https://arxiv.org/abs/2005.11401) by Lewis et al. from Facebook AI. The idea is to use a pre-trained language model (LM) to generate text, but to use a separate retrieval system to find relevant documents to condition the LM on. \n", "\n", "**What is Qdrant?**\n", "Qdrant is an open-source vector search engine that allows you to search for similar vectors in a large dataset. It is built in Rust and here we'll use the Python client to interact with it. This is the Retrieval part of RAG.\n", "\n", "**What is Few-Shot Learning?**\n", "Few-shot learning is a type of machine learning where the model is \"improved\" via training or fine-tuning on a small amount of data. In this case, we'll use it to fine-tune the RAG model on a small number of examples from the SQuAD dataset. This is the Augmented part of RAG.\n", "\n", "**What is Zero-Shot Learning?**\n", "Zero-shot learning is a type of machine learning where the model is \"improved\" via training or fine-tuning without any dataset specific information. \n", "\n", "**What is Fine-Tuning?**\n", "Fine-tuning is a type of machine learning where the model is \"improved\" via training or fine-tuning on a small amount of data. In this case, we'll use it to fine-tune the RAG model on a small number of examples from the SQuAD dataset. The LLM is what makes the Generation part of RAG." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Setting Up the Environment\n", "\n", "### Install and Import Dependencies" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "!pip install pandas openai tqdm tenacity scikit-learn tiktoken python-dotenv seaborn --upgrade --quiet" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import json\n", "import os\n", "import time\n", "\n", "import pandas as pd\n", "import openai\n", "import tiktoken\n", "import seaborn as sns\n", "from tenacity import retry, wait_exponential\n", "from tqdm import tqdm\n", "from collections import defaultdict\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from sklearn.metrics import confusion_matrix\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "tqdm.pandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Set your keys\n", "Get your OpenAI keys [here](https://platform.openai.com/account/api-keys) and Qdrant keys after making a free cluster [here](https://cloud.qdrant.io/login)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "openai.api_key = \"sk-xxx\"\n", "os.environ[\"QDRANT_URL\"] = \"https://xxx.cloud.qdrant.io:6333\"\n", "os.environ[\"QDRANT_API_KEY\"] = \"xxx\"" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Section A\n", "\n", "## 2. Data Preparation: SQuADv2 Data Subsets\n", "\n", "For the purpose of demonstration, we'll make small slices from the train and validation splits of the [SQuADv2](https://rajpurkar.github.io/SQuAD-explorer/) dataset. This dataset has questions and contexts where the answer is not present in the context, to help us evaluate how LLM handles this case.\n", "\n", "We'll read the data from the JSON files and create a dataframe with the following columns: `question`, `context`, `answer`, `is_impossible`.\n", "\n", "### Download the Data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# !mkdir -p local_cache\n", "# !wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json -O local_cache/train.json\n", "# !wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json -O local_cache/dev.json" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Read JSON to DataFrame" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def json_to_dataframe_with_titles(json_data):\n", " qas = []\n", " context = []\n", " is_impossible = []\n", " answers = []\n", " titles = []\n", "\n", " for article in json_data['data']:\n", " title = article['title']\n", " for paragraph in article['paragraphs']:\n", " for qa in paragraph['qas']:\n", " qas.append(qa['question'].strip())\n", " context.append(paragraph['context'])\n", " is_impossible.append(qa['is_impossible'])\n", " \n", " ans_list = []\n", " for ans in qa['answers']:\n", " ans_list.append(ans['text'])\n", " answers.append(ans_list)\n", " titles.append(title)\n", "\n", " df = pd.DataFrame({'title': titles, 'question': qas, 'context': context, 'is_impossible': is_impossible, 'answers': answers})\n", " return df\n", "\n", "def get_diverse_sample(df, sample_size=100, random_state=42):\n", " \"\"\"\n", " Get a diverse sample of the dataframe by sampling from each title\n", " \"\"\"\n", " sample_df = df.groupby(['title', 'is_impossible']).apply(lambda x: x.sample(min(len(x), max(1, sample_size // 50)), random_state=random_state)).reset_index(drop=True)\n", " \n", " if len(sample_df) < sample_size:\n", " remaining_sample_size = sample_size - len(sample_df)\n", " remaining_df = df.drop(sample_df.index).sample(remaining_sample_size, random_state=random_state)\n", " sample_df = pd.concat([sample_df, remaining_df]).sample(frac=1, random_state=random_state).reset_index(drop=True)\n", "\n", " return sample_df.sample(min(sample_size, len(sample_df)), random_state=random_state).reset_index(drop=True)\n", "\n", "train_df = json_to_dataframe_with_titles(json.load(open('local_cache/train.json')))\n", "val_df = json_to_dataframe_with_titles(json.load(open('local_cache/dev.json')))\n", "\n", "df = get_diverse_sample(val_df, sample_size=100, random_state=42)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Answering using Base gpt-3.5-turbo-0613 model\n", "\n", "### 3.1 Zero Shot Prompt\n", "\n", "Let's start by using the base gpt-3.5-turbo-0613 model to answer the questions. This prompt is a simple concatenation of the question and context, with a separator token in between: `\\n\\n`. We've a simple instruction part of the prompt: \n", "\n", "> Answer the following Question based on the Context only. Only answer from the Context. If you don't know the answer, say 'I don't know'.\n", "\n", "Other prompts are possible, but this is a good starting point. We'll use this prompt to answer the questions in the validation set. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Function to get prompt messages\n", "def get_prompt(row):\n", " return [\n", " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", " {\n", " \"role\": \"user\",\n", " \"content\": f\"\"\"Answer the following Question based on the Context only. Only answer from the Context. If you don't know the answer, say 'I don't know'.\n", " Question: {row.question}\\n\\n\n", " Context: {row.context}\\n\\n\n", " Answer:\\n\"\"\",\n", " },\n", " ]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.2 Answering using Zero Shot Prompt\n", "\n", "Next, you'll need some re-usable functions which make an OpenAI API Call and return the answer. You'll use the `ChatCompletion.create` endpoint of the API, which takes a prompt and returns the completed text." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Function with tenacity for retries\n", "@retry(wait=wait_exponential(multiplier=1, min=2, max=6))\n", "def api_call(messages, model):\n", " return openai.ChatCompletion.create(\n", " model=model,\n", " messages=messages,\n", " stop=[\"\\n\\n\"],\n", " max_tokens=100,\n", " temperature=0.0,\n", " )\n", "\n", "\n", "# Main function to answer question\n", "def answer_question(row, prompt_func=get_prompt, model=\"gpt-3.5-turbo-0613\"):\n", " messages = prompt_func(row)\n", " response = api_call(messages, model)\n", " return response[\"choices\"][0][\"message\"][\"content\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "β° **Time to run: ~3 min**, π Needs Internet Connection" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Use progress_apply with tqdm for progress bar\n", "df[\"generated_answer\"] = df.progress_apply(answer_question, axis=1)\n", "df.to_json(\"local_cache/100_val.json\", orient=\"records\", lines=True)\n", "df = pd.read_json(\"local_cache/100_val.json\", orient=\"records\", lines=True)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | title | \n", "question | \n", "context | \n", "is_impossible | \n", "answers | \n", "
---|---|---|---|---|---|
0 | \n", "Scottish_Parliament | \n", "What consequence of establishing the Scottish ... | \n", "A procedural consequence of the establishment ... | \n", "False | \n", "[able to vote on domestic legislation that app... | \n", "
1 | \n", "Imperialism | \n", "Imperialism is less often associated with whic... | \n", "The principles of imperialism are often genera... | \n", "True | \n", "[] | \n", "
2 | \n", "Economic_inequality | \n", "What issues can't prevent women from working o... | \n", "When a personβs capabilities are lowered, they... | \n", "True | \n", "[] | \n", "
3 | \n", "Southern_California | \n", "What county are Los Angeles, Orange, San Diego... | \n", "Its counties of Los Angeles, Orange, San Diego... | \n", "True | \n", "[] | \n", "
4 | \n", "French_and_Indian_War | \n", "When was the deportation of Canadians? | \n", "Britain gained control of French Canada and Ac... | \n", "True | \n", "[] | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
95 | \n", "Geology | \n", "In the layered Earth model, what is the inner ... | \n", "Seismologists can use the arrival times of sei... | \n", "True | \n", "[] | \n", "
96 | \n", "Prime_number | \n", "What type of value would the Basel function ha... | \n", "The zeta function is closely related to prime ... | \n", "True | \n", "[] | \n", "
97 | \n", "Fresno,_California | \n", "What does the San Joaquin Valley Railroad cros... | \n", "Passenger rail service is provided by Amtrak S... | \n", "True | \n", "[] | \n", "
98 | \n", "Victoria_(Australia) | \n", "What party rules in Melbourne's inner regions? | \n", "The centre-left Australian Labor Party (ALP), ... | \n", "False | \n", "[The Greens, Australian Greens, Greens] | \n", "
99 | \n", "Immune_system | \n", "The speed of the killing response of the human... | \n", "In humans, this response is activated by compl... | \n", "False | \n", "[signal amplification, signal amplification, s... | \n", "
100 rows Γ 5 columns
\n", "