mirror of
https://github.com/openai/openai-cookbook
synced 2024-11-08 01:10:29 +00:00
18bd018d27
Fix typo with duplicated "the the"
524 lines
34 KiB
Plaintext
524 lines
34 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"<span style=\"color:orange; font-weight:bold\">Note: To answer questions based on text documents, we recommend the procedure in <a href=\"https://github.com/openai/openai-cookbook/blob/main/examples/Question_answering_using_embeddings.ipynb\">Question Answering using Embeddings</a>. Some of the code below may rely on <a href=\"https://github.com/openai/openai-cookbook/tree/main/transition_guides_for_deprecated_API_endpoints\">deprecated API endpoints</a>.</span>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 1. Collect Wikipedia data about Olympic Games 2020\n",
|
|
"\n",
|
|
"The idea of this project is to create a question answering model, based on a few paragraphs of provided text. Base GPT-3 models do a good job at answering questions when the answer is contained within the paragraph, however if the answer isn't contained, the base models tend to try their best to answer anyway, often leading to confabulated answers. \n",
|
|
"\n",
|
|
"To create a model which answers questions only if there is sufficient context for doing so, we first create a dataset of questions and answers based on paragraphs of text. In order to train the model to answer only when the answer is present, we also add adversarial examples, where the question doesn't match the context. In those cases, we ask the model to output \"No sufficient context for answering the question\". \n",
|
|
"\n",
|
|
"We will perform this task in three notebooks:\n",
|
|
"1. The first (this) notebook focuses on collecting recent data, which GPT-3 didn't see during it's pre-training. We picked the topic of Olympic Games 2020 (which actually took place in the summer of 2021), and downloaded 713 unique pages. We organized the dataset by individual sections, which will serve as context for asking and answering the questions.\n",
|
|
"2. The [second notebook](olympics-2-create-qa.ipynb) will utilize Davinci-instruct to ask a few questions based on a Wikipedia section, as well as answer those questions, based on that section.\n",
|
|
"3. The [third notebook](olympics-3-train-qa.ipynb) will utilize the dataset of context, question and answer pairs to additionally create adversarial questions and context pairs, where the question was not generated on that context. In those cases the model will be prompted to answer \"No sufficient context for answering the question\". We will also train a discriminator model, which predicts whether the question can be answered based on the context or not."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 1.1 Data extraction using the wikipedia API\n",
|
|
"Extracting the data will take about half an hour, and processing will likely take about as much."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"909"
|
|
]
|
|
},
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import wikipedia\n",
|
|
"\n",
|
|
"\n",
|
|
"def filter_olympic_2020_titles(titles):\n",
|
|
" \"\"\"\n",
|
|
" Get the titles which are related to Olympic games hosted in 2020, given a list of titles\n",
|
|
" \"\"\"\n",
|
|
" titles = [title for title in titles if '2020' in title and 'olympi' in title.lower()]\n",
|
|
" \n",
|
|
" return titles\n",
|
|
"\n",
|
|
"def get_wiki_page(title):\n",
|
|
" \"\"\"\n",
|
|
" Get the wikipedia page given a title\n",
|
|
" \"\"\"\n",
|
|
" try:\n",
|
|
" return wikipedia.page(title)\n",
|
|
" except wikipedia.exceptions.DisambiguationError as e:\n",
|
|
" return wikipedia.page(e.options[0])\n",
|
|
" except wikipedia.exceptions.PageError as e:\n",
|
|
" return None\n",
|
|
"\n",
|
|
"def recursively_find_all_pages(titles, titles_so_far=set()):\n",
|
|
" \"\"\"\n",
|
|
" Recursively find all the pages that are linked to the Wikipedia titles in the list\n",
|
|
" \"\"\"\n",
|
|
" all_pages = []\n",
|
|
" \n",
|
|
" titles = list(set(titles) - titles_so_far)\n",
|
|
" titles = filter_olympic_2020_titles(titles)\n",
|
|
" titles_so_far.update(titles)\n",
|
|
" for title in titles:\n",
|
|
" page = get_wiki_page(title)\n",
|
|
" if page is None:\n",
|
|
" continue\n",
|
|
" all_pages.append(page)\n",
|
|
"\n",
|
|
" new_pages = recursively_find_all_pages(page.links, titles_so_far)\n",
|
|
" for pg in new_pages:\n",
|
|
" if pg.title not in [p.title for p in all_pages]:\n",
|
|
" all_pages.append(pg)\n",
|
|
" titles_so_far.update(page.links)\n",
|
|
" return all_pages\n",
|
|
"\n",
|
|
"\n",
|
|
"pages = recursively_find_all_pages([\"2020 Summer Olympics\"])\n",
|
|
"len(pages)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 1.2 Filtering the Wikipedia pages and splitting them into sections by headings\n",
|
|
"We remove sections unlikely to contain textual information, and ensure that each section is not longer than the token limit"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('Bermuda at the 2020 Summer Olympics',\n",
|
|
" 'Equestrian',\n",
|
|
" \"Bermuda entered one dressage rider into the Olympic competition by finishing in the top four, outside the group selection, of the individual FEI Olympic Rankings for Groups D and E (North, Central, and South America), marking the country's recurrence to the sport after an eight-year absence. The quota was later withdrawn, following an injury of Annabelle Collins' main horse Joyero and a failure to obtain minimum eligibility requirements (MER) aboard a new horse Chuppy Checker.\",\n",
|
|
" 104)"
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"import re\n",
|
|
"from typing import Set\n",
|
|
"from transformers import GPT2TokenizerFast\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"from nltk.tokenize import sent_tokenize\n",
|
|
"\n",
|
|
"tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")\n",
|
|
"\n",
|
|
"def count_tokens(text: str) -> int:\n",
|
|
" \"\"\"count the number of tokens in a string\"\"\"\n",
|
|
" return len(tokenizer.encode(text))\n",
|
|
"\n",
|
|
"def reduce_long(\n",
|
|
" long_text: str, long_text_tokens: bool = False, max_len: int = 590\n",
|
|
") -> str:\n",
|
|
" \"\"\"\n",
|
|
" Reduce a long text to a maximum of `max_len` tokens by potentially cutting at a sentence end\n",
|
|
" \"\"\"\n",
|
|
" if not long_text_tokens:\n",
|
|
" long_text_tokens = count_tokens(long_text)\n",
|
|
" if long_text_tokens > max_len:\n",
|
|
" sentences = sent_tokenize(long_text.replace(\"\\n\", \" \"))\n",
|
|
" ntokens = 0\n",
|
|
" for i, sentence in enumerate(sentences):\n",
|
|
" ntokens += 1 + count_tokens(sentence)\n",
|
|
" if ntokens > max_len:\n",
|
|
" return \". \".join(sentences[:i][:-1]) + \".\"\n",
|
|
"\n",
|
|
" return long_text\n",
|
|
"\n",
|
|
"discard_categories = ['See also', 'References', 'External links', 'Further reading', \"Footnotes\",\n",
|
|
" \"Bibliography\", \"Sources\", \"Citations\", \"Literature\", \"Footnotes\", \"Notes and references\",\n",
|
|
" \"Photo gallery\", \"Works cited\", \"Photos\", \"Gallery\", \"Notes\", \"References and sources\",\n",
|
|
" \"References and notes\",]\n",
|
|
"\n",
|
|
"\n",
|
|
"def extract_sections(\n",
|
|
" wiki_text: str,\n",
|
|
" title: str,\n",
|
|
" max_len: int = 1500,\n",
|
|
" discard_categories: Set[str] = discard_categories,\n",
|
|
") -> str:\n",
|
|
" \"\"\"\n",
|
|
" Extract the sections of a Wikipedia page, discarding the references and other low information sections\n",
|
|
" \"\"\"\n",
|
|
" if len(wiki_text) == 0:\n",
|
|
" return []\n",
|
|
"\n",
|
|
" # find all headings and the coresponding contents\n",
|
|
" headings = re.findall(\"==+ .* ==+\", wiki_text)\n",
|
|
" for heading in headings:\n",
|
|
" wiki_text = wiki_text.replace(heading, \"==+ !! ==+\")\n",
|
|
" contents = wiki_text.split(\"==+ !! ==+\")\n",
|
|
" contents = [c.strip() for c in contents]\n",
|
|
" assert len(headings) == len(contents) - 1\n",
|
|
"\n",
|
|
" cont = contents.pop(0).strip()\n",
|
|
" outputs = [(title, \"Summary\", cont, count_tokens(cont)+4)]\n",
|
|
" \n",
|
|
" # discard the discard categories, accounting for a tree structure\n",
|
|
" max_level = 100\n",
|
|
" keep_group_level = max_level\n",
|
|
" remove_group_level = max_level\n",
|
|
" nheadings, ncontents = [], []\n",
|
|
" for heading, content in zip(headings, contents):\n",
|
|
" plain_heading = \" \".join(heading.split(\" \")[1:-1])\n",
|
|
" num_equals = len(heading.split(\" \")[0])\n",
|
|
" if num_equals <= keep_group_level:\n",
|
|
" keep_group_level = max_level\n",
|
|
"\n",
|
|
" if num_equals > remove_group_level:\n",
|
|
" if (\n",
|
|
" num_equals <= keep_group_level\n",
|
|
" ):\n",
|
|
" continue\n",
|
|
" keep_group_level = max_level\n",
|
|
" if plain_heading in discard_categories:\n",
|
|
" remove_group_level = num_equals\n",
|
|
" keep_group_level = max_level\n",
|
|
" continue\n",
|
|
" nheadings.append(heading.replace(\"=\", \"\").strip())\n",
|
|
" ncontents.append(content)\n",
|
|
" remove_group_level = max_level\n",
|
|
"\n",
|
|
" # count the tokens of each section\n",
|
|
" ncontent_ntokens = [\n",
|
|
" count_tokens(c)\n",
|
|
" + 3\n",
|
|
" + count_tokens(\" \".join(h.split(\" \")[1:-1]))\n",
|
|
" - (1 if len(c) == 0 else 0)\n",
|
|
" for h, c in zip(nheadings, ncontents)\n",
|
|
" ]\n",
|
|
"\n",
|
|
" # Create a tuple of (title, section_name, content, number of tokens)\n",
|
|
" outputs += [(title, h, c, t) if t<max_len \n",
|
|
" else (title, h, reduce_long(c, max_len), count_tokens(reduce_long(c,max_len))) \n",
|
|
" for h, c, t in zip(nheadings, ncontents, ncontent_ntokens)]\n",
|
|
" \n",
|
|
" return outputs\n",
|
|
"\n",
|
|
"# Example page being processed into sections\n",
|
|
"bermuda_page = get_wiki_page('Bermuda at the 2020 Summer Olympics')\n",
|
|
"ber = extract_sections(bermuda_page.content, bermuda_page.title)\n",
|
|
"\n",
|
|
"# Example section\n",
|
|
"ber[-1]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 1.2.1 We create a dataset and filter out any sections with fewer than 40 tokens, as those are unlikely to contain enough context to ask a good question."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Token indices sequence length is longer than the specified maximum sequence length for this model (1060 > 1024). Running this sequence through the model will result in indexing errors\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>title</th>\n",
|
|
" <th>heading</th>\n",
|
|
" <th>content</th>\n",
|
|
" <th>tokens</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>2020 Summer Olympics</td>\n",
|
|
" <td>Summary</td>\n",
|
|
" <td>The 2020 Summer Olympics (Japanese: 2020年夏季オリン...</td>\n",
|
|
" <td>713</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>2020 Summer Olympics</td>\n",
|
|
" <td>Host city selection</td>\n",
|
|
" <td>The International Olympic Committee (IOC) vote...</td>\n",
|
|
" <td>126</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>2020 Summer Olympics</td>\n",
|
|
" <td>Impact of the COVID-19 pandemic</td>\n",
|
|
" <td>In January 2020, concerns were raised about th...</td>\n",
|
|
" <td>369</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>2020 Summer Olympics</td>\n",
|
|
" <td>Qualifying event cancellation and postponement</td>\n",
|
|
" <td>Concerns about the pandemic began to affect qu...</td>\n",
|
|
" <td>298</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>2020 Summer Olympics</td>\n",
|
|
" <td>Effect on doping tests</td>\n",
|
|
" <td>Mandatory doping tests were being severely res...</td>\n",
|
|
" <td>163</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" title heading \\\n",
|
|
"0 2020 Summer Olympics Summary \n",
|
|
"1 2020 Summer Olympics Host city selection \n",
|
|
"2 2020 Summer Olympics Impact of the COVID-19 pandemic \n",
|
|
"3 2020 Summer Olympics Qualifying event cancellation and postponement \n",
|
|
"4 2020 Summer Olympics Effect on doping tests \n",
|
|
"\n",
|
|
" content tokens \n",
|
|
"0 The 2020 Summer Olympics (Japanese: 2020年夏季オリン... 713 \n",
|
|
"1 The International Olympic Committee (IOC) vote... 126 \n",
|
|
"2 In January 2020, concerns were raised about th... 369 \n",
|
|
"3 Concerns about the pandemic began to affect qu... 298 \n",
|
|
"4 Mandatory doping tests were being severely res... 163 "
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"res = []\n",
|
|
"for page in pages:\n",
|
|
" res += extract_sections(page.content, page.title)\n",
|
|
"df = pd.DataFrame(res, columns=[\"title\", \"heading\", \"content\", \"tokens\"])\n",
|
|
"df = df[df.tokens>40]\n",
|
|
"df = df.drop_duplicates(['title','heading'])\n",
|
|
"df = df.reset_index().drop('index',axis=1) # reset index\n",
|
|
"df.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Save the section dataset\n",
|
|
"We will save the section dataset, for the [next notebook](olympics-2-create-qa.ipynb)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"df.to_csv('olympics-data/olympics_sections.csv', index=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 1.3 (Optional) Exploring the data "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Concerns and controversies at the 2020 Summer Olympics 51\n",
|
|
"United States at the 2020 Summer Olympics 46\n",
|
|
"Great Britain at the 2020 Summer Olympics 42\n",
|
|
"Canada at the 2020 Summer Olympics 39\n",
|
|
"Olympic Games 39\n",
|
|
"Name: title, dtype: int64"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"df.title.value_counts().head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"There appear to be winter and summer Olympics 2020. We chose to leave a little ambiguity and noise in the dataset, even though we were interested in only Summer Olympics 2020."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"True 3567\n",
|
|
"False 305\n",
|
|
"Name: title, dtype: int64"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"df.title.str.contains('Summer').value_counts()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"False 3774\n",
|
|
"True 98\n",
|
|
"Name: title, dtype: int64"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"df.title.str.contains('Winter').value_counts()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"from matplotlib import pyplot as plt\n",
|
|
"\n",
|
|
"df = pd.read_csv('olympics-data/olympics_sections.csv')\n",
|
|
"df[['tokens']].hist()\n",
|
|
"# add axis descriptions and title\n",
|
|
"plt.xlabel('Number of tokens')\n",
|
|
"plt.ylabel('Number of Wikipedia sections')\n",
|
|
"plt.title('Distribution of number of tokens in Wikipedia sections')\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can see that the majority of section are fairly short (less than 500 tokens)."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3.9.9 64-bit ('3.9.9')",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.9"
|
|
},
|
|
"orig_nbformat": 4,
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "cb9817b186a29e4e9713184d901f26c1ee05ad25243d878baff7f31bb1fef480"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|