mirror of
https://github.com/openai/openai-cookbook
synced 2024-11-17 15:29:46 +00:00
295 lines
9.2 KiB
Plaintext
295 lines
9.2 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "4e63973b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import json\n",
|
|
"import tiktoken\n",
|
|
"import numpy as np\n",
|
|
"from collections import defaultdict"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "748510f2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data_path = \"data/toy_chat_fine_tuning.jsonl\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "c248ccd1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Load dataset\n",
|
|
"with open(data_path, 'r', encoding='utf-8') as f:\n",
|
|
" dataset = [json.loads(line) for line in f]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "189ed16c",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Num examples: 5\n",
|
|
"First example:\n",
|
|
"{'role': 'system', 'content': 'You are a happy assistant that puts a positive spin on everything.'}\n",
|
|
"{'role': 'user', 'content': 'I fell off my bike today.'}\n",
|
|
"{'role': 'assistant', 'content': \"It's great that you're getting exercise outdoors!\"}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Initial dataset stats\n",
|
|
"print(\"Num examples:\", len(dataset))\n",
|
|
"print(\"First example:\")\n",
|
|
"for message in dataset[0][\"messages\"]:\n",
|
|
" print(message)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "d9f3ccbf",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"No errors found\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Format error checks\n",
|
|
"format_errors = defaultdict(int)\n",
|
|
"\n",
|
|
"for ex in dataset:\n",
|
|
" if not isinstance(ex, dict):\n",
|
|
" format_errors[\"data_type\"] += 1\n",
|
|
" continue\n",
|
|
" \n",
|
|
" messages = ex.get(\"messages\", None)\n",
|
|
" if not messages:\n",
|
|
" format_errors[\"missing_messages_list\"] += 1\n",
|
|
" continue\n",
|
|
" \n",
|
|
" for message in messages:\n",
|
|
" if \"role\" not in message or \"content\" not in message:\n",
|
|
" format_errors[\"message_missing_key\"] += 1\n",
|
|
" \n",
|
|
" if any(k not in (\"role\", \"content\", \"name\") for k in message):\n",
|
|
" format_errors[\"message_unrecognized_key\"] += 1\n",
|
|
" \n",
|
|
" if message.get(\"role\", None) not in (\"system\", \"user\", \"assistant\"):\n",
|
|
" format_errors[\"unrecognized_role\"] += 1\n",
|
|
" \n",
|
|
" content = message.get(\"content\", None)\n",
|
|
" if not content or not isinstance(content, str):\n",
|
|
" format_errors[\"missing_content\"] += 1\n",
|
|
" \n",
|
|
" if not any(message.get(\"role\", None) == \"assistant\" for message in messages):\n",
|
|
" format_errors[\"example_missing_assistant_message\"] += 1\n",
|
|
"\n",
|
|
"if format_errors:\n",
|
|
" print(\"Found errors:\")\n",
|
|
" for k, v in format_errors.items():\n",
|
|
" print(f\"{k}: {v}\")\n",
|
|
"else:\n",
|
|
" print(\"No errors found\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "8f4b47b5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Token counting functions\n",
|
|
"encoding = tiktoken.get_encoding(\"cl100k_base\")\n",
|
|
"\n",
|
|
"# not exact!\n",
|
|
"# simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb\n",
|
|
"def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):\n",
|
|
" num_tokens = 0\n",
|
|
" for message in messages:\n",
|
|
" num_tokens += tokens_per_message\n",
|
|
" for key, value in message.items():\n",
|
|
" num_tokens += len(encoding.encode(value))\n",
|
|
" if key == \"name\":\n",
|
|
" num_tokens += tokens_per_name\n",
|
|
" num_tokens += 3\n",
|
|
" return num_tokens\n",
|
|
"\n",
|
|
"def num_assistant_tokens_from_messages(messages):\n",
|
|
" num_tokens = 0\n",
|
|
" for message in messages:\n",
|
|
" if message[\"role\"] == \"assistant\":\n",
|
|
" num_tokens += len(encoding.encode(message[\"content\"]))\n",
|
|
" return num_tokens\n",
|
|
"\n",
|
|
"def print_distribution(values, name):\n",
|
|
" print(f\"\\n#### Distribution of {name}:\")\n",
|
|
" print(f\"min / max: {min(values)}, {max(values)}\")\n",
|
|
" print(f\"mean / median: {np.mean(values)}, {np.median(values)}\")\n",
|
|
" print(f\"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "52e58ee4",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Num examples missing system message: 1\n",
|
|
"Num examples missing user message: 1\n",
|
|
"\n",
|
|
"#### Distribution of num_messages_per_example:\n",
|
|
"min / max: 2, 9\n",
|
|
"mean / median: 3.8, 3.0\n",
|
|
"p5 / p95: 2.0, 6.6000000000000005\n",
|
|
"\n",
|
|
"#### Distribution of num_total_tokens_per_example:\n",
|
|
"min / max: 26, 8032\n",
|
|
"mean / median: 1648.4, 45.0\n",
|
|
"p5 / p95: 26.8, 4863.6\n",
|
|
"\n",
|
|
"#### Distribution of num_assistant_tokens_per_example:\n",
|
|
"min / max: 4, 8000\n",
|
|
"mean / median: 1610.2, 10.0\n",
|
|
"p5 / p95: 6.0, 4811.200000000001\n",
|
|
"\n",
|
|
"1 examples may be over the 4096 token limit, they will be truncated during fine-tuning\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Warnings and tokens counts\n",
|
|
"n_missing_system = 0\n",
|
|
"n_missing_user = 0\n",
|
|
"n_messages = []\n",
|
|
"convo_lens = []\n",
|
|
"assistant_message_lens = []\n",
|
|
"\n",
|
|
"for ex in dataset:\n",
|
|
" messages = ex[\"messages\"]\n",
|
|
" if not any(message[\"role\"] == \"system\" for message in messages):\n",
|
|
" n_missing_system += 1\n",
|
|
" if not any(message[\"role\"] == \"user\" for message in messages):\n",
|
|
" n_missing_user += 1\n",
|
|
" n_messages.append(len(messages))\n",
|
|
" convo_lens.append(num_tokens_from_messages(messages))\n",
|
|
" assistant_message_lens.append(num_assistant_tokens_from_messages(messages))\n",
|
|
" \n",
|
|
"print(\"Num examples missing system message:\", n_missing_system)\n",
|
|
"print(\"Num examples missing user message:\", n_missing_user)\n",
|
|
"print_distribution(n_messages, \"num_messages_per_example\")\n",
|
|
"print_distribution(convo_lens, \"num_total_tokens_per_example\")\n",
|
|
"print_distribution(assistant_message_lens, \"num_assistant_tokens_per_example\")\n",
|
|
"n_too_long = sum(l > 4096 for l in convo_lens)\n",
|
|
"print(f\"\\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "fb95a7ce",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Dataset has ~4306 tokens that will be charged for during training\n",
|
|
"By default, you'll train for 20 epochs on this dataset\n",
|
|
"By default, you'll be charged for ~86120 tokens\n",
|
|
"See pricing page to estimate total costs\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Pricing and default n_epochs estimate\n",
|
|
"MAX_TOKENS_PER_EXAMPLE = 4096\n",
|
|
"\n",
|
|
"TARGET_EPOCHS = 3\n",
|
|
"MIN_TARGET_EXAMPLES = 100\n",
|
|
"MAX_TARGET_EXAMPLES = 25000\n",
|
|
"MIN_DEFAULT_EPOCHS = 1\n",
|
|
"MAX_DEFAULT_EPOCHS = 25\n",
|
|
"\n",
|
|
"n_epochs = TARGET_EPOCHS\n",
|
|
"n_train_examples = len(dataset)\n",
|
|
"if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:\n",
|
|
" n_epochs = min(MAX_DEFAULT_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)\n",
|
|
"elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:\n",
|
|
" n_epochs = max(MIN_DEFAULT_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)\n",
|
|
"\n",
|
|
"n_billing_tokens_in_dataset = sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens)\n",
|
|
"print(f\"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training\")\n",
|
|
"print(f\"By default, you'll train for {n_epochs} epochs on this dataset\")\n",
|
|
"print(f\"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens\")\n",
|
|
"print(\"See pricing page to estimate total costs\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5736cd86",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0ecbb2f6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|