You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/fine-tuned_qa/ft_retrieval_augmented_gene...

1453 lines
210 KiB
Plaintext

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fine-Tuning OpenAI Models for Retrieval Augmented Generation (RAG) with Qdrant and Few-Shot Learning\n",
"\n",
"The aim of this notebook is to walk through a comprehensive example of how to fine-tune OpenAI models for Retrieval Augmented Generation (RAG). \n",
"\n",
"We will also be integrating Qdrant and Few-Shot Learning to boost the model's performance and reduce hallucinations. This could serve as a practical guide for ML practitioners, data scientists, and AI Engineers interested in leveraging the power of OpenAI models for specific use-cases. 🤩\n",
"\n",
"## Why should you read this blog?\n",
"\n",
"You want to learn how to \n",
"- [Fine-tune OpenAI models](https://platform.openai.com/docs/guides/fine-tuning/) for specific use-cases\n",
"- Use [Qdrant](https://qdrant.tech/documentation/) to improve the performance of your RAG model\n",
"- Use fine-tuning to improve the correctness of your RAG model and reduce hallucinations\n",
"\n",
"To begin, we've selected a dataset where we've a guarantee that the retrieval is perfect. We've selected a subset of the [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) dataset, which is a collection of questions and answers about Wikipedia articles. We've also included samples where the answer is not present in the context, to demonstrate how RAG handles this case.\n",
"\n",
"## Table of Contents\n",
"1. Setting up the Environment\n",
"\n",
"### Section A: Zero-Shot Learning\n",
"2. Data Preparation: SQuADv2 Dataset\n",
"3. Answering using Base gpt-3.5-turbo-0613 model\n",
"4. Fine-tuning and Answering using Fine-tuned model\n",
"5. **Evaluation**: How well does the model perform?\n",
"\n",
"### Section B: Few-Shot Learning\n",
"\n",
"6. Using Qdrant to Improve RAG Prompt\n",
"7. Fine-Tuning OpenAI Model with Qdrant\n",
"8. Evaluation\n",
"\n",
"9. **Conclusion**\n",
" - Aggregate Results\n",
" - Observations"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Terms, Definitions, and References\n",
"\n",
"**Retrieval Augmented Generation (RAG)?**\n",
"The phrase Retrieval Augmented Generation (RAG) comes from a [recent paper](https://arxiv.org/abs/2005.11401) by Lewis et al. from Facebook AI. The idea is to use a pre-trained language model (LM) to generate text, but to use a separate retrieval system to find relevant documents to condition the LM on. \n",
"\n",
"**What is Qdrant?**\n",
"Qdrant is an open-source vector search engine that allows you to search for similar vectors in a large dataset. It is built in Rust and here we'll use the Python client to interact with it. This is the Retrieval part of RAG.\n",
"\n",
"**What is Few-Shot Learning?**\n",
"Few-shot learning is a type of machine learning where the model is \"improved\" via training or fine-tuning on a small amount of data. In this case, we'll use it to fine-tune the RAG model on a small number of examples from the SQuAD dataset. This is the Augmented part of RAG.\n",
"\n",
"**What is Zero-Shot Learning?**\n",
"Zero-shot learning is a type of machine learning where the model is \"improved\" via training or fine-tuning without any dataset specific information. \n",
"\n",
"**What is Fine-Tuning?**\n",
"Fine-tuning is a type of machine learning where the model is \"improved\" via training or fine-tuning on a small amount of data. In this case, we'll use it to fine-tune the RAG model on a small number of examples from the SQuAD dataset. The LLM is what makes the Generation part of RAG."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Setting Up the Environment\n",
"\n",
"### Install and Import Dependencies"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"!pip install pandas openai tqdm tenacity scikit-learn tiktoken python-dotenv seaborn --upgrade --quiet"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"import time\n",
"\n",
"import pandas as pd\n",
"from openai import OpenAI\n",
"import tiktoken\n",
"import seaborn as sns\n",
"from tenacity import retry, wait_exponential\n",
"from tqdm import tqdm\n",
"from collections import defaultdict\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from sklearn.metrics import confusion_matrix\n",
"\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"tqdm.pandas()\n",
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Set your keys\n",
"Get your OpenAI keys [here](https://platform.openai.com/account/api-keys) and Qdrant keys after making a free cluster [here](https://cloud.qdrant.io/login)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"os.environ[\"QDRANT_URL\"] = \"https://xxx.cloud.qdrant.io:6333\"\n",
"os.environ[\"QDRANT_API_KEY\"] = \"xxx\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Section A\n",
"\n",
"## 2. Data Preparation: SQuADv2 Data Subsets\n",
"\n",
"For the purpose of demonstration, we'll make small slices from the train and validation splits of the [SQuADv2](https://rajpurkar.github.io/SQuAD-explorer/) dataset. This dataset has questions and contexts where the answer is not present in the context, to help us evaluate how LLM handles this case.\n",
"\n",
"We'll read the data from the JSON files and create a dataframe with the following columns: `question`, `context`, `answer`, `is_impossible`.\n",
"\n",
"### Download the Data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# !mkdir -p local_cache\n",
"# !wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json -O local_cache/train.json\n",
"# !wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json -O local_cache/dev.json"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Read JSON to DataFrame"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def json_to_dataframe_with_titles(json_data):\n",
" qas = []\n",
" context = []\n",
" is_impossible = []\n",
" answers = []\n",
" titles = []\n",
"\n",
" for article in json_data['data']:\n",
" title = article['title']\n",
" for paragraph in article['paragraphs']:\n",
" for qa in paragraph['qas']:\n",
" qas.append(qa['question'].strip())\n",
" context.append(paragraph['context'])\n",
" is_impossible.append(qa['is_impossible'])\n",
" \n",
" ans_list = []\n",
" for ans in qa['answers']:\n",
" ans_list.append(ans['text'])\n",
" answers.append(ans_list)\n",
" titles.append(title)\n",
"\n",
" df = pd.DataFrame({'title': titles, 'question': qas, 'context': context, 'is_impossible': is_impossible, 'answers': answers})\n",
" return df\n",
"\n",
"def get_diverse_sample(df, sample_size=100, random_state=42):\n",
" \"\"\"\n",
" Get a diverse sample of the dataframe by sampling from each title\n",
" \"\"\"\n",
" sample_df = df.groupby(['title', 'is_impossible']).apply(lambda x: x.sample(min(len(x), max(1, sample_size // 50)), random_state=random_state)).reset_index(drop=True)\n",
" \n",
" if len(sample_df) < sample_size:\n",
" remaining_sample_size = sample_size - len(sample_df)\n",
" remaining_df = df.drop(sample_df.index).sample(remaining_sample_size, random_state=random_state)\n",
" sample_df = pd.concat([sample_df, remaining_df]).sample(frac=1, random_state=random_state).reset_index(drop=True)\n",
"\n",
" return sample_df.sample(min(sample_size, len(sample_df)), random_state=random_state).reset_index(drop=True)\n",
"\n",
"train_df = json_to_dataframe_with_titles(json.load(open('local_cache/train.json')))\n",
"val_df = json_to_dataframe_with_titles(json.load(open('local_cache/dev.json')))\n",
"\n",
"df = get_diverse_sample(val_df, sample_size=100, random_state=42)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Answering using Base gpt-3.5-turbo-0613 model\n",
"\n",
"### 3.1 Zero Shot Prompt\n",
"\n",
"Let's start by using the base gpt-3.5-turbo-0613 model to answer the questions. This prompt is a simple concatenation of the question and context, with a separator token in between: `\\n\\n`. We've a simple instruction part of the prompt: \n",
"\n",
"> Answer the following Question based on the Context only. Only answer from the Context. If you don't know the answer, say 'I don't know'.\n",
"\n",
"Other prompts are possible, but this is a good starting point. We'll use this prompt to answer the questions in the validation set. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Function to get prompt messages\n",
"def get_prompt(row):\n",
" return [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": f\"\"\"Answer the following Question based on the Context only. Only answer from the Context. If you don't know the answer, say 'I don't know'.\n",
" Question: {row.question}\\n\\n\n",
" Context: {row.context}\\n\\n\n",
" Answer:\\n\"\"\",\n",
" },\n",
" ]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.2 Answering using Zero Shot Prompt\n",
"\n",
"Next, you'll need some re-usable functions which make an OpenAI API Call and return the answer. You'll use the `ChatCompletion.create` endpoint of the API, which takes a prompt and returns the completed text."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Function with tenacity for retries\n",
"@retry(wait=wait_exponential(multiplier=1, min=2, max=6))\n",
"def api_call(messages, model):\n",
" return client.chat.completions.create(\n",
" model=model,\n",
" messages=messages,\n",
" stop=[\"\\n\\n\"],\n",
" max_tokens=100,\n",
" temperature=0.0,\n",
" )\n",
"\n",
"\n",
"# Main function to answer question\n",
"def answer_question(row, prompt_func=get_prompt, model=\"gpt-3.5-turbo\"):\n",
" messages = prompt_func(row)\n",
" response = api_call(messages, model)\n",
" return response.choices[0].message.content"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"⏰ **Time to run: ~3 min**, 🛜 Needs Internet Connection"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Use progress_apply with tqdm for progress bar\n",
"df[\"generated_answer\"] = df.progress_apply(answer_question, axis=1)\n",
"df.to_json(\"local_cache/100_val.json\", orient=\"records\", lines=True)\n",
"df = pd.read_json(\"local_cache/100_val.json\", orient=\"records\", lines=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>title</th>\n",
" <th>question</th>\n",
" <th>context</th>\n",
" <th>is_impossible</th>\n",
" <th>answers</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Scottish_Parliament</td>\n",
" <td>What consequence of establishing the Scottish ...</td>\n",
" <td>A procedural consequence of the establishment ...</td>\n",
" <td>False</td>\n",
" <td>[able to vote on domestic legislation that app...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Imperialism</td>\n",
" <td>Imperialism is less often associated with whic...</td>\n",
" <td>The principles of imperialism are often genera...</td>\n",
" <td>True</td>\n",
" <td>[]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Economic_inequality</td>\n",
" <td>What issues can't prevent women from working o...</td>\n",
" <td>When a persons capabilities are lowered, they...</td>\n",
" <td>True</td>\n",
" <td>[]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Southern_California</td>\n",
" <td>What county are Los Angeles, Orange, San Diego...</td>\n",
" <td>Its counties of Los Angeles, Orange, San Diego...</td>\n",
" <td>True</td>\n",
" <td>[]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>French_and_Indian_War</td>\n",
" <td>When was the deportation of Canadians?</td>\n",
" <td>Britain gained control of French Canada and Ac...</td>\n",
" <td>True</td>\n",
" <td>[]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>95</th>\n",
" <td>Geology</td>\n",
" <td>In the layered Earth model, what is the inner ...</td>\n",
" <td>Seismologists can use the arrival times of sei...</td>\n",
" <td>True</td>\n",
" <td>[]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>96</th>\n",
" <td>Prime_number</td>\n",
" <td>What type of value would the Basel function ha...</td>\n",
" <td>The zeta function is closely related to prime ...</td>\n",
" <td>True</td>\n",
" <td>[]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>97</th>\n",
" <td>Fresno,_California</td>\n",
" <td>What does the San Joaquin Valley Railroad cros...</td>\n",
" <td>Passenger rail service is provided by Amtrak S...</td>\n",
" <td>True</td>\n",
" <td>[]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>98</th>\n",
" <td>Victoria_(Australia)</td>\n",
" <td>What party rules in Melbourne's inner regions?</td>\n",
" <td>The centre-left Australian Labor Party (ALP), ...</td>\n",
" <td>False</td>\n",
" <td>[The Greens, Australian Greens, Greens]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>99</th>\n",
" <td>Immune_system</td>\n",
" <td>The speed of the killing response of the human...</td>\n",
" <td>In humans, this response is activated by compl...</td>\n",
" <td>False</td>\n",
" <td>[signal amplification, signal amplification, s...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>100 rows × 5 columns</p>\n",
"</div>"
],
"text/plain": [
" title question \\\n",
"0 Scottish_Parliament What consequence of establishing the Scottish ... \n",
"1 Imperialism Imperialism is less often associated with whic... \n",
"2 Economic_inequality What issues can't prevent women from working o... \n",
"3 Southern_California What county are Los Angeles, Orange, San Diego... \n",
"4 French_and_Indian_War When was the deportation of Canadians? \n",
".. ... ... \n",
"95 Geology In the layered Earth model, what is the inner ... \n",
"96 Prime_number What type of value would the Basel function ha... \n",
"97 Fresno,_California What does the San Joaquin Valley Railroad cros... \n",
"98 Victoria_(Australia) What party rules in Melbourne's inner regions? \n",
"99 Immune_system The speed of the killing response of the human... \n",
"\n",
" context is_impossible \\\n",
"0 A procedural consequence of the establishment ... False \n",
"1 The principles of imperialism are often genera... True \n",
"2 When a persons capabilities are lowered, they... True \n",
"3 Its counties of Los Angeles, Orange, San Diego... True \n",
"4 Britain gained control of French Canada and Ac... True \n",
".. ... ... \n",
"95 Seismologists can use the arrival times of sei... True \n",
"96 The zeta function is closely related to prime ... True \n",
"97 Passenger rail service is provided by Amtrak S... True \n",
"98 The centre-left Australian Labor Party (ALP), ... False \n",
"99 In humans, this response is activated by compl... False \n",
"\n",
" answers \n",
"0 [able to vote on domestic legislation that app... \n",
"1 [] \n",
"2 [] \n",
"3 [] \n",
"4 [] \n",
".. ... \n",
"95 [] \n",
"96 [] \n",
"97 [] \n",
"98 [The Greens, Australian Greens, Greens] \n",
"99 [signal amplification, signal amplification, s... \n",
"\n",
"[100 rows x 5 columns]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Fine-tuning and Answering using Fine-tuned model\n",
"\n",
"For the complete fine-tuning process, please refer to the [OpenAI Fine-Tuning Docs](https://platform.openai.com/docs/guides/fine-tuning/use-a-fine-tuned-model).\n",
"\n",
"### 4.1 Prepare the Fine-Tuning Data\n",
"\n",
"We need to prepare the data for fine-tuning. We'll use a few samples from train split of same dataset as before, but we'll add the answer to the context. This will help the model learn to retrieve the answer from the context. \n",
"\n",
"Our instruction prompt is the same as before, and so is the system prompt. "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def dataframe_to_jsonl(df):\n",
" def create_jsonl_entry(row):\n",
" answer = row[\"answers\"][0] if row[\"answers\"] else \"I don't know\"\n",
" messages = [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": f\"\"\"Answer the following Question based on the Context only. Only answer from the Context. If you don't know the answer, say 'I don't know'.\n",
" Question: {row.question}\\n\\n\n",
" Context: {row.context}\\n\\n\n",
" Answer:\\n\"\"\",\n",
" },\n",
" {\"role\": \"assistant\", \"content\": answer},\n",
" ]\n",
" return json.dumps({\"messages\": messages})\n",
"\n",
" jsonl_output = df.apply(create_jsonl_entry, axis=1)\n",
" return \"\\n\".join(jsonl_output)\n",
"\n",
"train_sample = get_diverse_sample(train_df, sample_size=100, random_state=42)\n",
"\n",
"with open(\"local_cache/100_train.jsonl\", \"w\") as f:\n",
" f.write(dataframe_to_jsonl(train_sample))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Tip: 💡 Verify the Fine-Tuning Data**\n",
"\n",
"You can see this [cookbook](https://github.com/openai/openai-cookbook/blob/main/examples/Chat_finetuning_data_prep.ipynb) for more details on how to prepare the data for fine-tuning.\n",
"\n",
"### 4.2 Fine-Tune OpenAI Model\n",
"\n",
"If you're new to OpenAI Model Fine-Tuning, please refer to the [How to finetune Chat models](https://github.com/openai/openai-cookbook/blob/448a0595b84ced3bebc9a1568b625e748f9c1d60/examples/How_to_finetune_chat_models.ipynb) notebook. You can also refer to the [OpenAI Fine-Tuning Docs](platform.openai.com/docs/guides/fine-tuning/use-a-fine-tuned-model) for more details."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class OpenAIFineTuner:\n",
" \"\"\"\n",
" Class to fine tune OpenAI models\n",
" \"\"\"\n",
" def __init__(self, training_file_path, model_name, suffix):\n",
" self.training_file_path = training_file_path\n",
" self.model_name = model_name\n",
" self.suffix = suffix\n",
" self.file_object = None\n",
" self.fine_tuning_job = None\n",
" self.model_id = None\n",
"\n",
" def create_openai_file(self):\n",
" self.file_object = client.files.create(\n",
" file=open(self.training_file_path, \"r\"),\n",
" purpose=\"fine-tune\",\n",
" )\n",
"\n",
" def wait_for_file_processing(self, sleep_time=20):\n",
" while self.file_object.status != 'processed':\n",
" time.sleep(sleep_time)\n",
" self.file_object.refresh()\n",
" print(\"File Status: \", self.file_object.status)\n",
"\n",
" def create_fine_tuning_job(self):\n",
" self.fine_tuning_job = client.fine_tuning.jobs.create(\n",
" training_file=self.file_object[\"id\"],\n",
" model=self.model_name,\n",
" suffix=self.suffix,\n",
" )\n",
"\n",
" def wait_for_fine_tuning(self, sleep_time=45):\n",
" while self.fine_tuning_job.status != 'succeeded':\n",
" time.sleep(sleep_time)\n",
" self.fine_tuning_job.refresh()\n",
" print(\"Job Status: \", self.fine_tuning_job.status)\n",
"\n",
" def retrieve_fine_tuned_model(self):\n",
" self.model_id = client.fine_tuning.jobs.retrieve(self.fine_tuning_job[\"id\"]).fine_tuned_model\n",
" return self.model_id\n",
"\n",
" def fine_tune_model(self):\n",
" self.create_openai_file()\n",
" self.wait_for_file_processing()\n",
" self.create_fine_tuning_job()\n",
" self.wait_for_fine_tuning()\n",
" return self.retrieve_fine_tuned_model()\n",
"\n",
"fine_tuner = OpenAIFineTuner(\n",
" training_file_path=\"local_cache/100_train.jsonl\",\n",
" model_name=\"gpt-3.5-turbo\",\n",
" suffix=\"100trn20230907\"\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"⏰ **Time to run: ~10-20 minutes**, 🛜 Needs Internet Connection"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"model_id = fine_tuner.fine_tune_model()\n",
"model_id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 4.2.1 Try out the Fine-Tuned Model\n",
"\n",
"Let's try out the fine-tuned model on the same validation set as before. You'll use the same prompt as before, but you will use the fine-tuned model instead of the base model. Before you do that, you can make a simple call to get a sense of how the fine-tuned model is doing."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"completion = client.chat.completions.create(\n",
" model=model_id,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Hello!\"},\n",
" {\"role\": \"assistant\", \"content\": \"Hi, how can I help you today?\"},\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Can you answer the following question based on the given context? If not, say, I don't know:\\n\\nQuestion: What is the capital of France?\\n\\nContext: The capital of Mars is Gaia. Answer:\",\n",
" },\n",
" ],\n",
")\n",
"\n",
"print(completion.choices[0].message)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.3 Answer Using the Fine-Tuned Model\n",
"\n",
"This is the same as before, but you'll use the fine-tuned model instead of the base model.\n",
"\n",
"⏰ **Time to run: ~5 min**, 🛜 Needs Internet Connection"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"df[\"ft_generated_answer\"] = df.progress_apply(answer_question, model=model_id, axis=1)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Evaluation: How well does the model perform?\n",
"\n",
"To evaluate the model's performance, compare the predicted answer to the actual answers -- if any of the actual answers are present in the predicted answer, then it's a match. We've also created error categories to help you understand where the model is struggling.\n",
"\n",
"When we know that a correct answer exists in the context, we can measure the model's performance, there are 3 possible outcomes:\n",
"\n",
"1. ✅ **Answered Correctly**: The model responded the correct answer. It may have also included other answers that were not in the context.\n",
"2. ❎ **Skipped**: The model responded with \"I don't know\" (IDK) while the answer was present in the context. It's better than giving the wrong answer. It's better for the model say \"I don't know\" than giving the wrong answer. In our design, we know that a true answer exists and hence we're able to measure it -- this is not always the case. *This is a model error*. We exclude this from the overall error rate. \n",
"3. ❌ **Wrong**: The model responded with an incorrect answer. **This is a model ERROR.**\n",
"\n",
"When we know that a correct answer does not exist in the context, we can measure the model's performance, there are 2 possible outcomes:\n",
"\n",
"4. ❌ **Hallucination**: The model responded with an answer, when \"I don't know\" was expected. **This is a model ERROR.** \n",
"5. ✅ **I don't know**: The model responded with \"I don't know\" (IDK) and the answer was not present in the context. **This is a model WIN.**"
]
},
{
"cell_type": "code",
"execution_count": 193,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"\n",
"class Evaluator:\n",
" def __init__(self, df):\n",
" self.df = df\n",
" self.y_pred = pd.Series() # Initialize as empty Series\n",
" self.labels_answer_expected = [\"✅ Answered Correctly\", \"❎ Skipped\", \"❌ Wrong Answer\"]\n",
" self.labels_idk_expected = [\"❌ Hallucination\", \"✅ I don't know\"]\n",
"\n",
" def _evaluate_answer_expected(self, row, answers_column):\n",
" generated_answer = row[answers_column].lower()\n",
" actual_answers = [ans.lower() for ans in row[\"answers\"]]\n",
" return (\n",
" \"✅ Answered Correctly\" if any(ans in generated_answer for ans in actual_answers)\n",
" else \"❎ Skipped\" if generated_answer == \"i don't know\"\n",
" else \"❌ Wrong Answer\"\n",
" )\n",
"\n",
" def _evaluate_idk_expected(self, row, answers_column):\n",
" generated_answer = row[answers_column].lower()\n",
" return (\n",
" \"❌ Hallucination\" if generated_answer != \"i don't know\"\n",
" else \"✅ I don't know\"\n",
" )\n",
"\n",
" def _evaluate_single_row(self, row, answers_column):\n",
" is_impossible = row[\"is_impossible\"]\n",
" return (\n",
" self._evaluate_answer_expected(row, answers_column) if not is_impossible\n",
" else self._evaluate_idk_expected(row, answers_column)\n",
" )\n",
"\n",
" def evaluate_model(self, answers_column=\"generated_answer\"):\n",
" self.y_pred = pd.Series(self.df.apply(self._evaluate_single_row, answers_column=answers_column, axis=1))\n",
" freq_series = self.y_pred.value_counts()\n",
" \n",
" # Counting rows for each scenario\n",
" total_answer_expected = len(self.df[self.df['is_impossible'] == False])\n",
" total_idk_expected = len(self.df[self.df['is_impossible'] == True])\n",
" \n",
" freq_answer_expected = (freq_series / total_answer_expected * 100).round(2).reindex(self.labels_answer_expected, fill_value=0)\n",
" freq_idk_expected = (freq_series / total_idk_expected * 100).round(2).reindex(self.labels_idk_expected, fill_value=0)\n",
" return freq_answer_expected.to_dict(), freq_idk_expected.to_dict()\n",
"\n",
" def print_eval(self):\n",
" answer_columns=[\"generated_answer\", \"ft_generated_answer\"]\n",
" baseline_correctness, baseline_idk = self.evaluate_model()\n",
" ft_correctness, ft_idk = self.evaluate_model(self.df, answer_columns[1])\n",
" print(\"When the model should answer correctly:\")\n",
" eval_df = pd.merge(\n",
" baseline_correctness.rename(\"Baseline\"),\n",
" ft_correctness.rename(\"Fine-Tuned\"),\n",
" left_index=True,\n",
" right_index=True,\n",
" )\n",
" print(eval_df)\n",
" print(\"\\n\\n\\nWhen the model should say 'I don't know':\")\n",
" eval_df = pd.merge(\n",
" baseline_idk.rename(\"Baseline\"),\n",
" ft_idk.rename(\"Fine-Tuned\"),\n",
" left_index=True,\n",
" right_index=True,\n",
" )\n",
" print(eval_df)\n",
" \n",
" def plot_model_comparison(self, answer_columns=[\"generated_answer\", \"ft_generated_answer\"], scenario=\"answer_expected\", nice_names=[\"Baseline\", \"Fine-Tuned\"]):\n",
" \n",
" results = []\n",
" for col in answer_columns:\n",
" answer_expected, idk_expected = self.evaluate_model(col)\n",
" if scenario == \"answer_expected\":\n",
" results.append(answer_expected)\n",
" elif scenario == \"idk_expected\":\n",
" results.append(idk_expected)\n",
" else:\n",
" raise ValueError(\"Invalid scenario\")\n",
" \n",
" \n",
" results_df = pd.DataFrame(results, index=nice_names)\n",
" if scenario == \"answer_expected\":\n",
" results_df = results_df.reindex(self.labels_answer_expected, axis=1)\n",
" elif scenario == \"idk_expected\":\n",
" results_df = results_df.reindex(self.labels_idk_expected, axis=1)\n",
" \n",
" melted_df = results_df.reset_index().melt(id_vars='index', var_name='Status', value_name='Frequency')\n",
" sns.set_theme(style=\"whitegrid\", palette=\"icefire\")\n",
" g = sns.catplot(data=melted_df, x='Frequency', y='index', hue='Status', kind='bar', height=5, aspect=2)\n",
"\n",
" # Annotating each bar\n",
" for p in g.ax.patches:\n",
" g.ax.annotate(f\"{p.get_width():.0f}%\", (p.get_width()+5, p.get_y() + p.get_height() / 2),\n",
" textcoords=\"offset points\",\n",
" xytext=(0, 0),\n",
" ha='center', va='center')\n",
" plt.ylabel(\"Model\")\n",
" plt.xlabel(\"Percentage\")\n",
" plt.xlim(0, 100)\n",
" plt.tight_layout()\n",
" plt.title(scenario.replace(\"_\", \" \").title())\n",
" plt.show()\n",
"\n",
"\n",
"# Compare the results by merging into one dataframe\n",
"evaluator = Evaluator(df)\n",
"# evaluator.evaluate_model(answers_column=\"ft_generated_answer\")\n",
"# evaluator.plot_model_comparison([\"generated_answer\", \"ft_generated_answer\"], scenario=\"answer_expected\", nice_names=[\"Baseline\", \"Fine-Tuned\"])"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
"# Optionally, save the results to a JSON file\n",
"df.to_json(\"local_cache/100_val_ft.json\", orient=\"records\", lines=True)\n",
"df = pd.read_json(\"local_cache/100_val_ft.json\", orient=\"records\", lines=True)"
]
},
{
"cell_type": "code",
"execution_count": 194,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1203.25x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluator.plot_model_comparison([\"generated_answer\", \"ft_generated_answer\"], scenario=\"answer_expected\", nice_names=[\"Baseline\", \"Fine-Tuned\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice that the fine-tuned model skips questions more often -- and makes fewer mistakes. This is because the fine-tuned model is more conservative and skips questions when it's not sure."
]
},
{
"cell_type": "code",
"execution_count": 195,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1158.25x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluator.plot_model_comparison([\"generated_answer\", \"ft_generated_answer\"], scenario=\"idk_expected\", nice_names=[\"Baseline\", \"Fine-Tuned\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice that the fine-tuned model has learnt to say \"I don't know\" a lot better than the prompt. Or, the model has gotten good at skipping questions.\n",
"\n",
"### Observations\n",
"\n",
"1. The fine-tuned model is better at saying \"I don't know\"\n",
"2. Hallucinations drop from 100% to 15% with fine-tuning\n",
"3. Wrong answers drop from 17% to 6% with fine-tuning\n",
"\n",
"**Correct answers also drop from 83% to 60% with fine-tuning** - this is because the fine-tuned model is **more conservative** and says \"I don't know\" more often. This is a good thing because it's better to say \"I don't know\" than to give a wrong answer.\n",
"\n",
"That said, we want to improve the correctness of the model, even if that increases the hallucinations. We're looking for a model that is both correct and conservative, striking a balance between the two. We'll use Qdrant and Few-Shot Learning to achieve this."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"**💪 You're 2/3rds of the way there! Keep reading!**\n",
"\n",
"# Section B: Few Shot Learning\n",
"\n",
"We'll select a few examples from the dataset, including cases where the answer is not present in the context. We'll then use these examples to create a prompt that we can use to fine-tune the model. We'll then measure the performance of the fine-tuned model.\n",
"\n",
"**What is next?**\n",
"\n",
"6. Fine-Tuning OpenAI Model with Qdrant\n",
" 6.1 Embed the Fine-Tuning Data\n",
" 6.2 Embedding the Questions\n",
"7. Using Qdrant to Improve RAG Prompt\n",
"8. Evaluation\n",
"\n",
"\n",
"## 6. Fine-Tuning OpenAI Model with Qdrant\n",
"\n",
"So far, we've been using the OpenAI model to answer questions without using examples of the answer. The previous step made it work better on in-context examples, while this one helps it generalize to unseen data, and attempt to learn when to say \"I don't know\" and when to give an answer.\n",
"\n",
"This is where few-shot learning comes in!\n",
"\n",
"Few-shot learning is a type of transfer learning that allows us to answer questions where the answer is not present in the context. We can do this by providing a few examples of the answer we're looking for, and the model will learn to answer questions where the answer is not present in the context."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.1 Embed the Training Data\n",
"\n",
"Embeddings are a way to represent sentences as an array of floats. We'll use the embeddings to find the most similar questions to the ones we're looking for."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from qdrant_client import QdrantClient\n",
"from qdrant_client.http import models\n",
"from qdrant_client.http.models import PointStruct\n",
"from qdrant_client.http.models import Distance, VectorParams"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we've the Qdrant imports in place, "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"qdrant_client = QdrantClient(\n",
" url=os.getenv(\"QDRANT_URL\"), api_key=os.getenv(\"QDRANT_API_KEY\"), timeout=6000, prefer_grpc=True\n",
")\n",
"\n",
"collection_name = \"squadv2-cookbook\"\n",
"\n",
"# # Create the collection, run this only once\n",
"# qdrant_client.recreate_collection(\n",
"# collection_name=collection_name,\n",
"# vectors_config=VectorParams(size=384, distance=Distance.COSINE),\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastembed.embedding import DefaultEmbedding\n",
"from typing import List\n",
"import numpy as np\n",
"import pandas as pd\n",
"from tqdm.notebook import tqdm\n",
"\n",
"tqdm.pandas()\n",
"\n",
"embedding_model = DefaultEmbedding()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.2 Embedding the Questions\n",
"\n",
"Next, you'll embed the entire training set questions. You'll use the question to question similarity to find the most similar questions to the question we're looking for. This is a workflow which is used in RAG to leverage the OpenAI model ability of incontext learning with more examples. This is what we call Few Shot Learning here.\n",
"\n",
"**❗️⏰ Important Note: This step can take up to 3 hours to complete. Please be patient. If you see Out of Memory errors or Kernel Crashes, please reduce the batch size to 32, restart the kernel and run the notebook again. This code needs to be run only ONCE.**\n",
"\n",
"## Function Breakdown for `generate_points_from_dataframe`\n",
"\n",
"1. **Initialization**: `batch_size = 512` and `total_batches` set the stage for how many questions will be processed in one go. This is to prevent memory issues. If your machine can handle more, feel free to increase the batch size. If your kernel crashes, reduce the batch size to 32 and try again.\n",
"2. **Progress Bar**: `tqdm` gives you a nice progress bar so you don't fall asleep.\n",
"3. **Batch Loop**: The for-loop iterates through batches. `start_idx` and `end_idx` define the slice of the DataFrame to process.\n",
"4. **Generate Embeddings**: `batch_embeddings = embedding_model.embed(batch, batch_size=batch_size)` - This is where the magic happens. Your questions get turned into embeddings.\n",
"5. **PointStruct Generation**: Using `.progress_apply`, it turns each row into a `PointStruct` object. This includes an ID, the embedding vector, and other metadata.\n",
"\n",
"Returns the list of `PointStruct` objects, which can be used to create a collection in Qdrant."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def generate_points_from_dataframe(df: pd.DataFrame) -> List[PointStruct]:\n",
" batch_size = 512\n",
" questions = df[\"question\"].tolist()\n",
" total_batches = len(questions) // batch_size + 1\n",
" \n",
" pbar = tqdm(total=len(questions), desc=\"Generating embeddings\")\n",
" \n",
" # Generate embeddings in batches to improve performance\n",
" embeddings = []\n",
" for i in range(total_batches):\n",
" start_idx = i * batch_size\n",
" end_idx = min((i + 1) * batch_size, len(questions))\n",
" batch = questions[start_idx:end_idx]\n",
" \n",
" batch_embeddings = embedding_model.embed(batch, batch_size=batch_size)\n",
" embeddings.extend(batch_embeddings)\n",
" pbar.update(len(batch))\n",
" \n",
" pbar.close()\n",
" \n",
" # Convert embeddings to list of lists\n",
" embeddings_list = [embedding.tolist() for embedding in embeddings]\n",
" \n",
" # Create a temporary DataFrame to hold the embeddings and existing DataFrame columns\n",
" temp_df = df.copy()\n",
" temp_df[\"embeddings\"] = embeddings_list\n",
" temp_df[\"id\"] = temp_df.index\n",
" \n",
" # Generate PointStruct objects using DataFrame apply method\n",
" points = temp_df.progress_apply(\n",
" lambda row: PointStruct(\n",
" id=row[\"id\"],\n",
" vector=row[\"embeddings\"],\n",
" payload={\n",
" \"question\": row[\"question\"],\n",
" \"title\": row[\"title\"],\n",
" \"context\": row[\"context\"],\n",
" \"is_impossible\": row[\"is_impossible\"],\n",
" \"answers\": row[\"answers\"],\n",
" },\n",
" ),\n",
" axis=1,\n",
" ).tolist()\n",
"\n",
" return points\n",
"\n",
"points = generate_points_from_dataframe(train_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Upload the Embeddings to Qdrant\n",
"\n",
"Note that configuring Qdrant is outside the scope of this notebook. Please refer to the [Qdrant](https://qdrant.tech) for more information. We used a timeout of 600 seconds for the upload, and grpc compression to speed up the upload."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"operation_info = qdrant_client.upsert(\n",
" collection_name=collection_name, wait=True, points=points\n",
")\n",
"print(operation_info)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Using Qdrant to Improve RAG Prompt\n",
"\n",
"Now that we've uploaded the embeddings to Qdrant, we can use Qdrant to find the most similar questions to the question we're looking for. We'll use the top 5 most similar questions to create a prompt that we can use to fine-tune the model. We'll then measure the performance of the fine-tuned model on the same validation set, but with few shot prompting!\n",
"\n",
"Our main function `get_few_shot_prompt` serves as the workhorse for generating prompts for few-shot learning. It does this by retrieving similar questions from Qdrant - a vector search engine, using an embeddings model. Here is the high-level workflow:\n",
"\n",
"1. Retrieve similar questions from Qdrant where the **answer is present** in the context\n",
"2. Retrieve similar questions from Qdrant where the **answer is IMPOSSIBLE** i.e. the expected answer is \"I don't know\" to find in the context\n",
"3. Create a prompt using the retrieved questions\n",
"4. Fine-tune the model using the prompt\n",
"5. Evaluate the fine-tuned model on the validation set with the same prompting technique"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_few_shot_prompt(row):\n",
"\n",
" query, row_context = row[\"question\"], row[\"context\"]\n",
"\n",
" embeddings = list(embedding_model.embed([query]))\n",
" query_embedding = embeddings[0].tolist()\n",
"\n",
" num_of_qa_to_retrieve = 5\n",
"\n",
" # Query Qdrant for similar questions that have an answer\n",
" q1 = qdrant_client.search(\n",
" collection_name=collection_name,\n",
" query_vector=query_embedding,\n",
" with_payload=True,\n",
" limit=num_of_qa_to_retrieve,\n",
" query_filter=models.Filter(\n",
" must=[\n",
" models.FieldCondition(\n",
" key=\"is_impossible\",\n",
" match=models.MatchValue(\n",
" value=False,\n",
" ),\n",
" ),\n",
" ],\n",
" )\n",
" )\n",
"\n",
" # Query Qdrant for similar questions that are IMPOSSIBLE to answer\n",
" q2 = qdrant_client.search(\n",
" collection_name=collection_name,\n",
" query_vector=query_embedding,\n",
" query_filter=models.Filter(\n",
" must=[\n",
" models.FieldCondition(\n",
" key=\"is_impossible\",\n",
" match=models.MatchValue(\n",
" value=True,\n",
" ),\n",
" ),\n",
" ]\n",
" ),\n",
" with_payload=True,\n",
" limit=num_of_qa_to_retrieve,\n",
" )\n",
"\n",
"\n",
" instruction = \"\"\"Answer the following Question based on the Context only. Only answer from the Context. If you don't know the answer, say 'I don't know'.\\n\\n\"\"\"\n",
" # If there is a next best question, add it to the prompt\n",
" \n",
" def q_to_prompt(q):\n",
" question, context = q.payload[\"question\"], q.payload[\"context\"]\n",
" answer = q.payload[\"answers\"][0] if len(q.payload[\"answers\"]) > 0 else \"I don't know\"\n",
" return [\n",
" {\n",
" \"role\": \"user\", \n",
" \"content\": f\"\"\"Question: {question}\\n\\nContext: {context}\\n\\nAnswer:\"\"\"\n",
" },\n",
" {\"role\": \"assistant\", \"content\": answer},\n",
" ]\n",
"\n",
" rag_prompt = []\n",
" \n",
" if len(q1) >= 1:\n",
" rag_prompt += q_to_prompt(q1[1])\n",
" if len(q2) >= 1:\n",
" rag_prompt += q_to_prompt(q2[1])\n",
" if len(q1) >= 1:\n",
" rag_prompt += q_to_prompt(q1[2])\n",
" \n",
" \n",
"\n",
" rag_prompt += [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": f\"\"\"Question: {query}\\n\\nContext: {row_context}\\n\\nAnswer:\"\"\"\n",
" },\n",
" ]\n",
"\n",
" rag_prompt = [{\"role\": \"system\", \"content\": instruction}] + rag_prompt\n",
" return rag_prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# ⏰ Time: 2 min\n",
"train_sample[\"few_shot_prompt\"] = train_sample.progress_apply(get_few_shot_prompt, axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Fine-Tuning OpenAI Model with Qdrant\n",
"\n",
"### 7.1 Upload the Fine-Tuning Data to OpenAI"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Prepare the OpenAI File format i.e. JSONL from train_sample\n",
"def dataframe_to_jsonl(df):\n",
" def create_jsonl_entry(row):\n",
" messages = row[\"few_shot_prompt\"]\n",
" return json.dumps({\"messages\": messages})\n",
"\n",
" jsonl_output = df.progress_apply(create_jsonl_entry, axis=1)\n",
" return \"\\n\".join(jsonl_output)\n",
"\n",
"with open(\"local_cache/100_train_few_shot.jsonl\", \"w\") as f:\n",
" f.write(dataframe_to_jsonl(train_sample))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 7.2 Fine-Tune the Model\n",
"\n",
"⏰ **Time to run: ~15-30 minutes**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fine_tuner = OpenAIFineTuner(\n",
" training_file_path=\"local_cache/100_train_few_shot.jsonl\",\n",
" model_name=\"gpt-3.5-turbo\",\n",
" suffix=\"trnfewshot20230907\"\n",
" )\n",
"\n",
"model_id = fine_tuner.fine_tune_model()\n",
"model_id"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Let's try this out\n",
"completion = client.chat.completions.create(\n",
" model=model_id,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Can you answer the following question based on the given context? If not, say, I don't know:\\n\\nQuestion: What is the capital of France?\\n\\nContext: The capital of Mars is Gaia. Answer:\",\n",
" },\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"I don't know\",\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Question: Where did Maharana Pratap die?\\n\\nContext: Rana Pratap's defiance of the mighty Mughal empire, almost alone and unaided by the other Rajput states, constitute a glorious saga of Rajput valour and the spirit of self sacrifice for cherished principles. Rana Pratap's methods of guerrilla warfare was later elaborated further by Malik Ambar, the Deccani general, and by Emperor Shivaji.\\nAnswer:\",\n",
" },\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"I don't know\",\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Question: Who did Rana Pratap fight against?\\n\\nContext: In stark contrast to other Rajput rulers who accommodated and formed alliances with the various Muslim dynasties in the subcontinent, by the time Pratap ascended to the throne, Mewar was going through a long standing conflict with the Mughals which started with the defeat of his grandfather Rana Sanga in the Battle of Khanwa in 1527 and continued with the defeat of his father Udai Singh II in Siege of Chittorgarh in 1568. Pratap Singh, gained distinction for his refusal to form any political alliance with the Mughal Empire and his resistance to Muslim domination. The conflicts between Pratap Singh and Akbar led to the Battle of Haldighati. Answer:\",\n",
" },\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"Akbar\",\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Question: Which state is Chittorgarh in?\\n\\nContext: Chittorgarh, located in the southern part of the state of Rajasthan, 233 km (144.8 mi) from Ajmer, midway between Delhi and Mumbai on the National Highway 8 (India) in the road network of Golden Quadrilateral. Chittorgarh is situated where National Highways No. 76 & 79 intersect. Answer:\",\n",
" },\n",
" ],\n",
")\n",
"print(\"Correct Answer: Rajasthan\\nModel Answer:\")\n",
"print(completion.choices[0].message)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"⏰ **Time to run: 5-15 min**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df[\"ft_generated_answer_few_shot\"] = df.progress_apply(answer_question, model=model_id, prompt_func=get_few_shot_prompt, axis=1)\n",
"df.to_json(\"local_cache/100_val_ft_few_shot.json\", orient=\"records\", lines=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Evaluation\n",
"\n",
"But how well does the model perform? Let's compare the results from the 3 different models we've looked at so far:"
]
},
{
"cell_type": "code",
"execution_count": 203,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1203.25x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluator = Evaluator(df)\n",
"evaluator.plot_model_comparison([\"generated_answer\", \"ft_generated_answer\", \"ft_generated_answer_few_shot\"], scenario=\"answer_expected\", nice_names=[\"Baseline\", \"Fine-Tuned\", \"Fine-Tuned with Few-Shot\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is quite amazing -- we're able to get the best of both worlds! We're able to get the model to be both correct and conservative: \n",
"\n",
"1. The model is correct 83% of the time -- this is the same as the base model\n",
"2. The model gives the wrong answer only 8% of the time -- down from 17% with the base model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, let's look at the hallucinations. We want to reduce the hallucinations, but not at the cost of correctness. We want to strike a balance between the two. We've struck a good balance here:\n",
"\n",
"1. The model hallucinates 53% of the time -- down from 100% with the base model\n",
"2. The model says \"I don't know\" 47% of the time -- up from NEVER with the base model"
]
},
{
"cell_type": "code",
"execution_count": 202,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1158.25x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluator.plot_model_comparison([\"generated_answer\", \"ft_generated_answer\", \"ft_generated_answer_few_shot\"], scenario=\"idk_expected\", nice_names=[\"Baseline\", \"Fine-Tuned\", \"Fine-Tuned with Few-Shot\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Few Shot Fine-Tuning with Qdrant is a great way to control and steer the performance of your RAG system. Here, we made the model less conservative compared to zero shot and more confident by using Qdrant to find similar questions. \n",
"\n",
"You can also use Qdrant to make the model more conservative. We did this by giving examples of questions where the answer is not present in the context. \n",
"This is biasing the model to say \"I don't know\" more often. \n",
"\n",
"Similarly, one can also use Qdrant to make the model more confident by giving examples of questions where the answer is present in the context. This biases the model to give an answer more often. The trade-off is that the model will also hallucinate more often.\n",
"\n",
"You can make this trade off by adjusting the training data: distribution of questions and examples, as well as the kind and number of examples you retrieve from Qdrant.\n",
"\n",
"## 9. Conclusion\n",
"\n",
"In this notebook, we've demonstrated how to fine-tune OpenAI models for specific use-cases. We've also demonstrated how to use Qdrant and Few-Shot Learning to improve the performance of the model.\n",
"\n",
"### Aggregate Results\n",
"\n",
"So far, we've looked at the results for each scenario separately, i.e. each scenario summed to 100. Let's look at the results as an aggregate to get a broader sense of how the model is performing:\n",
"\n",
"| Category | Base | Fine-Tuned | Fine-Tuned with Qdrant |\n",
"| --- | --- | --- | --- |\n",
"| Correct | 44% | 32% | 44% |\n",
"| Skipped | 0% | 18% | 5% |\n",
"| Wrong | 9% | 3% | 4% |\n",
"| Hallucination | 47% | 7% | 25% |\n",
"| I don't know | 0% | 40% | 22% |"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Observations\n",
"\n",
"#### Compared to base model\n",
"1. The few shot fine-tuned with Qdrant model is as good as the base model at answering questions where the answer is present in the context. \n",
"2. The few shot fine-tuned with Qdrant model is better at saying \"I don't know\" when the answer is not present in the context.\n",
"3. The few shot fine-tuned with Qdrant model is better at reducing hallucinations.\n",
"\n",
"\n",
"#### Compared to fine-tuned model\n",
"1. The few shot fine-tuned with Qdrant model gets more correct answers than the fine-tuned model: **83% of the questions are answered correctly vs 60%** for the fine-tuned model\n",
"2. The few shot fine-tuned with Qdrant model is better at deciding when to say \"I don't know\" when the answer is not present in the context. **34% skip rate for the plain fine-tuning mode, vs 9% for the few shot fine-tuned with Qdrant model**\n",
"\n",
"\n",
"Now, you should be able to:\n",
"\n",
"1. Notice the trade-offs between number of correct answers and hallucinations -- and how training dataset choice influences that!\n",
"2. Fine-tune OpenAI models for specific use-cases and use Qdrant to improve the performance of your RAG model\n",
"3. Get started on how to evaluate the performance of your RAG model"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "fst",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}