You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/Chat_finetuning_data_prep.i...

295 lines
9.2 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "4e63973b",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import tiktoken\n",
"import numpy as np\n",
"from collections import defaultdict"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "748510f2",
"metadata": {},
"outputs": [],
"source": [
"data_path = \"data/toy_chat_fine_tuning.jsonl\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c248ccd1",
"metadata": {},
"outputs": [],
"source": [
"# Load dataset\n",
"with open(data_path, 'r', encoding='utf-8') as f:\n",
" dataset = [json.loads(line) for line in f]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "189ed16c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Num examples: 5\n",
"First example:\n",
"{'role': 'system', 'content': 'You are a happy assistant that puts a positive spin on everything.'}\n",
"{'role': 'user', 'content': 'I fell off my bike today.'}\n",
"{'role': 'assistant', 'content': \"It's great that you're getting exercise outdoors!\"}\n"
]
}
],
"source": [
"# Initial dataset stats\n",
"print(\"Num examples:\", len(dataset))\n",
"print(\"First example:\")\n",
"for message in dataset[0][\"messages\"]:\n",
" print(message)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d9f3ccbf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"No errors found\n"
]
}
],
"source": [
"# Format error checks\n",
"format_errors = defaultdict(int)\n",
"\n",
"for ex in dataset:\n",
" if not isinstance(ex, dict):\n",
" format_errors[\"data_type\"] += 1\n",
" continue\n",
" \n",
" messages = ex.get(\"messages\", None)\n",
" if not messages:\n",
" format_errors[\"missing_messages_list\"] += 1\n",
" continue\n",
" \n",
" for message in messages:\n",
" if \"role\" not in message or \"content\" not in message:\n",
" format_errors[\"message_missing_key\"] += 1\n",
" \n",
" if any(k not in (\"role\", \"content\", \"name\") for k in message):\n",
" format_errors[\"message_unrecognized_key\"] += 1\n",
" \n",
" if message.get(\"role\", None) not in (\"system\", \"user\", \"assistant\"):\n",
" format_errors[\"unrecognized_role\"] += 1\n",
" \n",
" content = message.get(\"content\", None)\n",
" if not content or not isinstance(content, str):\n",
" format_errors[\"missing_content\"] += 1\n",
" \n",
" if not any(message.get(\"role\", None) == \"assistant\" for message in messages):\n",
" format_errors[\"example_missing_assistant_message\"] += 1\n",
"\n",
"if format_errors:\n",
" print(\"Found errors:\")\n",
" for k, v in format_errors.items():\n",
" print(f\"{k}: {v}\")\n",
"else:\n",
" print(\"No errors found\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8f4b47b5",
"metadata": {},
"outputs": [],
"source": [
"# Token counting functions\n",
"encoding = tiktoken.get_encoding(\"cl100k_base\")\n",
"\n",
"# not exact!\n",
"# simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb\n",
"def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):\n",
" num_tokens = 0\n",
" for message in messages:\n",
" num_tokens += tokens_per_message\n",
" for key, value in message.items():\n",
" num_tokens += len(encoding.encode(value))\n",
" if key == \"name\":\n",
" num_tokens += tokens_per_name\n",
" num_tokens += 3\n",
" return num_tokens\n",
"\n",
"def num_assistant_tokens_from_messages(messages):\n",
" num_tokens = 0\n",
" for message in messages:\n",
" if message[\"role\"] == \"assistant\":\n",
" num_tokens += len(encoding.encode(message[\"content\"]))\n",
" return num_tokens\n",
"\n",
"def print_distribution(values, name):\n",
" print(f\"\\n#### Distribution of {name}:\")\n",
" print(f\"min / max: {min(values)}, {max(values)}\")\n",
" print(f\"mean / median: {np.mean(values)}, {np.median(values)}\")\n",
" print(f\"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "52e58ee4",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Num examples missing system message: 1\n",
"Num examples missing user message: 1\n",
"\n",
"#### Distribution of num_messages_per_example:\n",
"min / max: 2, 9\n",
"mean / median: 3.8, 3.0\n",
"p5 / p95: 2.0, 6.6000000000000005\n",
"\n",
"#### Distribution of num_total_tokens_per_example:\n",
"min / max: 26, 8032\n",
"mean / median: 1648.4, 45.0\n",
"p5 / p95: 26.8, 4863.6\n",
"\n",
"#### Distribution of num_assistant_tokens_per_example:\n",
"min / max: 4, 8000\n",
"mean / median: 1610.2, 10.0\n",
"p5 / p95: 6.0, 4811.200000000001\n",
"\n",
"1 examples may be over the 4096 token limit, they will be truncated during fine-tuning\n"
]
}
],
"source": [
"# Warnings and tokens counts\n",
"n_missing_system = 0\n",
"n_missing_user = 0\n",
"n_messages = []\n",
"convo_lens = []\n",
"assistant_message_lens = []\n",
"\n",
"for ex in dataset:\n",
" messages = ex[\"messages\"]\n",
" if not any(message[\"role\"] == \"system\" for message in messages):\n",
" n_missing_system += 1\n",
" if not any(message[\"role\"] == \"user\" for message in messages):\n",
" n_missing_user += 1\n",
" n_messages.append(len(messages))\n",
" convo_lens.append(num_tokens_from_messages(messages))\n",
" assistant_message_lens.append(num_assistant_tokens_from_messages(messages))\n",
" \n",
"print(\"Num examples missing system message:\", n_missing_system)\n",
"print(\"Num examples missing user message:\", n_missing_user)\n",
"print_distribution(n_messages, \"num_messages_per_example\")\n",
"print_distribution(convo_lens, \"num_total_tokens_per_example\")\n",
"print_distribution(assistant_message_lens, \"num_assistant_tokens_per_example\")\n",
"n_too_long = sum(l > 4096 for l in convo_lens)\n",
"print(f\"\\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "fb95a7ce",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset has ~4306 tokens that will be charged for during training\n",
"By default, you'll train for 20 epochs on this dataset\n",
"By default, you'll be charged for ~86120 tokens\n",
"See pricing page to estimate total costs\n"
]
}
],
"source": [
"# Pricing and default n_epochs estimate\n",
"MAX_TOKENS_PER_EXAMPLE = 4096\n",
"\n",
"TARGET_EPOCHS = 3\n",
"MIN_TARGET_EXAMPLES = 100\n",
"MAX_TARGET_EXAMPLES = 25000\n",
"MIN_DEFAULT_EPOCHS = 1\n",
"MAX_DEFAULT_EPOCHS = 25\n",
"\n",
"n_epochs = TARGET_EPOCHS\n",
"n_train_examples = len(dataset)\n",
"if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:\n",
" n_epochs = min(MAX_DEFAULT_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)\n",
"elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:\n",
" n_epochs = max(MIN_DEFAULT_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)\n",
"\n",
"n_billing_tokens_in_dataset = sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens)\n",
"print(f\"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training\")\n",
"print(f\"By default, you'll train for {n_epochs} epochs on this dataset\")\n",
"print(f\"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens\")\n",
"print(\"See pricing page to estimate total costs\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5736cd86",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ecbb2f6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}