add descriptions to fine-tuning dataprep notebook (#673)

pull/675/head
Simón Fishman 10 months ago committed by GitHub
parent 15f3fda4a3
commit 5783656852
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,17 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "a06ec76c",
"metadata": {},
"source": [
"# Data preparation and analysis for chat model fine-tuning\n",
"\n",
"This notebook serves as a tool to preprocess and analyze the chat dataset used for fine-tuning a chat model. \n",
"It checks for format errors, provides basic statistics, and estimates token counts for fine-tuning costs.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
@ -8,38 +20,27 @@
"outputs": [],
"source": [
"import json\n",
"import tiktoken\n",
"import tiktoken # for token counting\n",
"import numpy as np\n",
"from collections import defaultdict"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "748510f2",
"attachments": {},
"cell_type": "markdown",
"id": "013bdbc4",
"metadata": {},
"outputs": [],
"source": [
"data_path = \"data/toy_chat_fine_tuning.jsonl\""
"## Data loading\n",
"\n",
"We first load the chat dataset from an [example JSONL file](https://github.com/openai/openai-cookbook/blob/main/examples/data/toy_chat_fine_tuning.jsonl)."
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "c248ccd1",
"metadata": {},
"outputs": [],
"source": [
"# Load dataset\n",
"with open(data_path, 'r', encoding='utf-8') as f:\n",
" dataset = [json.loads(line) for line in f]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "189ed16c",
"metadata": {},
"outputs": [
{
"name": "stdout",
@ -54,6 +55,12 @@
}
],
"source": [
"data_path = \"data/toy_chat_fine_tuning.jsonl\"\n",
"\n",
"# Load the dataset\n",
"with open(data_path, 'r', encoding='utf-8') as f:\n",
" dataset = [json.loads(line) for line in f]\n",
"\n",
"# Initial dataset stats\n",
"print(\"Num examples:\", len(dataset))\n",
"print(\"First example:\")\n",
@ -61,9 +68,30 @@
" print(message)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "17903d61",
"metadata": {},
"source": [
"## Format validation\n",
"\n",
"We can perform a variety of error checks to validate that each conversation in the dataset adheres to the format expected by the fine-tuning API. Errors are categorized based on their nature for easier debugging.\n",
"\n",
"1. **Data Type Check**: Checks whether each entry in the dataset is a dictionary (`dict`). Error type: `data_type`.\n",
"2. **Presence of Message List**: Checks if a `messages` list is present in each entry. Error type: `missing_messages_list`.\n",
"3. **Message Keys Check**: Validates that each message in the `messages` list contains the keys `role` and `content`. Error type: `message_missing_key`.\n",
"4. **Unrecognized Keys in Messages**: Logs if a message has keys other than `role`, `content`, and `name`. Error type: `message_unrecognized_key`.\n",
"5. **Role Validation**: Ensures the `role` is one of \"system\", \"user\", or \"assistant\". Error type: `unrecognized_role`.\n",
"6. **Content Validation**: Verifies that `content` has textual data and is a string. Error type: `missing_content`.\n",
"7. **Assistant Message Presence**: Checks that each conversation has at least one message from the assistant. Error type: `example_missing_assistant_message`.\n",
"\n",
"The code below performs these checks, and outputs counts for each type of error found are printed. This is useful for debugging and ensuring the dataset is ready for the next steps.\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"id": "d9f3ccbf",
"metadata": {},
"outputs": [
@ -114,14 +142,24 @@
" print(\"No errors found\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "981e77da",
"metadata": {},
"source": [
"## Token Counting Utilities\n",
"\n",
"Lets define a few helpful utilities to be used in the rest of the notebook."
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"id": "8f4b47b5",
"metadata": {},
"outputs": [],
"source": [
"# Token counting functions\n",
"encoding = tiktoken.get_encoding(\"cl100k_base\")\n",
"\n",
"# not exact!\n",
@ -151,9 +189,26 @@
" print(f\"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "0fdff67d",
"metadata": {},
"source": [
"## Data Warnings and Token Counts \n",
"\n",
"With some lightweight analysis we can identify potential issues in the dataset, like missing messages, and provide statistical insights into message and token counts.\n",
"\n",
"1. **Missing System/User Messages**: Counts the number of conversations missing a \"system\" or \"user\" message. Such messages are critical for defining the assistant's behavior and initiating the conversation.\n",
"2. **Number of Messages Per Example**: Summarizes the distribution of the number of messages in each conversation, providing insight into dialogue complexity.\n",
"3. **Total Tokens Per Example**: Calculates and summarizes the distribution of the total number of tokens in each conversation. Important for understanding fine-tuning costs.\n",
"4. **Tokens in Assistant's Messages**: Calculates the number of tokens in the assistant's messages per conversation and summarizes this distribution. Useful for understanding the assistant's verbosity.\n",
"5. **Token Limit Warnings**: Checks if any examples exceed the maximum token limit (4096 tokens), as such examples will be truncated during fine-tuning, potentially resulting in data loss.\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"id": "52e58ee4",
"metadata": {
"scrolled": true
@ -212,9 +267,20 @@
"print(f\"\\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "2afb04df",
"metadata": {},
"source": [
"## Cost Estimation\n",
"\n",
"In this final section, we estimate the total number of tokens that will be used for fine-tuning, which allows us to approximate the cost. It is worth noting that the duration of the fine-tuning jobs will also increase with the token count. "
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"id": "fb95a7ce",
"metadata": {},
"outputs": [
@ -224,8 +290,7 @@
"text": [
"Dataset has ~4306 tokens that will be charged for during training\n",
"By default, you'll train for 20 epochs on this dataset\n",
"By default, you'll be charged for ~86120 tokens\n",
"See pricing page to estimate total costs\n"
"By default, you'll be charged for ~86120 tokens\n"
]
}
],
@ -249,25 +314,17 @@
"n_billing_tokens_in_dataset = sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens)\n",
"print(f\"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training\")\n",
"print(f\"By default, you'll train for {n_epochs} epochs on this dataset\")\n",
"print(f\"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens\")\n",
"print(\"See pricing page to estimate total costs\")"
"print(f\"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5736cd86",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ecbb2f6",
"attachments": {},
"cell_type": "markdown",
"id": "a0ad0369",
"metadata": {},
"outputs": [],
"source": []
"source": [
"See https://openai.com/pricing to estimate total costs."
]
}
],
"metadata": {
@ -286,7 +343,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.9.13"
}
},
"nbformat": 4,

Loading…
Cancel
Save