more fine-tuning improvements (#652)

* more fine-tuning improvements

* add links to other resources
pull/653/head
Simón Fishman 10 months ago committed by GitHub
parent 8ed84645e8
commit d534c85477
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,10 +13,12 @@
"\n",
"1. **Setup:** Loading our dataset and filtering down to one domain to fine-tune on.\n",
"2. **Data preparation:** Preparing your data for fine-tuning by creating training and validation examples, and uploading them to the `Files` endpoint.\n",
"3. **Create the fine-tune:** Creating your fine-tuned model.\n",
"4. **Use model for inference:** Using your fine-tuned model for inference on new inputs.\n",
"3. **Fine-tuning:** Creating your fine-tuned model.\n",
"4. **Inference:** Using your fine-tuned model for inference on new inputs.\n",
"\n",
"By the end of this you should be able to train, evaluate and deploy a fine-tuned `gpt-3.5-turbo` model.\n"
"By the end of this you should be able to train, evaluate and deploy a fine-tuned `gpt-3.5-turbo` model.\n",
"\n",
"For more information on fine-tuning, you can refer to our [documentation guide](https://platform.openai.com/docs/guides/fine-tuning), [API reference](https://platform.openai.com/docs/api-reference/fine-tuning) or [blog post](https://openai.com/blog/gpt-3-5-turbo-fine-tuning-and-api-updates)"
]
},
{
@ -24,11 +26,7 @@
"id": "6f49cb10-f895-41f4-aa97-da606d0084d4",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"First we will import any required libraries and prepare our data.\n",
"\n",
"Fine tuning works best when focused on a particular domain. It's important to make sure your dataset is both focused enough for the model to learn, but general enough that unseen examples won't be missed. Having this in mind, we have already extracted a subset from the RecipesNLG dataset to only contain documents from www.cookbooks.com.\n"
"## Setup"
]
},
{
@ -58,6 +56,15 @@
"OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\", \"\")"
]
},
{
"cell_type": "markdown",
"id": "a468d660",
"metadata": {},
"source": [
"\n",
"Fine-tuning works best when focused on a particular domain. It's important to make sure your dataset is both focused enough for the model to learn, but general enough that unseen examples won't be missed. Having this in mind, we have extracted a subset from the RecipesNLG dataset to only contain documents from www.cookbooks.com."
]
},
{
"cell_type": "code",
"execution_count": 2,
@ -173,7 +180,7 @@
"4 [\"peanut butter\", \"graham cracker crumbs\", \"bu... "
]
},
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@ -208,62 +215,112 @@
"\n",
"During the training process this conversation will be split, with the final entry being the `completion` that the model will produce, and the remainder of the `messages` acting as the prompt. Consider this when building your training examples - if your model will act on multi-turn conversations, then please provide representative examples so it doesn't perform poorly when the conversation starts to expand.\n",
"\n",
"For fine-tuning with `ChatCompletion` you can begin with even 30-50 well-pruned examples. You should see performance continue to scale linearly as you increase the size of the training set, but your jobs will also take longer.\n",
"\n",
"Please note that currently there is a 4096 token limit for each training example. Anything longer than this will be truncated at 4096 tokens.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "9a8216b0-d1dc-472d-b07d-1be03acd70a5",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'messages': [{'content': 'You are a helpful recipe assistant. You are to '\n",
" 'extract the generic ingredients from each of the '\n",
" 'recipes provided.',\n",
" 'role': 'system'},\n",
" {'content': 'Title: No-Bake Nut Cookies\\n'\n",
" '\\n'\n",
" 'Ingredients: [\"1 c. firmly packed brown sugar\", '\n",
" '\"1/2 c. evaporated milk\", \"1/2 tsp. vanilla\", \"1/2 '\n",
" 'c. broken nuts (pecans)\", \"2 Tbsp. butter or '\n",
" 'margarine\", \"3 1/2 c. bite size shredded rice '\n",
" 'biscuits\"]\\n'\n",
" '\\n'\n",
" 'Generic ingredients: ',\n",
" 'role': 'user'},\n",
" {'content': '[\"brown sugar\", \"milk\", \"vanilla\", \"nuts\", '\n",
" '\"butter\", \"bite size shredded rice biscuits\"]',\n",
" 'role': 'assistant'}]}\n"
]
}
],
"source": [
"training_data = []\n",
"\n",
"system_message = \"You are a helpful recipe assistant. You are to extract the generic ingredients from each of the recipes provided.\"\n",
"\n",
"def prepare_example_conversation(row):\n",
" messages = []\n",
" messages.append({\"role\": \"system\", \"content\": system_message})\n",
"\n",
" user_message = f\"\"\"Title: {row['title']}\\n\\nIngredients: {row['ingredients']}\\n\\nGeneric ingredients: \"\"\"\n",
" messages.append({\"role\": \"user\", \"content\": user_message})\n",
"\n",
"def create_user_message(row):\n",
" return f\"\"\"Title: {row['title']}\\n\\nIngredients: {row['ingredients']}\\n\\nGeneric ingredients: \"\"\"\n",
" messages.append({\"role\": \"assistant\", \"content\": row[\"NER\"]})\n",
"\n",
" return {\"messages\": messages}\n",
"\n",
"# Take first 100 records for training\n",
"for x, y in recipe_df.head(100).iterrows():\n",
" training_message = []\n",
" training_message.append({\"role\": \"system\", \"content\": system_message})\n",
"pprint(prepare_example_conversation(recipe_df.iloc[0]))"
]
},
{
"cell_type": "markdown",
"id": "82fa4fae",
"metadata": {},
"source": [
"Let's now do this for a subset of the dataset to use as our training data. You can begin with even 30-50 well-pruned examples. You should see performance continue to scale linearly as you increase the size of the training set, but your jobs will also take longer."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8f37aff9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'messages': [{'role': 'system', 'content': 'You are a helpful recipe assistant. You are to extract the generic ingredients from each of the recipes provided.'}, {'role': 'user', 'content': 'Title: No-Bake Nut Cookies\\n\\nIngredients: [\"1 c. firmly packed brown sugar\", \"1/2 c. evaporated milk\", \"1/2 tsp. vanilla\", \"1/2 c. broken nuts (pecans)\", \"2 Tbsp. butter or margarine\", \"3 1/2 c. bite size shredded rice biscuits\"]\\n\\nGeneric ingredients: '}, {'role': 'assistant', 'content': '[\"brown sugar\", \"milk\", \"vanilla\", \"nuts\", \"butter\", \"bite size shredded rice biscuits\"]'}]}\n",
"{'messages': [{'role': 'system', 'content': 'You are a helpful recipe assistant. You are to extract the generic ingredients from each of the recipes provided.'}, {'role': 'user', 'content': 'Title: Jewell Ball\\'S Chicken\\n\\nIngredients: [\"1 small jar chipped beef, cut up\", \"4 boned chicken breasts\", \"1 can cream of mushroom soup\", \"1 carton sour cream\"]\\n\\nGeneric ingredients: '}, {'role': 'assistant', 'content': '[\"beef\", \"chicken breasts\", \"cream of mushroom soup\", \"sour cream\"]'}]}\n",
"{'messages': [{'role': 'system', 'content': 'You are a helpful recipe assistant. You are to extract the generic ingredients from each of the recipes provided.'}, {'role': 'user', 'content': 'Title: Creamy Corn\\n\\nIngredients: [\"2 (16 oz.) pkg. frozen corn\", \"1 (8 oz.) pkg. cream cheese, cubed\", \"1/3 c. butter, cubed\", \"1/2 tsp. garlic powder\", \"1/2 tsp. salt\", \"1/4 tsp. pepper\"]\\n\\nGeneric ingredients: '}, {'role': 'assistant', 'content': '[\"frozen corn\", \"cream cheese\", \"butter\", \"garlic powder\", \"salt\", \"pepper\"]'}]}\n",
"{'messages': [{'role': 'system', 'content': 'You are a helpful recipe assistant. You are to extract the generic ingredients from each of the recipes provided.'}, {'role': 'user', 'content': 'Title: Chicken Funny\\n\\nIngredients: [\"1 large whole chicken\", \"2 (10 1/2 oz.) cans chicken gravy\", \"1 (10 1/2 oz.) can cream of mushroom soup\", \"1 (6 oz.) box Stove Top stuffing\", \"4 oz. shredded cheese\"]\\n\\nGeneric ingredients: '}, {'role': 'assistant', 'content': '[\"chicken\", \"chicken gravy\", \"cream of mushroom soup\", \"shredded cheese\"]'}]}\n",
"{'messages': [{'role': 'system', 'content': 'You are a helpful recipe assistant. You are to extract the generic ingredients from each of the recipes provided.'}, {'role': 'user', 'content': 'Title: Reeses Cups(Candy) \\n\\nIngredients: [\"1 c. peanut butter\", \"3/4 c. graham cracker crumbs\", \"1 c. melted butter\", \"1 lb. (3 1/2 c.) powdered sugar\", \"1 large pkg. chocolate chips\"]\\n\\nGeneric ingredients: '}, {'role': 'assistant', 'content': '[\"peanut butter\", \"graham cracker crumbs\", \"butter\", \"powdered sugar\", \"chocolate chips\"]'}]}\n"
]
}
],
"source": [
"# use the first 100 rows of the dataset for training\n",
"training_df = recipe_df.loc[0:100]\n",
"\n",
" user_message = create_user_message(y)\n",
" training_message.append({\"role\": \"user\", \"content\": user_message})\n",
"# apply the prepare_example_conversation function to each row of the training_df\n",
"training_data = training_df.apply(prepare_example_conversation, axis=1).tolist()\n",
"\n",
" training_message.append({\"role\": \"assistant\", \"content\": y[\"NER\"]})\n",
" training_message_dict = {\"messages\": training_message}\n",
" training_data.append(training_message_dict)"
"for example in training_data[:5]:\n",
" print(example)"
]
},
{
"cell_type": "markdown",
"id": "6fb0f6a7",
"metadata": {},
"source": [
"In addition to training data, we can also **optionally** provide validation data, which will be used to make sure that the model does not overfit your training set."
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"id": "5b853efa-dfea-4770-ab88-9b7e17794421",
"metadata": {},
"outputs": [],
"source": [
"validation_data = []\n",
"\n",
"# We'll pick a test set from further on in the dataset\n",
"test_df = recipe_df.loc[100:200]\n",
"\n",
"for x, y in test_df.iterrows():\n",
" validation_message = []\n",
" validation_message.append({\"role\": \"system\", \"content\": system_message})\n",
"\n",
" user_message = create_user_message(y)\n",
" validation_message.append({\"role\": \"user\", \"content\": user_message})\n",
"\n",
" validation_message.append({\"role\": \"assistant\", \"content\": y[\"NER\"]})\n",
" validation_message_dict = {\"messages\": validation_message}\n",
" validation_data.append(validation_message_dict)"
"validation_df = recipe_df.loc[101:200]\n",
"validation_data = validation_df.apply(prepare_example_conversation, axis=1).tolist()"
]
},
{
@ -271,30 +328,17 @@
"id": "1d5e7bfe-f6c8-4a23-a951-3df3f3791d7f",
"metadata": {},
"source": [
"We then need to export these as `.jsonl` files, with each row being one training example.\n"
"We then need to save our data as `.jsonl` files, with each line being one training example conversation.\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"id": "8d2eb207-2c2b-43f6-a613-64a7e92d494d",
"metadata": {},
"outputs": [],
"source": [
"def dicts_to_jsonl(data_list: list, filename: str) -> None:\n",
" \"\"\"\n",
" Method saves list of dicts into jsonl file.\n",
" :param data: (list) list of dicts to be stored,\n",
" :param filename: (str) path to the output file. If suffix .jsonl is not given then methods appends\n",
" .jsonl suffix into the file.\n",
" \"\"\"\n",
" sjsonl = \".jsonl\"\n",
"\n",
" # Check filename\n",
" if not filename.endswith(sjsonl):\n",
" filename = filename + \".jsonl\"\n",
"\n",
" # Save data\n",
"def write_jsonl(data_list: list, filename: str) -> None:\n",
" with open(filename, \"w\") as out:\n",
" for ddict in data_list:\n",
" jout = json.dumps(ddict) + \"\\n\"\n",
@ -303,18 +347,47 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 8,
"id": "8b53e7a2-1cac-4c5f-8ba4-3292ba2a0770",
"metadata": {},
"outputs": [],
"source": [
"# Save training_data to JSONL\n",
"training_file_name = \"tmp_recipe_finetune_training\"\n",
"dicts_to_jsonl(training_data, training_file_name)\n",
"training_file_name = \"tmp_recipe_finetune_training.jsonl\"\n",
"write_jsonl(training_data, training_file_name)\n",
"\n",
"# Save validation_data to JSONL\n",
"validation_file_name = \"tmp_recipe_finetune_validation\"\n",
"dicts_to_jsonl(validation_data, validation_file_name)"
"validation_file_name = \"tmp_recipe_finetune_validation.jsonl\"\n",
"write_jsonl(validation_data, validation_file_name)"
]
},
{
"cell_type": "markdown",
"id": "80b274a3",
"metadata": {},
"source": [
"This is what the first 5 lines of our training `.jsonl` file look like:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "283c4ec2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful recipe assistant. You are to extract the generic ingredients from each of the recipes provided.\"}, {\"role\": \"user\", \"content\": \"Title: No-Bake Nut Cookies\\n\\nIngredients: [\\\"1 c. firmly packed brown sugar\\\", \\\"1/2 c. evaporated milk\\\", \\\"1/2 tsp. vanilla\\\", \\\"1/2 c. broken nuts (pecans)\\\", \\\"2 Tbsp. butter or margarine\\\", \\\"3 1/2 c. bite size shredded rice biscuits\\\"]\\n\\nGeneric ingredients: \"}, {\"role\": \"assistant\", \"content\": \"[\\\"brown sugar\\\", \\\"milk\\\", \\\"vanilla\\\", \\\"nuts\\\", \\\"butter\\\", \\\"bite size shredded rice biscuits\\\"]\"}]}\n",
"{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful recipe assistant. You are to extract the generic ingredients from each of the recipes provided.\"}, {\"role\": \"user\", \"content\": \"Title: Jewell Ball'S Chicken\\n\\nIngredients: [\\\"1 small jar chipped beef, cut up\\\", \\\"4 boned chicken breasts\\\", \\\"1 can cream of mushroom soup\\\", \\\"1 carton sour cream\\\"]\\n\\nGeneric ingredients: \"}, {\"role\": \"assistant\", \"content\": \"[\\\"beef\\\", \\\"chicken breasts\\\", \\\"cream of mushroom soup\\\", \\\"sour cream\\\"]\"}]}\n",
"{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful recipe assistant. You are to extract the generic ingredients from each of the recipes provided.\"}, {\"role\": \"user\", \"content\": \"Title: Creamy Corn\\n\\nIngredients: [\\\"2 (16 oz.) pkg. frozen corn\\\", \\\"1 (8 oz.) pkg. cream cheese, cubed\\\", \\\"1/3 c. butter, cubed\\\", \\\"1/2 tsp. garlic powder\\\", \\\"1/2 tsp. salt\\\", \\\"1/4 tsp. pepper\\\"]\\n\\nGeneric ingredients: \"}, {\"role\": \"assistant\", \"content\": \"[\\\"frozen corn\\\", \\\"cream cheese\\\", \\\"butter\\\", \\\"garlic powder\\\", \\\"salt\\\", \\\"pepper\\\"]\"}]}\n",
"{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful recipe assistant. You are to extract the generic ingredients from each of the recipes provided.\"}, {\"role\": \"user\", \"content\": \"Title: Chicken Funny\\n\\nIngredients: [\\\"1 large whole chicken\\\", \\\"2 (10 1/2 oz.) cans chicken gravy\\\", \\\"1 (10 1/2 oz.) can cream of mushroom soup\\\", \\\"1 (6 oz.) box Stove Top stuffing\\\", \\\"4 oz. shredded cheese\\\"]\\n\\nGeneric ingredients: \"}, {\"role\": \"assistant\", \"content\": \"[\\\"chicken\\\", \\\"chicken gravy\\\", \\\"cream of mushroom soup\\\", \\\"shredded cheese\\\"]\"}]}\n",
"{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful recipe assistant. You are to extract the generic ingredients from each of the recipes provided.\"}, {\"role\": \"user\", \"content\": \"Title: Reeses Cups(Candy) \\n\\nIngredients: [\\\"1 c. peanut butter\\\", \\\"3/4 c. graham cracker crumbs\\\", \\\"1 c. melted butter\\\", \\\"1 lb. (3 1/2 c.) powdered sugar\\\", \\\"1 large pkg. chocolate chips\\\"]\\n\\nGeneric ingredients: \"}, {\"role\": \"assistant\", \"content\": \"[\\\"peanut butter\\\", \\\"graham cracker crumbs\\\", \\\"butter\\\", \\\"powdered sugar\\\", \\\"chocolate chips\\\"]\"}]}\n"
]
}
],
"source": [
"# print the first 5 lines of the training file\n",
"!head -n 5 tmp_recipe_finetune_training.jsonl"
]
},
{
@ -324,12 +397,12 @@
"source": [
"### Upload files\n",
"\n",
"You can then upload the files to our `Files` endpoint to be used by the fine-tuned model.\n"
"You can now upload the files to our `Files` endpoint to be used by the fine-tuned model.\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 10,
"id": "69462d9e-e6bd-49b9-a064-9eae4ea5b7a8",
"metadata": {},
"outputs": [
@ -337,19 +410,19 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Training file ID: file-XEpJAyCL2qMbMEgXwYfS9ypT\n",
"Validation file ID: file-MAGY5QyvCMdOWUDuXJB79uw9\n"
"Training file ID: file-jcdvNl27iuBMZfwi4q30IIka\n",
"Validation file ID: file-O144OIHkZ1xjB32ednBmbOXP\n"
]
}
],
"source": [
"training_response = openai.File.create(\n",
" file=open(training_file_name + \".jsonl\", \"rb\"), purpose=\"fine-tune\"\n",
" file=open(training_file_name, \"rb\"), purpose=\"fine-tune\"\n",
")\n",
"training_file_id = training_response[\"id\"]\n",
"\n",
"validation_response = openai.File.create(\n",
" file=open(validation_file_name + \".jsonl\", \"rb\"), purpose=\"fine-tune\"\n",
" file=open(validation_file_name, \"rb\"), purpose=\"fine-tune\"\n",
")\n",
"validation_file_id = validation_response[\"id\"]\n",
"\n",
@ -371,7 +444,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "05541ceb-5628-447e-962d-7e57c112439c",
"metadata": {},
"outputs": [
@ -412,7 +485,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "d7392f48",
"metadata": {},
"outputs": [
@ -444,7 +517,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "08cace28",
"metadata": {},
"outputs": [

Loading…
Cancel
Save