You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/fine-tuned_qa/olympics-3-train-qa.ipynb

649 lines
25 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<span style=\"color:orange; font-weight:bold\">Note: To answer questions based on text documents, we recommend the procedure in <a href=\"https://github.com/openai/openai-cookbook/blob/main/examples/Question_answering_using_embeddings.ipynb\">Question Answering using Embeddings</a>. Some of the code below may rely on <a href=\"https://github.com/openai/openai-cookbook/tree/main/transition_guides_for_deprecated_API_endpoints\">deprecated API endpoints</a>.</span>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Train a fine-tuning model specialized for Q&A\n",
"This notebook will utilize the dataset of context, question and answer pairs to additionally create adversarial questions and context pairs, where the question was not generated on that context. In those cases the model will be prompted to answer \"No sufficient context for answering the question\". We will also train a discriminator model, which predicts whether the question can be answered based on the context or not.\n",
"\n",
"We will add hard adversarial examples as well, which will be based either on semantically similar sections, or neighbouring sections, originating from the same article."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>title</th>\n",
" <th>heading</th>\n",
" <th>content</th>\n",
" <th>tokens</th>\n",
" <th>context</th>\n",
" <th>questions</th>\n",
" <th>answers</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2020 Summer Olympics</td>\n",
" <td>Summary</td>\n",
" <td>The 2020 Summer Olympics (Japanese: 2020年夏季オリン...</td>\n",
" <td>713</td>\n",
" <td>2020 Summer Olympics\\nSummary\\n\\nThe 2020 Summ...</td>\n",
" <td>1. What is the 2020 Summer Olympics?\\n2. When ...</td>\n",
" <td>1. The 2020 Summer Olympics is an internationa...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2020 Summer Olympics</td>\n",
" <td>Host city selection</td>\n",
" <td>The International Olympic Committee (IOC) vote...</td>\n",
" <td>126</td>\n",
" <td>2020 Summer Olympics\\nHost city selection\\n\\nT...</td>\n",
" <td>1. \\n2. \\n3. \\n4.</td>\n",
" <td>1. What is the International Olympic Committee...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2020 Summer Olympics</td>\n",
" <td>Impact of the COVID-19 pandemic</td>\n",
" <td>In January 2020, concerns were raised about th...</td>\n",
" <td>369</td>\n",
" <td>2020 Summer Olympics\\nImpact of the COVID-19 p...</td>\n",
" <td>1. What was the COVID-19 pandemic?\\n2. How did...</td>\n",
" <td>1. The COVID-19 pandemic was a pandemic that o...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2020 Summer Olympics</td>\n",
" <td>Qualifying event cancellation and postponement</td>\n",
" <td>Concerns about the pandemic began to affect qu...</td>\n",
" <td>298</td>\n",
" <td>2020 Summer Olympics\\nQualifying event cancell...</td>\n",
" <td>1. What was the original location of the Asia ...</td>\n",
" <td>1. The original location of the Asia &amp; Oceania...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2020 Summer Olympics</td>\n",
" <td>Effect on doping tests</td>\n",
" <td>Mandatory doping tests were being severely res...</td>\n",
" <td>163</td>\n",
" <td>2020 Summer Olympics\\nEffect on doping tests\\n...</td>\n",
" <td>1. What was the COVID-19 pandemic?\\n2. What di...</td>\n",
" <td>1. The COVID-19 pandemic was a pandemic that o...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" title heading \\\n",
"0 2020 Summer Olympics Summary \n",
"1 2020 Summer Olympics Host city selection \n",
"2 2020 Summer Olympics Impact of the COVID-19 pandemic \n",
"3 2020 Summer Olympics Qualifying event cancellation and postponement \n",
"4 2020 Summer Olympics Effect on doping tests \n",
"\n",
" content tokens \\\n",
"0 The 2020 Summer Olympics (Japanese: 2020年夏季オリン... 713 \n",
"1 The International Olympic Committee (IOC) vote... 126 \n",
"2 In January 2020, concerns were raised about th... 369 \n",
"3 Concerns about the pandemic began to affect qu... 298 \n",
"4 Mandatory doping tests were being severely res... 163 \n",
"\n",
" context \\\n",
"0 2020 Summer Olympics\\nSummary\\n\\nThe 2020 Summ... \n",
"1 2020 Summer Olympics\\nHost city selection\\n\\nT... \n",
"2 2020 Summer Olympics\\nImpact of the COVID-19 p... \n",
"3 2020 Summer Olympics\\nQualifying event cancell... \n",
"4 2020 Summer Olympics\\nEffect on doping tests\\n... \n",
"\n",
" questions \\\n",
"0 1. What is the 2020 Summer Olympics?\\n2. When ... \n",
"1 1. \\n2. \\n3. \\n4. \n",
"2 1. What was the COVID-19 pandemic?\\n2. How did... \n",
"3 1. What was the original location of the Asia ... \n",
"4 1. What was the COVID-19 pandemic?\\n2. What di... \n",
"\n",
" answers \n",
"0 1. The 2020 Summer Olympics is an internationa... \n",
"1 1. What is the International Olympic Committee... \n",
"2 1. The COVID-19 pandemic was a pandemic that o... \n",
"3 1. The original location of the Asia & Oceania... \n",
"4 1. The COVID-19 pandemic was a pandemic that o... "
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import openai\n",
"import pandas as pd\n",
"df = pd.read_csv('olympics-data/olympics_qa.csv')\n",
"olympics_search_fileid = \"file-c3shd8wqF3vSCKaukW4Jr1TT\"\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Split the sections into a training and testing set"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3014, 754)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)\n",
"len(train_df), len(test_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"we check that the separator we intend to use isn't present within the contexts"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.context.str.contains('->').sum()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.1 Create the fine-tuning datasets for Q&A and discriminator models\n",
"The fine-tuning dataset is created in the following way. For every corresponding question, answer and context pair we create:\n",
"- Positive example: correct question, answer, context pair\n",
"- Negative examples:\n",
" - random negative example, where the random context is paired with the question \n",
" - two hard negative examples\n",
" - one originating from the same wikipedia article\n",
" - another, which is most similar to the correct context\n",
"\n",
"This process is noisy, as sometimes the question might be answerable given a different context, but on average we hope this won't affect the performance too much.\n",
"\n",
"We apply the same process of dataset creation for both the discriminator, and the Q&A answering model. We apply the process separately for the training and testing set, to ensure that the examples from the training set don't feature within the test set."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"def get_random_similar_contexts(question, context, file_id=olympics_search_fileid, search_model='ada', max_rerank=10):\n",
" \"\"\"\n",
" Find similar contexts to the given context using the search file\n",
" \"\"\"\n",
" try:\n",
" # TODO: openai.Engine(search_model) is deprecated\n",
" results = openai.Engine(search_model).search(\n",
" search_model=search_model, \n",
" query=question, \n",
" max_rerank=max_rerank,\n",
" file=file_id\n",
" )\n",
" candidates = []\n",
" for result in results['data'][:3]:\n",
" if result['text'] == context:\n",
" continue\n",
" candidates.append(result['text'])\n",
" random_candidate = random.choice(candidates)\n",
" return random_candidate\n",
" except Exception as e:\n",
" print(e)\n",
" return \"\"\n",
"\n",
"def create_fine_tuning_dataset(df, discriminator=False, n_negative=1, add_related=False):\n",
" \"\"\"\n",
" Create a dataset for fine tuning the OpenAI model; either for a discriminator model, \n",
" or a model specializing in Q&A, where it says if no relevant context is found.\n",
"\n",
" Parameters\n",
" ----------\n",
" df: pd.DataFrame\n",
" The dataframe containing the question, answer and context pairs\n",
" discriminator: bool\n",
" Whether to create a dataset for the discriminator\n",
" n_negative: int\n",
" The number of random negative samples to add (using a random context)\n",
" add_related: bool\n",
" Whether to add the related contexts to the correct context. These are hard negative examples\n",
"\n",
" Returns\n",
" -------\n",
" pd.DataFrame\n",
" The dataframe containing the prompts and completions, ready for fine-tuning\n",
" \"\"\"\n",
" rows = []\n",
" for i, row in df.iterrows():\n",
" for q, a in zip((\"1.\" + row.questions).split('\\n'), (\"1.\" + row.answers).split('\\n')):\n",
" if len(q) >10 and len(a) >10:\n",
" if discriminator:\n",
" rows.append({\"prompt\":f\"{row.context}\\nQuestion: {q[2:].strip()}\\n Related:\", \"completion\":f\" yes\"})\n",
" else:\n",
" rows.append({\"prompt\":f\"{row.context}\\nQuestion: {q[2:].strip()}\\nAnswer:\", \"completion\":f\" {a[2:].strip()}\"})\n",
"\n",
" for i, row in df.iterrows():\n",
" for q in (\"1.\" + row.questions).split('\\n'):\n",
" if len(q) >10:\n",
" for j in range(n_negative + (2 if add_related else 0)):\n",
" random_context = \"\"\n",
" if j == 0 and add_related:\n",
" # add the related contexts based on originating from the same wikipedia page\n",
" subset = df[(df.title == row.title) & (df.context != row.context)]\n",
" \n",
" if len(subset) < 1:\n",
" continue\n",
" random_context = subset.sample(1).iloc[0].context\n",
" if j == 1 and add_related:\n",
" # add the related contexts based on the most similar contexts according to the search\n",
" random_context = get_random_similar_contexts(q[2:].strip(), row.context, search_model='ada', max_rerank=10)\n",
" else:\n",
" while True:\n",
" # add random context, which isn't the correct context\n",
" random_context = df.sample(1).iloc[0].context\n",
" if random_context != row.context:\n",
" break\n",
" if discriminator:\n",
" rows.append({\"prompt\":f\"{random_context}\\nQuestion: {q[2:].strip()}\\n Related:\", \"completion\":f\" no\"})\n",
" else:\n",
" rows.append({\"prompt\":f\"{random_context}\\nQuestion: {q[2:].strip()}\\nAnswer:\", \"completion\":f\" No appropriate context found to answer the question.\"})\n",
"\n",
" return pd.DataFrame(rows) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We apply the same process of dataset creation for both the discriminator, and the Q&A answering model. We apply the process separately for the training and testing set, to ensure that the examples from the training set don't feature within the test set."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": []
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"for name, is_disc in [('discriminator', True), ('qa', False)]:\n",
" for train_test, dt in [('train', train_df), ('test', test_df)]:\n",
" ft = create_fine_tuning_dataset(dt, discriminator=is_disc, n_negative=1, add_related=True)\n",
" ft.to_json(f'{name}_{train_test}.jsonl', orient='records', lines=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We formatted the data according to the recommendations from the fine-tuning tool, which is available using\n",
"> openai tools fine_tunes.prepare_data -f qa_train.jsonl\n",
"\n",
"We highly recommend that you use this tool, which suggests improvements in your data formatting for fine-tuning.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.2 Submit the datasets for fine-tuning"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": []
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"!openai api fine_tunes.create -t \"olympics-data/discriminator_train.jsonl\" -v \"olympics-data/discriminator_test.jsonl\" --batch_size 16 --compute_classification_metrics --classification_positive_class \" yes\" --model ada"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": []
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"!openai api fine_tunes.create -t \"olympics-data/qa_train.jsonl\" -v \"olympics-data/qa_test.jsonl\" --batch_size 16"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.3 Using the fine-tuned models\n",
"\n",
"We will now use the fine-tuned discriminator and the fine-tuned Q&A model. By requesting logprobs, we can see how certain the discriminator is in a `yes` vs `no` answer."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<OpenAIObject at 0x7fe812e602b0> JSON: {\n",
" \" no\": -10.819577,\n",
" \" yes\": -2.045765e-05\n",
" }]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ft_discriminator = \"curie:ft-openai-internal-2021-08-23-23-58-57\"\n",
"ft_qa = \"curie:ft-openai-internal-2021-08-23-17-54-10\"\n",
"\n",
"def apply_ft_discriminator(context, question, discriminator_model):\n",
" \"\"\"\n",
" Apply the fine tuned discriminator to a question, to assess whether it can be answered from the context.\n",
" \"\"\"\n",
" prompt = f\"{context}\\nQuestion: {question}\\n Related:\"\n",
" result = openai.chat.completions.create(model=discriminator_model, prompt=prompt, max_tokens=1, temperature=0, top_p=1, n=1, logprobs=2)\n",
" return result['choices'][0]['logprobs']['top_logprobs']\n",
"\n",
"apply_ft_discriminator('The first human-made object in space was the Soviet Union satellite Sputnik 1 on 4 October 1957.', \n",
" 'What was the first human-made object in space?', ft_discriminator)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the model can generalize well to different contexts and questions. "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' The first human-made object in space was the Soviet Union satellite Sputnik 1 on 4 October 1957'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def apply_ft_qa_answer(context, question, answering_model):\n",
" \"\"\"\n",
" Apply the fine tuned discriminator to a question\n",
" \"\"\"\n",
" prompt = f\"{context}\\nQuestion: {question}\\nAnswer:\"\n",
" result = openai.chat.completions.create(model=answering_model, prompt=prompt, max_tokens=30, temperature=0, top_p=1, n=1, stop=['.','\\n'])\n",
" return result['choices'][0]['text']\n",
"\n",
"apply_ft_qa_answer('The first human-made object in space was the Soviet Union satellite Sputnik 1 on 4 October 1957.', \n",
" 'What was the first human-made object in space?', ft_qa)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the model can answer the question, when the context is appropriate."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' The Soviet Union was the first country to successfully launch a satellite into space'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"apply_ft_qa_answer('The first human-made object in space was the Soviet Union satellite Sputnik 1 on 4 October 1957.',\n",
" 'What is impressive about the Soviet Union?', ft_qa)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' No appropriate context found to answer the question'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"apply_ft_qa_answer('The first human-made object in space was the Soviet Union satellite Sputnik 1 on 4 October 1957.',\n",
" 'How many cars were produced in the Soviet Union in 1970?', ft_qa)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the model knows when to answer the question, and when to say that insufficient context is present to answer the question."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also combine a discriminator and a base model, or a fine-tuned Q&A model. Discriminator can essentially serve as a decision whether the question can be answered given the context or not."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' Weather could cause a sport event to have no crowd'"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def answer_question_conditionally(answering_model, discriminator_model, context, question, discriminator_logprob_yes_modifier=0):\n",
" logprobs = apply_ft_discriminator(context, question, discriminator_model)\n",
" yes_logprob = logprobs[' yes'] if ' yes' in logprobs else -100\n",
" no_logprob = logprobs[' no'] if ' no' in logprobs else -100\n",
" if yes_logprob + discriminator_logprob_yes_modifier < no_logprob:\n",
" return \" No appropriate context found to answer the question based on the discriminator.\"\n",
" return apply_ft_qa_answer(context, question, answering_model)\n",
"answer_question_conditionally(ft_qa, ft_discriminator, \n",
" \"Crowdless games are a rare although not unheard-of occurrence in sports. \\\n",
" When they do occur, it is usually the result of events beyond the control \\\n",
" of the teams or fans, such as weather-related concerns, public health concerns, \\\n",
" or wider civil disturbances unrelated to the game. For instance, \\\n",
" the COVID-19 pandemic caused many sports leagues around the world \\\n",
" to be played behind closed doors.\",\n",
" \"Could weather cause a sport event to have no crowd?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above function illustrates how to potentially combine a discriminator and a fine-tuned Q&A model. This gives a more fine-grained control over how certain we want the model to be before it answers the question.\n",
"\n",
"We'll now take a look on how answers endpoint works - combining search to retrieve the relevant context from a knowledge base, and then using the fine-tuned Q&A model to answer the question."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.4 Answering the question based on a knowledge base\n",
"Finally we can use a logic similar to the [/answers](https://beta.openai.com/docs/api-reference/answers) endpoint, where we first search for the relevant context, and then ask a Q&A model to answer the question given that context. If you'd like to see the implementation details, check out the [`answers_with_ft.py`](answers_with_ft.py) file."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" Canada won the Women's football tournament at the 2020 Olympic games\""
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from answers_with_ft import answer_question\n",
"answer_question(olympics_search_fileid, ft_qa, \"Which country won the Women's football tournament at the 2020 Olympic games?\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.9 64-bit ('3.9.9')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "cb9817b186a29e4e9713184d901f26c1ee05ad25243d878baff7f31bb1fef480"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}