Fix finetuning example

pull/1077/head
Christian Muertz 2 years ago
parent 26ab363c26
commit 83f359aa95

@ -37,7 +37,7 @@
"openai.api_base = '' # Please add your endpoint here\n", "openai.api_base = '' # Please add your endpoint here\n",
"\n", "\n",
"openai.api_type = 'azure'\n", "openai.api_type = 'azure'\n",
"openai.api_version = '2022-03-01-preview' # this may change in the future" "openai.api_version = '2022-12-01'"
] ]
}, },
{ {
@ -89,9 +89,9 @@
"training_file_name = 'training.jsonl'\n", "training_file_name = 'training.jsonl'\n",
"validation_file_name = 'validation.jsonl'\n", "validation_file_name = 'validation.jsonl'\n",
"\n", "\n",
"sample_data = [{\"prompt\": \"When I go to the store, I want an\", \"completion\": \"apple\"},\n", "sample_data = [{\"prompt\": \"When I go to the store, I want an\", \"completion\": \"apple.\"},\n",
" {\"prompt\": \"When I go to work, I want a\", \"completion\": \"coffe\"},\n", " {\"prompt\": \"When I go to work, I want a\", \"completion\": \"coffee.\"},\n",
" {\"prompt\": \"When I go home, I want a\", \"completion\": \"soda\"}]\n", " {\"prompt\": \"When I go home, I want a\", \"completion\": \"soda.\"}]\n",
"\n", "\n",
"print(f'Generating the training file: {training_file_name}')\n", "print(f'Generating the training file: {training_file_name}')\n",
"with open(training_file_name, 'w') as training_file:\n", "with open(training_file_name, 'w') as training_file:\n",
@ -141,7 +141,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"print(f'Deleting already uploaded files.')\n", "print(f'Deleting already uploaded files...')\n",
"for id in results:\n", "for id in results:\n",
" openai.File.delete(sid = id)\n" " openai.File.delete(sid = id)\n"
] ]
@ -197,7 +197,7 @@
"source": [ "source": [
"print(f'Downloading training file: {training_id}')\n", "print(f'Downloading training file: {training_id}')\n",
"result = openai.File.download(training_id)\n", "result = openai.File.download(training_id)\n",
"print(result)" "print(result.decode('utf-8'))"
] ]
}, },
{ {
@ -225,9 +225,12 @@
"create_args = {\n", "create_args = {\n",
" \"training_file\": training_id,\n", " \"training_file\": training_id,\n",
" \"validation_file\": validation_id,\n", " \"validation_file\": validation_id,\n",
" \"model\": \"curie\",\n", " \"model\": \"babbage\",\n",
" \"compute_classification_metrics\": True,\n", " \"compute_classification_metrics\": True,\n",
" \"classification_n_classes\": 3\n", " \"classification_n_classes\": 3,\n",
" \"n_epochs\": 20,\n",
" \"batch_size\": 3,\n",
" \"learning_rate_multiplier\": 0.3\n",
"}\n", "}\n",
"resp = openai.FineTune.create(**create_args)\n", "resp = openai.FineTune.create(**create_args)\n",
"job_id = resp[\"id\"]\n", "job_id = resp[\"id\"]\n",
@ -258,7 +261,7 @@
" print(f\"Stream interrupted. Job is still {status}.\")\n", " print(f\"Stream interrupted. Job is still {status}.\")\n",
" return\n", " return\n",
"\n", "\n",
"print('Streaming events for the fine-tuning job: {job_id}')\n", "print(f'Streaming events for the fine-tuning job: {job_id}')\n",
"signal.signal(signal.SIGINT, signal_handler)\n", "signal.signal(signal.SIGINT, signal_handler)\n",
"\n", "\n",
"events = openai.FineTune.stream_events(job_id)\n", "events = openai.FineTune.stream_events(job_id)\n",
@ -296,7 +299,7 @@
"\n", "\n",
"print('Checking other finetune jobs in the subscription.')\n", "print('Checking other finetune jobs in the subscription.')\n",
"result = openai.FineTune.list()\n", "result = openai.FineTune.list()\n",
"print(f'Found {len(result)} finetune jobs.')" "print(f'Found {len(result.data)} finetune jobs.')"
] ]
}, },
{ {
@ -413,10 +416,10 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"print('Sending a test completion job')\n", "print('Sending a test completion job')\n",
"start_phrase = 'When I go to the store, I want a'\n", "start_phrase = 'When I go home, I want a'\n",
"response = openai.Completion.create(deployment_id=deployment_id, prompt=start_phrase, max_tokens=4)\n", "response = openai.Completion.create(deployment_id=deployment_id, prompt=start_phrase, temperature=0, stop=\".\")\n",
"text = response['choices'][0]['text'].replace('\\n', '').replace(' .', '.').strip()\n", "text = response['choices'][0]['text'].replace('\\n', '').replace(' .', '.').strip()\n",
"print(f'\"{start_phrase} {text}\"')\n" "print(f'\"{start_phrase} {text}.\"')"
] ]
}, },
{ {

Loading…
Cancel
Save