Update FT'ing for function calling notebook to match new python SDK (#866)

pull/890/head
jhills20 6 months ago committed by GitHub
parent bcbb505a4a
commit 16e5a1c2f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -71,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@ -82,7 +82,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -90,12 +90,12 @@
"import numpy as np\n",
"import json\n",
"import os\n",
"import openai\n",
"from openai import OpenAI\n",
"import itertools\n",
"from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
"from typing import Any, Dict, List, Generator\n",
"import ast\n",
"openai.api_key = os.getenv('OPENAI_API_KEY')\n"
"client = OpenAI()\n"
]
},
{
@ -114,7 +114,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@ -137,7 +137,7 @@
" if functions:\n",
" params['functions'] = functions\n",
"\n",
" completion = openai.ChatCompletion.create(**params)\n",
" completion = client.chat.completions.create(**params)\n",
" return completion.choices[0].message\n"
]
},
@ -159,7 +159,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@ -177,7 +177,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@ -430,7 +430,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@ -442,9 +442,28 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Land the drone at the home base\n",
"FunctionCall(arguments='{\\n \"location\": \"home_base\"\\n}', name='land_drone') \n",
"\n",
"Take off the drone to 50 meters\n",
"FunctionCall(arguments='{\\n \"altitude\": 50\\n}', name='takeoff_drone') \n",
"\n",
"change speed to 15 kilometers per hour\n",
"FunctionCall(arguments='{\\n \"speed\": 15\\n}', name='set_drone_speed') \n",
"\n",
"turn into an elephant!\n",
"FunctionCall(arguments='{}', name='reject_request') \n",
"\n"
]
}
],
"source": [
"for prompt in straightforward_prompts:\n",
" messages = []\n",
@ -464,7 +483,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@ -477,9 +496,36 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Play pre-recorded audio message\n",
"FunctionCall(arguments='{}', name='reject_request')\n",
"\n",
"\n",
"Initiate live-streaming on social media\n",
"FunctionCall(arguments='{\\n\"mode\": \"video\",\\n\"duration\": 0\\n}', name='control_camera')\n",
"\n",
"\n",
"Scan environment for heat signatures\n",
"None\n",
"\n",
"\n",
"Enable stealth mode\n",
"FunctionCall(arguments='{\\n \"mode\": \"off\"\\n}', name='set_drone_lighting')\n",
"\n",
"\n",
"Change drone's paint job color\n",
"FunctionCall(arguments='{\\n \"pattern\": \"solid\",\\n \"color\": \"blue\"\\n}', name='configure_led_display')\n",
"\n",
"\n"
]
}
],
"source": [
"for prompt in challenging_prompts:\n",
" messages = []\n",
@ -537,7 +583,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@ -557,7 +603,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@ -664,7 +710,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@ -734,7 +780,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@ -784,7 +830,7 @@
" request_prompt = COMMAND_GENERATION_PROMPT.format(invocation=invocation)\n",
"\n",
" messages = [{\"role\": \"user\", \"content\": f\"{request_prompt}\"}]\n",
" completion = get_chat_completion(messages,temperature=0.8)\n",
" completion = get_chat_completion(messages,temperature=0.8).content\n",
" command_dict = {\n",
" \"Input\": invocation,\n",
" \"Prompt\": completion\n",
@ -926,13 +972,13 @@
"outputs": [],
"source": [
"if __name__ == \"__main__\":\n",
" file = openai.File.create(\n",
" file = client.files.create(\n",
" file=open(training_file, \"rb\"),\n",
" purpose=\"fine-tune\",\n",
" )\n",
" file_id = file.id\n",
" print(file_id)\n",
" ft = openai.FineTuningJob.create(\n",
" ft = client.fine_tuning.jobs.create(\n",
" # model=\"gpt-4-0613\",\n",
" model=\"gpt-3.5-turbo\",\n",
" training_file=file_id,\n",
@ -980,7 +1026,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Conclustion"
"### Conclusion"
]
},
{

Loading…
Cancel
Save