You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/Fine-tuned_classification.i...

676 lines
23 KiB
Plaintext

2 years ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
"# Fine tuning classification example\n",
"\n",
"We will fine-tune a `babbage-002` classifier (replacement for the `ada` models) to distinguish between the two sports: Baseball and Hockey."
]
2 years ago
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
2 years ago
"source": [
"from sklearn.datasets import fetch_20newsgroups\n",
"import pandas as pd\n",
"import openai\n",
"import os\n",
"\n",
"client = openai.OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
2 years ago
"\n",
"categories = ['rec.sport.baseball', 'rec.sport.hockey']\n",
"sports_dataset = fetch_20newsgroups(subset='train', shuffle=True, random_state=42, categories=categories)"
]
2 years ago
},
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
" ## Data exploration\n",
" The newsgroup dataset can be loaded using sklearn. First we will look at the data itself:"
]
2 years ago
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
2 years ago
"outputs": [
{
"name": "stdout",
"output_type": "stream",
2 years ago
"text": [
"From: dougb@comm.mot.com (Doug Bank)\n",
"Subject: Re: Info needed for Cleveland tickets\n",
"Reply-To: dougb@ecs.comm.mot.com\n",
"Organization: Motorola Land Mobile Products Sector\n",
"Distribution: usa\n",
"Nntp-Posting-Host: 145.1.146.35\n",
"Lines: 17\n",
"\n",
"In article <1993Apr1.234031.4950@leland.Stanford.EDU>, bohnert@leland.Stanford.EDU (matthew bohnert) writes:\n",
"\n",
"|> I'm going to be in Cleveland Thursday, April 15 to Sunday, April 18.\n",
"|> Does anybody know if the Tribe will be in town on those dates, and\n",
"|> if so, who're they playing and if tickets are available?\n",
"\n",
"The tribe will be in town from April 16 to the 19th.\n",
"There are ALWAYS tickets available! (Though they are playing Toronto,\n",
"and many Toronto fans make the trip to Cleveland as it is easier to\n",
"get tickets in Cleveland than in Toronto. Either way, I seriously\n",
"doubt they will sell out until the end of the season.)\n",
"\n",
"-- \n",
"Doug Bank Private Systems Division\n",
"dougb@ecs.comm.mot.com Motorola Communications Sector\n",
"dougb@nwu.edu Schaumburg, Illinois\n",
"dougb@casbah.acns.nwu.edu 708-576-8207 \n",
"\n"
]
}
],
"source": [
"print(sports_dataset['data'][0])"
]
2 years ago
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
2 years ago
"outputs": [
{
"data": {
"text/plain": [
"'rec.sport.baseball'"
]
},
"execution_count": 11,
2 years ago
"metadata": {},
"output_type": "execute_result"
2 years ago
}
],
"source": [
"sports_dataset.target_names[sports_dataset['target'][0]]\n"
]
2 years ago
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
2 years ago
"outputs": [
{
"name": "stdout",
"output_type": "stream",
2 years ago
"text": [
"Total examples: 1197, Baseball examples: 597, Hockey examples: 600\n"
]
}
],
"source": [
"len_all, len_baseball, len_hockey = len(sports_dataset.data), len([e for e in sports_dataset.target if e == 0]), len([e for e in sports_dataset.target if e == 1])\n",
"print(f\"Total examples: {len_all}, Baseball examples: {len_baseball}, Hockey examples: {len_hockey}\")"
]
2 years ago
},
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
"One sample from the baseball category can be seen above. It is an email to a mailing list. We can observe that we have 1197 examples in total, which are evenly split between the two sports."
]
2 years ago
},
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
"## Data Preparation\n",
"We transform the dataset into a pandas dataframe, with a column for prompt and completion. The prompt contains the email from the mailing list, and the completion is a name of the sport, either hockey or baseball. For demonstration purposes only and speed of fine-tuning we take only 300 examples. In a real use case the more examples the better the performance."
]
2 years ago
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
2 years ago
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>prompt</th>\n",
" <th>completion</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>From: dougb@comm.mot.com (Doug Bank)\\nSubject:...</td>\n",
" <td>baseball</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>From: gld@cunixb.cc.columbia.edu (Gary L Dare)...</td>\n",
" <td>hockey</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>From: rudy@netcom.com (Rudy Wade)\\nSubject: Re...</td>\n",
" <td>baseball</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>From: monack@helium.gas.uug.arizona.edu (david...</td>\n",
" <td>hockey</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Subject: Let it be Known\\nFrom: &lt;ISSBTL@BYUVM....</td>\n",
" <td>baseball</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" prompt completion\n",
"0 From: dougb@comm.mot.com (Doug Bank)\\nSubject:... baseball\n",
"1 From: gld@cunixb.cc.columbia.edu (Gary L Dare)... hockey\n",
"2 From: rudy@netcom.com (Rudy Wade)\\nSubject: Re... baseball\n",
"3 From: monack@helium.gas.uug.arizona.edu (david... hockey\n",
"4 Subject: Let it be Known\\nFrom: <ISSBTL@BYUVM.... baseball"
2 years ago
]
},
"execution_count": 13,
2 years ago
"metadata": {},
"output_type": "execute_result"
2 years ago
}
],
"source": [
"import pandas as pd\n",
"\n",
"labels = [sports_dataset.target_names[x].split('.')[-1] for x in sports_dataset['target']]\n",
"texts = [text.strip() for text in sports_dataset['data']]\n",
"df = pd.DataFrame(zip(texts, labels), columns = ['prompt','completion']) #[:300]\n",
"df.head()"
]
2 years ago
},
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
"Both baseball and hockey are single tokens. We save the dataset as a jsonl file."
]
2 years ago
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
2 years ago
"source": [
"df.to_json(\"sport2.jsonl\", orient='records', lines=True)"
]
2 years ago
},
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
"### Data Preparation tool\n",
"We can now use a data preparation tool which will suggest a few improvements to our dataset before fine-tuning. Before launching the tool we update the openai library to ensure we're using the latest data preparation tool. We additionally specify `-q` which auto-accepts all suggestions."
]
2 years ago
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
2 years ago
"outputs": [
{
"name": "stdout",
"output_type": "stream",
2 years ago
"text": [
"Analyzing...\n",
"\n",
"- Your file contains 1197 prompt-completion pairs\n",
"- Based on your data it seems like you're trying to fine-tune a model for classification\n",
"- For classification, we recommend you try one of the faster and cheaper models, such as `ada`\n",
"- For classification, you can estimate the expected model performance by keeping a held out dataset, which is not used for training\n",
"- There are 11 examples that are very long. These are rows: [134, 200, 281, 320, 404, 595, 704, 838, 1113, 1139, 1174]\n",
"For conditional generation, and for classification the examples shouldn't be longer than 2048 tokens.\n",
"- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty\n",
"- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details\n",
2 years ago
"\n",
"Based on the analysis we will perform the following actions:\n",
"- [Recommended] Remove 11 long examples [Y/n]: Y\n",
"- [Recommended] Add a suffix separator `\\n\\n###\\n\\n` to all prompts [Y/n]: Y\n",
"- [Recommended] Add a whitespace character to the beginning of the completion [Y/n]: Y\n",
"- [Recommended] Would you like to split into training and validation set? [Y/n]: Y\n",
"\n",
"\n",
"Your data will be written to a new JSONL file. Proceed [Y/n]: Y\n",
"\n",
"Wrote modified files to `sport2_prepared_train (1).jsonl` and `sport2_prepared_valid (1).jsonl`\n",
2 years ago
"Feel free to take a look!\n",
"\n",
"Now use that file when fine-tuning:\n",
"> openai api fine_tunes.create -t \"sport2_prepared_train (1).jsonl\" -v \"sport2_prepared_valid (1).jsonl\" --compute_classification_metrics --classification_positive_class \" baseball\"\n",
2 years ago
"\n",
"After youve fine-tuned a model, remember that your prompt has to end with the indicator string `\\n\\n###\\n\\n` for the model to start generating completions, rather than continuing with the prompt.\n",
"Once your model starts training, it'll approximately take 30.8 minutes to train a `curie` model, and less for `ada` and `babbage`. Queue will approximately take half an hour per job ahead of you.\n"
]
}
],
"source": [
"!openai tools fine_tunes.prepare_data -f sport2.jsonl -q"
]
2 years ago
},
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
"The tool helpfully suggests a few improvements to the dataset and splits the dataset into training and validation set.\n",
"\n",
"A suffix between a prompt and a completion is necessary to tell the model that the input text has stopped, and that it now needs to predict the class. Since we use the same separator in each example, the model is able to learn that it is meant to predict either baseball or hockey following the separator.\n",
"A whitespace prefix in completions is useful, as most word tokens are tokenized with a space prefix.\n",
"The tool also recognized that this is likely a classification task, so it suggested to split the dataset into training and validation datasets. This will allow us to easily measure expected performance on new data."
]
2 years ago
},
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
"## Fine-tuning\n",
"The tool suggests we run the following command to train the dataset. Since this is a classification task, we would like to know what the generalization performance on the provided validation set is for our classification use case.\n",
2 years ago
"\n",
"We can simply copy the suggested command from the CLI tool. We specifically add `-m ada` to fine-tune a cheaper and faster ada model, which is usually comperable in performance to slower and more expensive models on classification use cases. "
]
2 years ago
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
2 years ago
"outputs": [
{
"name": "stdout",
"output_type": "stream",
2 years ago
"text": [
"FineTuningJob(id='ftjob-REo0uLpriEAm08CBRNDlPJZC', created_at=1704413736, error=None, fine_tuned_model=None, finished_at=None, hyperparameters=Hyperparameters(n_epochs='auto', batch_size='auto', learning_rate_multiplier='auto'), model='babbage-002', object='fine_tuning.job', organization_id='org-9HXYFy8ux4r6aboFyec2OLRf', result_files=[], status='validating_files', trained_tokens=None, training_file='file-82XooA2AUDBAUbN5z2DuKRMs', validation_file='file-wTOcQF8vxQ0Z6fNY2GSm0z4P')\n"
2 years ago
]
}
],
"source": [
"train_file = client.files.create(file=open(\"sport2_prepared_train.jsonl\", \"rb\"), purpose=\"fine-tune\")\n",
"valid_file = client.files.create(file=open(\"sport2_prepared_valid.jsonl\", \"rb\"), purpose=\"fine-tune\")\n",
"\n",
"fine_tuning_job = client.fine_tuning.jobs.create(training_file=train_file.id, validation_file=valid_file.id, model=\"babbage-002\")\n",
"\n",
"print(fine_tuning_job)"
]
2 years ago
},
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
"The model is successfully trained in about ten minutes. You can watch the finetune happen on [https://platform.openai.com/finetune/](https://platform.openai.com/finetune/)\n",
"\n",
"You can also check on its status programatically:"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1704414393\n"
]
}
2 years ago
],
"source": [
"fine_tune_results = client.fine_tuning.jobs.retrieve(fine_tuning_job.id)\n",
"print(fine_tune_results.finished_at)"
]
2 years ago
},
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
"### [Advanced] Results and expected model performance\n",
"We can now download the results file to observe the expected performance on a held out validation set."
]
2 years ago
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
2 years ago
"outputs": [],
"source": [
"fine_tune_results = client.fine_tuning.jobs.retrieve(fine_tuning_job.id).result_files\n",
"result_file = client.files.retrieve(fine_tune_results[0])\n",
"content = client.files.content(result_file.id)\n",
"# save content to file\n",
"with open(\"result.csv\", \"wb\") as f:\n",
" f.write(content.text.encode(\"utf-8\"))"
]
2 years ago
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
2 years ago
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>step</th>\n",
" <th>train_loss</th>\n",
" <th>train_accuracy</th>\n",
" <th>valid_loss</th>\n",
" <th>valid_mean_token_accuracy</th>\n",
2 years ago
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2843</th>\n",
" <td>2844</td>\n",
" <td>0.0</td>\n",
2 years ago
" <td>1.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" step train_loss train_accuracy valid_loss valid_mean_token_accuracy\n",
"2843 2844 0.0 1.0 NaN NaN"
2 years ago
]
},
"execution_count": 52,
2 years ago
"metadata": {},
"output_type": "execute_result"
2 years ago
}
],
"source": [
"results = pd.read_csv('result.csv')\n",
"results[results['train_accuracy'].notnull()].tail(1)"
]
2 years ago
},
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
"The accuracy reaches 99.6%. On the plot below we can see how accuracy on the validation set increases during the training run. "
]
2 years ago
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
2 years ago
"source": [
"results[results['train_accuracy'].notnull()]['train_accuracy'].plot()"
]
2 years ago
},
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
"## Using the model\n",
"We can now call the model to get the predictions."
]
2 years ago
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
2 years ago
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>prompt</th>\n",
" <th>completion</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>From: gld@cunixb.cc.columbia.edu (Gary L Dare)...</td>\n",
" <td>hockey</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>From: smorris@venus.lerc.nasa.gov (Ron Morris ...</td>\n",
" <td>hockey</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>From: golchowy@alchemy.chem.utoronto.ca (Geral...</td>\n",
" <td>hockey</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>From: krattige@hpcc01.corp.hp.com (Kim Krattig...</td>\n",
" <td>baseball</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>From: warped@cs.montana.edu (Doug Dolven)\\nSub...</td>\n",
" <td>baseball</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" prompt completion\n",
"0 From: gld@cunixb.cc.columbia.edu (Gary L Dare)... hockey\n",
"1 From: smorris@venus.lerc.nasa.gov (Ron Morris ... hockey\n",
"2 From: golchowy@alchemy.chem.utoronto.ca (Geral... hockey\n",
"3 From: krattige@hpcc01.corp.hp.com (Kim Krattig... baseball\n",
"4 From: warped@cs.montana.edu (Doug Dolven)\\nSub... baseball"
2 years ago
]
},
"execution_count": 55,
2 years ago
"metadata": {},
"output_type": "execute_result"
2 years ago
}
],
"source": [
"test = pd.read_json('sport2_prepared_valid.jsonl', lines=True)\n",
"test.head()"
]
2 years ago
},
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
"We need to use the same separator following the prompt which we used during fine-tuning. In this case it is `\\n\\n###\\n\\n`. Since we're concerned with classification, we want the temperature to be as low as possible, and we only require one token completion to determine the prediction of the model."
]
2 years ago
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
2 years ago
"outputs": [
{
"data": {
"text/plain": [
"' hockey'"
]
},
"execution_count": 69,
2 years ago
"metadata": {},
"output_type": "execute_result"
2 years ago
}
],
"source": [
"ft_model = fine_tune_results.fine_tuned_model\n",
"\n",
"# note that this calls the legacy completions api - https://platform.openai.com/docs/api-reference/completions\n",
"res = client.completions.create(model=ft_model, prompt=test['prompt'][0] + '\\n\\n###\\n\\n', max_tokens=1, temperature=0)\n",
"res.choices[0].text\n"
]
2 years ago
},
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
"To get the log probabilities, we can specify logprobs parameter on the completion request"
]
2 years ago
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
2 years ago
"outputs": [
{
"data": {
"text/plain": [
"[{' hockey': 0.0, ' Hockey': -22.504879}]"
2 years ago
]
},
"execution_count": 71,
2 years ago
"metadata": {},
"output_type": "execute_result"
2 years ago
}
],
"source": [
"res = client.completions.create(model=ft_model, prompt=test['prompt'][0] + '\\n\\n###\\n\\n', max_tokens=1, temperature=0, logprobs=2)\n",
"res.choices[0].logprobs.top_logprobs"
]
2 years ago
},
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
"We can see that the model predicts hockey as a lot more likely than baseball, which is the correct prediction. By requesting log_probs, we can see the prediction (log) probability for each class."
]
2 years ago
},
{
"cell_type": "markdown",
"metadata": {},
2 years ago
"source": [
"### Generalization\n",
"Interestingly, our fine-tuned classifier is quite versatile. Despite being trained on emails to different mailing lists, it also successfully predicts tweets."
]
2 years ago
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
2 years ago
"outputs": [
{
"data": {
"text/plain": [
"' hockey'"
]
},
"execution_count": 73,
2 years ago
"metadata": {},
"output_type": "execute_result"
2 years ago
}
],
"source": [
"sample_hockey_tweet = \"\"\"Thank you to the \n",
"@Canes\n",
" and all you amazing Caniacs that have been so supportive! You guys are some of the best fans in the NHL without a doubt! Really excited to start this new chapter in my career with the \n",
"@DetroitRedWings\n",
" !!\"\"\"\n",
"res = client.completions.create(model=ft_model, prompt=sample_hockey_tweet + '\\n\\n###\\n\\n', max_tokens=1, temperature=0, logprobs=2)\n",
"res.choices[0].text"
]
2 years ago
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
2 years ago
"source": [
"sample_baseball_tweet=\"\"\"BREAKING: The Tampa Bay Rays are finalizing a deal to acquire slugger Nelson Cruz from the Minnesota Twins, sources tell ESPN.\"\"\"\n",
"res = client.completions.create(model=ft_model, prompt=sample_baseball_tweet + '\\n\\n###\\n\\n', max_tokens=1, temperature=0, logprobs=2)\n",
"res.choices[0].text"
]
2 years ago
}
],
"metadata": {
"interpreter": {
"hash": "3b138a8faad971cc852f62bcf00f59ea0e31721743ea2c5a866ca26adf572e75"
},
"kernelspec": {
"display_name": "Python 3.7.3 64-bit ('base': conda)",
"name": "python3"
},
2 years ago
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
2 years ago
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
2 years ago
},
"orig_nbformat": 4
2 years ago
},
"nbformat": 4,
"nbformat_minor": 2
}