mirror of
https://github.com/openai/openai-cookbook
synced 2024-11-08 01:10:29 +00:00
776 lines
29 KiB
Plaintext
776 lines
29 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 167,
|
|||
|
"id": "9e3839a6-9146-4f60-b74b-19abbc24278d",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"import openai\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from tqdm.notebook import tqdm\n",
|
|||
|
"import pickle\n",
|
|||
|
"from transformers import GPT2TokenizerFast\n",
|
|||
|
"\n",
|
|||
|
"ENGINE_NAME = \"curie\"\n",
|
|||
|
"\n",
|
|||
|
"DOC_EMBEDDING_MODEL = f\"text-search-{ENGINE_NAME}-doc-001\"\n",
|
|||
|
"QUERY_EMBEDDING_MODEL = f\"text-search-{MODEL_NAME}-query-001\"\n",
|
|||
|
"\n",
|
|||
|
"COMPLETIONS_MODEL = \"text-davinci-002\""
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "9312f62f-e208-4030-a648-71ad97aee74f",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Question Answering\n",
|
|||
|
"\n",
|
|||
|
"Many use cases require GPT to respond to user questions with insightful answers. For example, a customer support chatbot may need to provide answers to common questions. The GPT models have picked up a lot of general knowledge in training, but we often need to ingest and use a body of more specific information.\n",
|
|||
|
"\n",
|
|||
|
"In this notebook we will demonstrate a method for augmenting GPT with a large body of additional contextual information by using embeddings search and retrieval. We'll be using a dataset of Wikipedia articles about the 2020 Summer Olympic Games. Please see [this notebook](examples/fine-tuned_qa/olympics-1-collect-data.ipynb) to follow the data gathering process.\n",
|
|||
|
"\n",
|
|||
|
"GPT-3 isn't an expert on the Olympics by default:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 176,
|
|||
|
"id": "a167516c-7c19-4bda-afa5-031aa0ae13bb",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"\"The 2020 Summer Olympics men's high jump was won by Mariusz Przybylski of Poland.\""
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 176,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"prompt = \"Who won the 2020 Summer Olympics men's high jump?\"\n",
|
|||
|
"\n",
|
|||
|
"openai.Completion.create(\n",
|
|||
|
" prompt=prompt,\n",
|
|||
|
" temperature=0,\n",
|
|||
|
" max_tokens=300,\n",
|
|||
|
" top_p=1,\n",
|
|||
|
" frequency_penalty=0,\n",
|
|||
|
" presence_penalty=0,\n",
|
|||
|
" engine=COMPLETIONS_MODEL\n",
|
|||
|
")[\"choices\"][0][\"text\"].strip(\"\\n\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "47204cce-a7d5-4c81-ab6e-53323026e08c",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Mariusz Przybylski is a professional footballer from Poland, and not much of a high jumper! Evidently GPT-3 needs some assistance here. (In fact we'd ideally like the model to be more conservative and say \"I don't know\" rather than making a guess - see [this guide](examples/fine-tuned_qa) for an exploration of that topic.)\n",
|
|||
|
"\n",
|
|||
|
"When the total required context is short, we can include it in the prompt directly. In this case, we can use this information taken from Wikipedia:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 179,
|
|||
|
"id": "fceaf665-2602-4788-bc44-9eb256a6f955",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"\"Gianmarco Tamberi and Mutaz Essa Barshim won the 2020 Summer Olympics men's high jump.\""
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 179,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"prompt = \"\"\"\n",
|
|||
|
"The men's high jump event at the 2020 Summer Olympics took place between 30 July and 1 August 2021 at the Olympic Stadium.\n",
|
|||
|
"33 athletes from 24 nations competed; the total possible number depended on how many nations would use universality places \n",
|
|||
|
"to enter athletes in addition to the 32 qualifying through mark or ranking (no universality places were used in 2021).\n",
|
|||
|
"Italian athlete Gianmarco Tamberi along with Qatari athlete Mutaz Essa Barshim emerged as joint winners of the event following\n",
|
|||
|
"a tie between both of them as they cleared 2.37m. Both Tamberi and Barshim agreed to share the gold medal in a rare instance\n",
|
|||
|
"where the athletes of different nations had agreed to share the same medal in the history of Olympics. \n",
|
|||
|
"Barshim in particular was heard to ask a competition official \"Can we have two golds?\" in response to being offered a \n",
|
|||
|
"'jump off'. Maksim Nedasekau of Belarus took bronze. The medals were the first ever in the men's high jump for Italy and \n",
|
|||
|
"Belarus, the first gold in the men's high jump for Italy and Qatar, and the third consecutive medal in the men's high jump\n",
|
|||
|
"for Qatar (all by Barshim). Barshim became only the second man to earn three medals in high jump, joining Patrik Sjöberg\n",
|
|||
|
"of Sweden (1984 to 1992).\n",
|
|||
|
"\n",
|
|||
|
"Who won the 2020 Summer Olympics men's high jump?\"\"\"\n",
|
|||
|
"\n",
|
|||
|
"openai.Completion.create(\n",
|
|||
|
" prompt=prompt,\n",
|
|||
|
" temperature=0,\n",
|
|||
|
" max_tokens=300,\n",
|
|||
|
" top_p=1,\n",
|
|||
|
" frequency_penalty=0,\n",
|
|||
|
" presence_penalty=0,\n",
|
|||
|
" engine=COMPLETIONS_MODEL\n",
|
|||
|
")[\"choices\"][0][\"text\"].strip(\"\\n\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "ee85ee77-d8d2-4788-b57e-0785f2d7e2e3",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"But this technique only works when the dataset of extra content that the model may need to know is small enough to fit in a single prompt. What do we do when we need the model to choose relevant contextual information from within a large body of information?\n",
|
|||
|
"\n",
|
|||
|
"**In this notebook we demonstrate a method for augmenting GPT with a large body of additional contextual information by using embeddings search and retrieval.** This method answers queries in two steps: first it retrieves the information relevant to the query, then it writes an answer tailored to the question based on the retrieved information. The first step uses the Embedding API, the second step uses the Completion API.\n",
|
|||
|
" \n",
|
|||
|
"The steps are:\n",
|
|||
|
"* Preprocess the contextual information by splitting it into chunks and create an embedding vector for each chunk.\n",
|
|||
|
"* On receiving a query, embed the query in the same vector space as the context chunks and find the context embeddings which are most similar to the query.\n",
|
|||
|
"* Prepend the most relevant context embeddings to the query prompt.\n",
|
|||
|
"* Submit the question along with the most relevant context to GPT, and receive an answer which makes use of the provided contextual information."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "0c9bfea5-a028-4191-b9f1-f210d76ec4e3",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# 1) Preprocess the contextual information\n",
|
|||
|
"\n",
|
|||
|
"We preprocess the contextual information by creating an embedding vector for each chunk of context in the knowledge base. An embedding is a vector of numbers that helps us understand how similar or different things are. The closer two embeddings are to each other, the more similar the things are that they represent.\n",
|
|||
|
"\n",
|
|||
|
"This indexing stage can be executed offline and only runs once to precompute the indexes for the dataset so that each piece of content can be retrieved later. Since this is a small example, we will store and search the embeddings locally. If you have a larger dataset, consider using a vector search engine like Pinecone or Weaviate to power the search.\n",
|
|||
|
"\n",
|
|||
|
"For the purposes of this tutorial we chose to use Curie embeddings, which are at a very good price and performance point. Since we will be using these embeddings for retrieval, we’ll use the search embeddings. "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 22,
|
|||
|
"id": "cc9c8d69-e234-48b4-87e3-935970e1523a",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"3980 rows in the data.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>content</th>\n",
|
|||
|
" <th>tokens</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>title</th>\n",
|
|||
|
" <th>heading</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>United States at the 2020 Summer Olympics</th>\n",
|
|||
|
" <th>Diving</th>\n",
|
|||
|
" <td>U.S. divers qualified for the following indivi...</td>\n",
|
|||
|
" <td>89</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>Austria at the 2020 Summer Olympics</th>\n",
|
|||
|
" <th>Summary</th>\n",
|
|||
|
" <td>Austria competed at the 2020 Summer Olympics i...</td>\n",
|
|||
|
" <td>115</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2020 Women's Rugby Sevens Final Olympic Qualification Tournament</th>\n",
|
|||
|
" <th>Knockout stage</th>\n",
|
|||
|
" <td>With two Olympic places available, the top eig...</td>\n",
|
|||
|
" <td>49</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>Italy at the 2020 Summer Olympics</th>\n",
|
|||
|
" <th>Karate</th>\n",
|
|||
|
" <td>Italy entered five karateka into the inaugural...</td>\n",
|
|||
|
" <td>148</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2020 United States Olympic Team Trials (wrestling)</th>\n",
|
|||
|
" <th>Summary</th>\n",
|
|||
|
" <td>The 2020 United States Olympic Team Trials for...</td>\n",
|
|||
|
" <td>119</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" content \\\n",
|
|||
|
"title heading \n",
|
|||
|
"United States at the 2020 Summer Olympics Diving U.S. divers qualified for the following indivi... \n",
|
|||
|
"Austria at the 2020 Summer Olympics Summary Austria competed at the 2020 Summer Olympics i... \n",
|
|||
|
"2020 Women's Rugby Sevens Final Olympic Qualifi... Knockout stage With two Olympic places available, the top eig... \n",
|
|||
|
"Italy at the 2020 Summer Olympics Karate Italy entered five karateka into the inaugural... \n",
|
|||
|
"2020 United States Olympic Team Trials (wrestling) Summary The 2020 United States Olympic Team Trials for... \n",
|
|||
|
"\n",
|
|||
|
" tokens \n",
|
|||
|
"title heading \n",
|
|||
|
"United States at the 2020 Summer Olympics Diving 89 \n",
|
|||
|
"Austria at the 2020 Summer Olympics Summary 115 \n",
|
|||
|
"2020 Women's Rugby Sevens Final Olympic Qualifi... Knockout stage 49 \n",
|
|||
|
"Italy at the 2020 Summer Olympics Karate 148 \n",
|
|||
|
"2020 United States Olympic Team Trials (wrestling) Summary 119 "
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 22,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Load the dataset, obtained from this notebook <LINK>.\n",
|
|||
|
"# This dataset has already been split into sections, one row for each section of the Wikipedia page.\n",
|
|||
|
"\n",
|
|||
|
"df = pd.read_csv('olympics_sections.csv')\n",
|
|||
|
"df = df.set_index([\"title\", \"heading\"])\n",
|
|||
|
"print(f\"{len(df)} rows in the data.\")\n",
|
|||
|
"df.sample(5)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 215,
|
|||
|
"id": "ba475f30-ef7f-431c-b60d-d5970b62ad09",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def get_embedding(text: str, engine: str) -> np.array:\n",
|
|||
|
" result = openai.Engine(engine).embeddings(input=text)\n",
|
|||
|
" return np.array(result[\"data\"][0][\"embedding\"])\n",
|
|||
|
"\n",
|
|||
|
"def get_doc_embedding(text: str) -> np.array:\n",
|
|||
|
" return get_embedding(text, DOC_EMBEDDING_MODEL)\n",
|
|||
|
"\n",
|
|||
|
"def get_query_embedding(text: str) -> np.array:\n",
|
|||
|
" return get_embedding(text, QUERY_EMBEDDING_MODEL)\n",
|
|||
|
"\n",
|
|||
|
"def compute_doc_embeddings(df: pd.DataFrame) -> pd.DataFrame:\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" Create an embedding for each document in the dataframe using the OpenAI Embeddings API.\n",
|
|||
|
" \n",
|
|||
|
" Return a dictionary that maps between each embedding vector and the index of the row that it corresponds to.\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" return {\n",
|
|||
|
" idx: get_doc_embedding(r.content) for idx, r in tqdm(df.iterrows(), total=len(df))\n",
|
|||
|
" }"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 48,
|
|||
|
"id": "ab50bfca-cb02-41c6-b338-4400abe1d86e",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"application/vnd.jupyter.widget-view+json": {
|
|||
|
"model_id": "5d2192513f1349febdaef6ccb9ab1046",
|
|||
|
"version_major": 2,
|
|||
|
"version_minor": 0
|
|||
|
},
|
|||
|
"text/plain": [
|
|||
|
" 0%| | 0/3980 [00:00<?, ?it/s]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"CPU times: user 1min 46s, sys: 7.92 s, total: 1min 54s\n",
|
|||
|
"Wall time: 11min 16s\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%%time\n",
|
|||
|
"context_embeddings = compute_doc_embeddings(df)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 249,
|
|||
|
"id": "a298f666-f31f-4356-a882-e2170524a637",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"147.552"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 249,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import sys\n",
|
|||
|
"sys.getsizeof(context_embeddings) / 1000"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "aa32cf88-9edb-4dc6-b4cf-a16a8de7d304",
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"# 2) Find the most similar context embeddings to the question embedding\n",
|
|||
|
"\n",
|
|||
|
"At the time of question-answering, to answer the user's query we compute the embedding of the question and use it to find the most similar context chunks.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 210,
|
|||
|
"id": "dcd680e9-f194-4180-b14f-fc357498eb92",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def vector_similarity(x: np.array, y: np.array) -> float:\n",
|
|||
|
" return np.dot(x, y)\n",
|
|||
|
"\n",
|
|||
|
"def order_contexts_by_query_similarity(query: str, contexts: dict[(str, str), np.array]) -> list[(float, (str, str))]:\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" Find the query embedding for the supplied query, and compare it against all of the pre-calculated context embeddings\n",
|
|||
|
" to find the most relevant contexts. \n",
|
|||
|
" \n",
|
|||
|
" Return the list of contexts, sorted by relevance in descending order.\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" query_embedding = get_query_embedding(query)\n",
|
|||
|
" \n",
|
|||
|
" context_similarities = sorted([\n",
|
|||
|
" (vector_similarity(query_embedding, doc_embedding), doc_index) for doc_index, doc_embedding in contexts.items()\n",
|
|||
|
" ], reverse=True)\n",
|
|||
|
" \n",
|
|||
|
" return context_similarities"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 211,
|
|||
|
"id": "e3a27d73-f47f-480d-b336-079414f749cb",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"[(0.4296262686167575,\n",
|
|||
|
" (\"Athletics at the 2020 Summer Olympics – Men's high jump\", 'Summary')),\n",
|
|||
|
" (0.4067051275316437,\n",
|
|||
|
" (\"Athletics at the 2020 Summer Olympics – Women's high jump\", 'Summary')),\n",
|
|||
|
" (0.4046927788902179,\n",
|
|||
|
" (\"Athletics at the 2020 Summer Olympics – Men's high jump\", 'Background')),\n",
|
|||
|
" (0.40424431005550354,\n",
|
|||
|
" (\"Athletics at the 2020 Summer Olympics – Men's triple jump\", 'Summary')),\n",
|
|||
|
" (0.4021923762547752,\n",
|
|||
|
" (\"Athletics at the 2020 Summer Olympics – Women's long jump\", 'Summary'))]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 211,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"order_contexts_by_query_similarity(\"Who won the men's high jump?\", context_embeddings)[:5]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "3cf71fae-abb1-46b2-a483-c1b2f1a915c2",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"We can see that summaries information about the men's and women's high jump competitions are judged to be the most relevant contexts for this question, which is in line with what we would expect."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "a0efa0f6-4469-457a-89a4-a2f5736a01e0",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# 3) Add the most relevant contexts to the query prompt\n",
|
|||
|
"\n",
|
|||
|
"Once we've calculated the most relevant pieces of context, we construct a prompt by simply prepending them to the supplied query. It is helpful to use a query separator to help the model distinguish between separate pieces of text."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 214,
|
|||
|
"id": "b763ace2-1946-48e0-8ff1-91ba335d47a0",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'Context separator contains 3 tokens'"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 214,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"MAX_CONTEXT_LEN = 500\n",
|
|||
|
"CONTEXT_SEPARATOR = \"\\n* \"\n",
|
|||
|
"\n",
|
|||
|
"tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")\n",
|
|||
|
"context_separator_len = len(tokenizer.tokenize(CONTEXT_SEPARATOR))\n",
|
|||
|
"\n",
|
|||
|
"f\"Context separator contains {context_separator_len} tokens\""
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 160,
|
|||
|
"id": "0c5c0509-eeb9-4552-a5d4-6ace04ef73dd",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def construct_prompt(question: str, contexts: dict, df: pd.DataFrame) -> str:\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" \n",
|
|||
|
" \"\"\"\n",
|
|||
|
" most_relevant_contexts = order_contexts_by_query_similarity(question, doc_embeddings)\n",
|
|||
|
" \n",
|
|||
|
" chosen_contexts = []\n",
|
|||
|
" chosen_contexts_indexes = []\n",
|
|||
|
" chosen_contexts_len = 0\n",
|
|||
|
" \n",
|
|||
|
" for _, context_index in most_relevant_contexts:\n",
|
|||
|
" # Add contexts until we run out of space.\n",
|
|||
|
" # In this version, we will not add the final context that overflows our limit; \n",
|
|||
|
" # you may wish to \"partially add\" that final context.\n",
|
|||
|
" \n",
|
|||
|
" context = df.loc[context_index]\n",
|
|||
|
" \n",
|
|||
|
" chosen_contexts_len += context.tokens + context_separator_len\n",
|
|||
|
" if chosen_contexts_len > MAX_CONTEXT_LEN:\n",
|
|||
|
" break\n",
|
|||
|
" \n",
|
|||
|
" chosen_contexts.append(CONTEXT_SEPARATOR + context.content)\n",
|
|||
|
" chosen_contexts_indexes.append(str(context_index))\n",
|
|||
|
" \n",
|
|||
|
" print(f\"Selected {len(chosen_contexts)} contexts:\")\n",
|
|||
|
" print(\"\\n\".join(chosen_contexts_indexes))\n",
|
|||
|
" \n",
|
|||
|
" return \"\".join(chosen_contexts) + \"\\n\\n\" + question + \"\\n\""
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 161,
|
|||
|
"id": "f614045a-3917-4b28-9643-7e0c299ec1a7",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Selected 3 contexts:\n",
|
|||
|
"(\"Athletics at the 2020 Summer Olympics – Women's high jump\", 'Summary')\n",
|
|||
|
"(\"Athletics at the 2020 Summer Olympics – Men's high jump\", 'Summary')\n",
|
|||
|
"(\"Athletics at the 2020 Summer Olympics – Men's triple jump\", 'Summary')\n",
|
|||
|
"===\n",
|
|||
|
" \n",
|
|||
|
"* The women's high jump event at the 2020 Summer Olympics took place on 5 and 7 August 2021 at the Japan National Stadium. Even though 32 athletes qualified through the qualification system for the Games, only 31 took part in the competition. This was the 22nd appearance of the event, having appeared at every Olympics since women's athletics was introduced in 1928.\n",
|
|||
|
"* The men's high jump event at the 2020 Summer Olympics took place between 30 July and 1 August 2021 at the Olympic Stadium. 33 athletes from 24 nations competed; the total possible number depended on how many nations would use universality places to enter athletes in addition to the 32 qualifying through mark or ranking (no universality places were used in 2021). Italian athlete Gianmarco Tamberi along with Qatari athlete Mutaz Essa Barshim emerged as joint winners of the event following a tie between both of them as they cleared 2.37m. Both Tamberi and Barshim agreed to share the gold medal in a rare instance where the athletes of different nations had agreed to share the same medal in the history of Olympics. Barshim in particular was heard to ask a competition official \"Can we have two golds?\" in response to being offered a 'jump off'. Maksim Nedasekau of Belarus took bronze. The medals were the first ever in the men's high jump for Italy and Belarus, the first gold in the men's high jump for Italy and Qatar, and the third consecutive medal in the men's high jump for Qatar (all by Barshim). Barshim became only the second man to earn three medals in high jump, joining Patrik Sjöberg of Sweden (1984 to 1992).\n",
|
|||
|
"* The men's triple jump event at the 2020 Summer Olympics took place between 3 and 5 August 2021 at the Japan National Stadium. Approximately 35 athletes were expected to compete; the exact number was dependent on how many nations use universality places to enter athletes in addition to the 32 qualifying through time or ranking (2 universality places were used in 2016). 32 athletes from 19 nations competed. Pedro Pichardo of Portugal won the gold medal, the nation's second victory in the men's triple jump (after Nelson Évora in 2008). China's Zhu Yaming took silver, while Hugues Fabrice Zango earned Burkina Faso's first Olympic medal in any event.\n",
|
|||
|
"\n",
|
|||
|
"Who won the 2020 Summer Olympics men's high jump?\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"prompt = construct_prompt(\n",
|
|||
|
" \"Who won the 2020 Summer Olympics men's high jump?\",\n",
|
|||
|
" doc_embeddings,\n",
|
|||
|
" df\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"print(\"===\\n\", prompt)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "1b022fd4-0a3c-4ae1-bed1-4c80e4f0fb56",
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"# 4) Submit the extended query to the Completions API.\n",
|
|||
|
"\n",
|
|||
|
"Now that we've retrieved the relevant context, we can finally use the Completions API to answer the user's query."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 233,
|
|||
|
"id": "b0edfec7-9243-4573-92e0-253d31c771ad",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"COMPLETIONS_API_PARAMS = {\n",
|
|||
|
" \"temperature\": 0.0,\n",
|
|||
|
" \"max_tokens\": 300,\n",
|
|||
|
" \"top_p\": 1,\n",
|
|||
|
" \"frequency_penalty\": 0,\n",
|
|||
|
" \"presence_penalty\": 0,\n",
|
|||
|
" \"engine\": COMPLETIONS_MODEL,\n",
|
|||
|
" \"stop\": \"\\n\\n\"\n",
|
|||
|
"}"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 234,
|
|||
|
"id": "9c1c9a69-848e-4099-a90d-c8da36c153d5",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def answer_query_with_context(\n",
|
|||
|
" query: str,\n",
|
|||
|
" df: pd.DataFrame,\n",
|
|||
|
" context_embeddings: dict[(str, str), np.array]\n",
|
|||
|
") -> str:\n",
|
|||
|
" prompt = construct_prompt(\n",
|
|||
|
" query,\n",
|
|||
|
" context_embeddings,\n",
|
|||
|
" df\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" response = openai.Completion.create(\n",
|
|||
|
" prompt=prompt,\n",
|
|||
|
" **COMPLETIONS_API_PARAMS\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" return response[\"choices\"][0][\"text\"].strip(\"\\n\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 235,
|
|||
|
"id": "c233e449-bf33-4c9e-b095-6a4dd278c8fd",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Selected 3 contexts:\n",
|
|||
|
"(\"Athletics at the 2020 Summer Olympics – Women's high jump\", 'Summary')\n",
|
|||
|
"(\"Athletics at the 2020 Summer Olympics – Men's high jump\", 'Summary')\n",
|
|||
|
"(\"Athletics at the 2020 Summer Olympics – Men's triple jump\", 'Summary')\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"\"Gianmarco Tamberi and Mutaz Essa Barshim both cleared 2.37m to win the gold medal in the men's high jump event at the 2020 Summer Olympics.\""
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 235,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"answer_query_with_context(\"Who won the 2020 Summer Olympics men's high jump?\", df, context_embeddings)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "7b48d155-d2d4-447c-ab8e-5a5b4722b07c",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# More Examples"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 242,
|
|||
|
"id": "1127867b-2884-44bb-9439-0e8ae171c835",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Selected 1 contexts:\n",
|
|||
|
"('2020 Summer Olympics', 'Postponement to 2021')\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'The 2020 Summer Olympics were originally postponed due to the COVID-19 pandemic. The outbreak of the pandemic began in China in December 2019, and spread to Japan in January 2020. The outbreak was declared a Public Health Emergency of International Concern by the World Health Organization on 30 January 2020. On 25 February 2020, the IOC announced that it would hold a meeting on the following day to discuss the outbreak and its potential impact on the Games.'"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 242,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"query = \"Why was the 2020 Summer Olympics originally postponed?\"\n",
|
|||
|
"answer_query_with_context(query, df, context_embeddings)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 236,
|
|||
|
"id": "720d9e0b-b189-4101-91ee-babf736199e6",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Selected 1 contexts:\n",
|
|||
|
"('2020 Summer Olympics medal table', 'Summary')\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'The United States won the most medals overall, with 113, and the most gold medals, with 39.'"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 236,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"query = \"In the 2020 Summer Olympics, how many gold medals did the country which won the most medals win?\"\n",
|
|||
|
"answer_query_with_context(query, df, context_embeddings)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 238,
|
|||
|
"id": "4e8e51cc-e4eb-4557-9e09-2929d4df5b7f",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Selected 3 contexts:\n",
|
|||
|
"(\"Athletics at the 2020 Summer Olympics – Men's shot put\", 'Summary')\n",
|
|||
|
"(\"Athletics at the 2020 Summer Olympics – Men's shot put\", 'Background')\n",
|
|||
|
"(\"Athletics at the 2020 Summer Olympics – Men's hammer throw\", 'Competition format')\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'The men’s shotput competition at the 2020 Summer Olympics was unusual because the same three competitors received the same medals in back-to-back editions of the same individual event. Americans Ryan Crouser and Joe Kovacs and New Zealander Tom Walsh repeated their gold, silver, and bronze (respectively) performances from the 2016 Summer Olympics. They became the 15th, 16th, and 17th men to earn multiple medals in the shot put; Crouser was the 4th to repeat as champion.'"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 238,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"query = \"What was unusual about the men’s shotput competition?\"\n",
|
|||
|
"answer_query_with_context(query, df, context_embeddings)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 239,
|
|||
|
"id": "37c83519-e3c6-4c44-8b4a-98cbb3a5f5ba",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Selected 1 contexts:\n",
|
|||
|
"('Italy at the 2020 Summer Olympics', 'Summary')\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'10'"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 239,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"query = \"In the 2020 Summer Olympics, how many silver medals did Italy win?\"\n",
|
|||
|
"answer_query_with_context(query, df, context_embeddings)"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3 (ipykernel)",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.9.9"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 5
|
|||
|
}
|