Function call cookbook (#1213)

pull/1212/head^2
msingh-openai 4 weeks ago committed by GitHub
parent cf15304c39
commit a5197083bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -51,19 +51,24 @@
},
{
"cell_type": "code",
"execution_count": 2,
"id": "dab872c5",
"metadata": {},
"outputs": [],
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:45:23.563149Z",
"start_time": "2024-05-15T17:45:22.925978Z"
}
},
"source": [
"import json\n",
"from openai import OpenAI\n",
"from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
"from termcolor import colored \n",
"\n",
"GPT_MODEL = \"gpt-3.5-turbo-0613\"\n",
"GPT_MODEL = \"gpt-4o\"\n",
"client = OpenAI()"
]
],
"outputs": [],
"execution_count": 2
},
{
"attachments": {},
@ -78,10 +83,13 @@
},
{
"cell_type": "code",
"execution_count": 3,
"id": "745ceec5",
"metadata": {},
"outputs": [],
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:45:28.816345Z",
"start_time": "2024-05-15T17:45:28.814155Z"
}
},
"source": [
"@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))\n",
"def chat_completion_request(messages, tools=None, tool_choice=None, model=GPT_MODEL):\n",
@ -97,14 +105,19 @@
" print(\"Unable to generate ChatCompletion response\")\n",
" print(f\"Exception: {e}\")\n",
" return e\n"
]
],
"outputs": [],
"execution_count": 3
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c4d1c99f",
"metadata": {},
"outputs": [],
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:45:30.003910Z",
"start_time": "2024-05-15T17:45:30.001259Z"
}
},
"source": [
"def pretty_print_conversation(messages):\n",
" role_to_color = {\n",
@ -125,7 +138,9 @@
" print(colored(f\"assistant: {message['content']}\\n\", role_to_color[message[\"role\"]]))\n",
" elif message[\"role\"] == \"function\":\n",
" print(colored(f\"function ({message['name']}): {message['content']}\\n\", role_to_color[message[\"role\"]]))\n"
]
],
"outputs": [],
"execution_count": 4
},
{
"attachments": {},
@ -140,10 +155,13 @@
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d2e25069",
"metadata": {},
"outputs": [],
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:45:31.794879Z",
"start_time": "2024-05-15T17:45:31.792617Z"
}
},
"source": [
"tools = [\n",
" {\n",
@ -195,7 +213,9 @@
" }\n",
" },\n",
"]"
]
],
"outputs": [],
"execution_count": 5
},
{
"attachments": {},
@ -208,21 +228,13 @@
},
{
"cell_type": "code",
"execution_count": 6,
"id": "518d6827",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ChatCompletionMessage(content='Sure, could you please tell me the location for which you would like to know the weather?', role='assistant', function_call=None, tool_calls=None)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:45:35.282310Z",
"start_time": "2024-05-15T17:45:33.861496Z"
}
],
},
"source": [
"messages = []\n",
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
@ -233,7 +245,20 @@
"assistant_message = chat_response.choices[0].message\n",
"messages.append(assistant_message)\n",
"assistant_message\n"
]
],
"outputs": [
{
"data": {
"text/plain": [
"ChatCompletionMessage(content=\"I need to know your location to provide you with the current weather. Could you please specify the city and state (or country) you're in?\", role='assistant', function_call=None, tool_calls=None)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 6
},
{
"attachments": {},
@ -246,14 +271,27 @@
},
{
"cell_type": "code",
"execution_count": 7,
"id": "23c42a6e",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:45:43.553403Z",
"start_time": "2024-05-15T17:45:42.205590Z"
}
},
"source": [
"messages.append({\"role\": \"user\", \"content\": \"I'm in Glasgow, Scotland.\"})\n",
"chat_response = chat_completion_request(\n",
" messages, tools=tools\n",
")\n",
"assistant_message = chat_response.choices[0].message\n",
"messages.append(assistant_message)\n",
"assistant_message\n"
],
"outputs": [
{
"data": {
"text/plain": [
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_2PArU89L2uf4uIzRqnph4SrN', function=Function(arguments='{\\n \"location\": \"Glasgow, Scotland\",\\n \"format\": \"celsius\"\\n}', name='get_current_weather'), type='function')])"
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_Dn2RJJSxzDm49vlVTehseJ0k', function=Function(arguments='{\"location\":\"Glasgow, Scotland\",\"format\":\"celsius\"}', name='get_current_weather'), type='function')])"
]
},
"execution_count": 7,
@ -261,15 +299,7 @@
"output_type": "execute_result"
}
],
"source": [
"messages.append({\"role\": \"user\", \"content\": \"I'm in Glasgow, Scotland.\"})\n",
"chat_response = chat_completion_request(\n",
" messages, tools=tools\n",
")\n",
"assistant_message = chat_response.choices[0].message\n",
"messages.append(assistant_message)\n",
"assistant_message\n"
]
"execution_count": 7
},
{
"attachments": {},
@ -282,21 +312,13 @@
},
{
"cell_type": "code",
"execution_count": 8,
"id": "fa232e54",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ChatCompletionMessage(content='Sure, I can help you with that. How many days would you like to get the weather forecast for?', role='assistant', function_call=None, tool_calls=None)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:45:47.090638Z",
"start_time": "2024-05-15T17:45:46.302475Z"
}
],
},
"source": [
"messages = []\n",
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
@ -307,7 +329,20 @@
"assistant_message = chat_response.choices[0].message\n",
"messages.append(assistant_message)\n",
"assistant_message\n"
]
],
"outputs": [
{
"data": {
"text/plain": [
"ChatCompletionMessage(content='Please specify the number of days (x) for which you want the weather forecast for Glasgow, Scotland.', role='assistant', function_call=None, tool_calls=None)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 8
},
{
"attachments": {},
@ -320,14 +355,25 @@
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c7d8a543",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:45:49.790820Z",
"start_time": "2024-05-15T17:45:48.847752Z"
}
},
"source": [
"messages.append({\"role\": \"user\", \"content\": \"5 days\"})\n",
"chat_response = chat_completion_request(\n",
" messages, tools=tools\n",
")\n",
"chat_response.choices[0]\n"
],
"outputs": [
{
"data": {
"text/plain": [
"Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_ujD1NwPxzeOSCbgw2NOabOin', function=Function(arguments='{\\n \"location\": \"Glasgow, Scotland\",\\n \"format\": \"celsius\",\\n \"num_days\": 5\\n}', name='get_n_day_weather_forecast'), type='function')]), internal_metrics=[{'cached_prompt_tokens': 128, 'total_accepted_tokens': 0, 'total_batched_tokens': 273, 'total_predicted_tokens': 0, 'total_rejected_tokens': 0, 'total_tokens_in_completion': 274, 'cached_embeddings_bytes': 0, 'cached_embeddings_n': 0, 'uncached_embeddings_bytes': 0, 'uncached_embeddings_n': 0, 'fetched_embeddings_bytes': 0, 'fetched_embeddings_n': 0, 'n_evictions': 0, 'sampling_steps': 40, 'sampling_steps_with_predictions': 0, 'batcher_ttft': 0.035738229751586914, 'batcher_initial_queue_time': 0.0007979869842529297}])"
"Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_Yg5ydH9lHhLjjYQyXbNvh004', function=Function(arguments='{\"location\":\"Glasgow, Scotland\",\"format\":\"celsius\",\"num_days\":5}', name='get_n_day_weather_forecast'), type='function')]))"
]
},
"execution_count": 9,
@ -335,13 +381,7 @@
"output_type": "execute_result"
}
],
"source": [
"messages.append({\"role\": \"user\", \"content\": \"5 days\"})\n",
"chat_response = chat_completion_request(\n",
" messages, tools=tools\n",
")\n",
"chat_response.choices[0]\n"
]
"execution_count": 9
},
{
"attachments": {},
@ -363,21 +403,13 @@
},
{
"cell_type": "code",
"execution_count": 10,
"id": "559371b7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_MapM0kaNZBR046H4tAB2UGVu', function=Function(arguments='{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\",\\n \"num_days\": 1\\n}', name='get_n_day_weather_forecast'), type='function')])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:45:54.194255Z",
"start_time": "2024-05-15T17:45:52.975746Z"
}
],
},
"source": [
"# in this cell we force the model to use get_n_day_weather_forecast\n",
"messages = []\n",
@ -387,25 +419,30 @@
" messages, tools=tools, tool_choice={\"type\": \"function\", \"function\": {\"name\": \"get_n_day_weather_forecast\"}}\n",
")\n",
"chat_response.choices[0].message"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a7ab0f58",
"metadata": {},
],
"outputs": [
{
"data": {
"text/plain": [
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_z8ijGSoMLS7xcaU7MjLmpRL8', function=Function(arguments='{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\"\\n}', name='get_current_weather'), type='function')])"
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_aP8ZEtGcyseL0btTMYxTCKbk', function=Function(arguments='{\"location\":\"Toronto, Canada\",\"format\":\"celsius\",\"num_days\":1}', name='get_n_day_weather_forecast'), type='function')])"
]
},
"execution_count": 11,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 10
},
{
"cell_type": "code",
"id": "a7ab0f58",
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:45:56.841233Z",
"start_time": "2024-05-15T17:45:55.433397Z"
}
},
"source": [
"# if we don't force the model to use get_n_day_weather_forecast it may not\n",
"messages = []\n",
@ -415,7 +452,20 @@
" messages, tools=tools\n",
")\n",
"chat_response.choices[0].message"
]
],
"outputs": [
{
"data": {
"text/plain": [
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_5HqCVRaAoBuU0uTlO3MUwaWX', function=Function(arguments='{\"location\": \"Toronto, Canada\", \"format\": \"celsius\"}', name='get_current_weather'), type='function'), ChatCompletionMessageToolCall(id='call_C9kCha28xHEsxYl4PxZ1l5LI', function=Function(arguments='{\"location\": \"Toronto, Canada\", \"format\": \"celsius\", \"num_days\": 3}', name='get_n_day_weather_forecast'), type='function')])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 11
},
{
"attachments": {},
@ -428,14 +478,27 @@
},
{
"cell_type": "code",
"execution_count": 12,
"id": "acfe54e6",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:45:59.800346Z",
"start_time": "2024-05-15T17:45:59.289603Z"
}
},
"source": [
"messages = []\n",
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
"messages.append({\"role\": \"user\", \"content\": \"Give me the current weather (use Celcius) for Toronto, Canada.\"})\n",
"chat_response = chat_completion_request(\n",
" messages, tools=tools, tool_choice=\"none\"\n",
")\n",
"chat_response.choices[0].message\n"
],
"outputs": [
{
"data": {
"text/plain": [
"ChatCompletionMessage(content='{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\"\\n}', role='assistant', function_call=None, tool_calls=None)"
"ChatCompletionMessage(content=\"I'll get the current weather for Toronto, Canada in Celsius.\", role='assistant', function_call=None, tool_calls=None)"
]
},
"execution_count": 12,
@ -443,15 +506,7 @@
"output_type": "execute_result"
}
],
"source": [
"messages = []\n",
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
"messages.append({\"role\": \"user\", \"content\": \"Give me the current weather (use Celcius) for Toronto, Canada.\"})\n",
"chat_response = chat_completion_request(\n",
" messages, tools=tools, tool_choice=\"none\"\n",
")\n",
"chat_response.choices[0].message\n"
]
"execution_count": 12
},
{
"cell_type": "markdown",
@ -460,20 +515,35 @@
"source": [
"### Parallel Function Calling\n",
"\n",
"Newer models like gpt-4-1106-preview or gpt-3.5-turbo-1106 can call multiple functions in one turn."
"Newer models such as gpt-4o or gpt-3.5-turbo can call multiple functions in one turn."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "380eeb68",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:46:04.048553Z",
"start_time": "2024-05-15T17:46:01.273501Z"
}
},
"source": [
"messages = []\n",
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
"messages.append({\"role\": \"user\", \"content\": \"what is the weather going to be like in San Francisco and Glasgow over the next 4 days\"})\n",
"chat_response = chat_completion_request(\n",
" messages, tools=tools, model=GPT_MODEL\n",
")\n",
"\n",
"assistant_message = chat_response.choices[0].message.tool_calls\n",
"assistant_message"
],
"outputs": [
{
"data": {
"text/plain": [
"[ChatCompletionMessageToolCall(id='call_8BlkS2yvbkkpL3V1Yxc6zR6u', function=Function(arguments='{\"location\": \"San Francisco, CA\", \"format\": \"celsius\", \"num_days\": 4}', name='get_n_day_weather_forecast'), type='function'),\n",
" ChatCompletionMessageToolCall(id='call_vSZMy3f24wb3vtNXucpFfAbG', function=Function(arguments='{\"location\": \"Glasgow\", \"format\": \"celsius\", \"num_days\": 4}', name='get_n_day_weather_forecast'), type='function')]"
"[ChatCompletionMessageToolCall(id='call_pFdKcCu5taDTtOOfX14vEDRp', function=Function(arguments='{\"location\": \"San Francisco, CA\", \"format\": \"fahrenheit\", \"num_days\": 4}', name='get_n_day_weather_forecast'), type='function'),\n",
" ChatCompletionMessageToolCall(id='call_Veeyp2hYJOKp0wT7ODxmTjaS', function=Function(arguments='{\"location\": \"Glasgow, UK\", \"format\": \"celsius\", \"num_days\": 4}', name='get_n_day_weather_forecast'), type='function')]"
]
},
"execution_count": 13,
@ -481,17 +551,7 @@
"output_type": "execute_result"
}
],
"source": [
"messages = []\n",
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
"messages.append({\"role\": \"user\", \"content\": \"what is the weather going to be like in San Francisco and Glasgow over the next 4 days\"})\n",
"chat_response = chat_completion_request(\n",
" messages, tools=tools, model='gpt-3.5-turbo-1106'\n",
")\n",
"\n",
"assistant_message = chat_response.choices[0].message.tool_calls\n",
"assistant_message"
]
"execution_count": 13
},
{
"attachments": {},
@ -519,9 +579,19 @@
},
{
"cell_type": "code",
"execution_count": 14,
"id": "30f6b60e",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:46:07.270851Z",
"start_time": "2024-05-15T17:46:07.265545Z"
}
},
"source": [
"import sqlite3\n",
"\n",
"conn = sqlite3.connect(\"data/Chinook.db\")\n",
"print(\"Opened database successfully\")"
],
"outputs": [
{
"name": "stdout",
@ -531,19 +601,17 @@
]
}
],
"source": [
"import sqlite3\n",
"\n",
"conn = sqlite3.connect(\"data/Chinook.db\")\n",
"print(\"Opened database successfully\")"
]
"execution_count": 14
},
{
"cell_type": "code",
"execution_count": 15,
"id": "abec0214",
"metadata": {},
"outputs": [],
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:46:09.345308Z",
"start_time": "2024-05-15T17:46:09.342998Z"
}
},
"source": [
"def get_table_names(conn):\n",
" \"\"\"Return a list of table names.\"\"\"\n",
@ -570,7 +638,9 @@
" columns_names = get_column_names(conn, table_name)\n",
" table_dicts.append({\"table_name\": table_name, \"column_names\": columns_names})\n",
" return table_dicts\n"
]
],
"outputs": [],
"execution_count": 15
},
{
"attachments": {},
@ -583,10 +653,13 @@
},
{
"cell_type": "code",
"execution_count": 16,
"id": "0c0104cd",
"metadata": {},
"outputs": [],
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:46:11.303746Z",
"start_time": "2024-05-15T17:46:11.301210Z"
}
},
"source": [
"database_schema_dict = get_database_info(conn)\n",
"database_schema_string = \"\\n\".join(\n",
@ -595,7 +668,9 @@
" for table in database_schema_dict\n",
" ]\n",
")"
]
],
"outputs": [],
"execution_count": 16
},
{
"attachments": {},
@ -608,10 +683,13 @@
},
{
"cell_type": "code",
"execution_count": 17,
"id": "0258813a",
"metadata": {},
"outputs": [],
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:46:16.569530Z",
"start_time": "2024-05-15T17:46:16.567801Z"
}
},
"source": [
"tools = [\n",
" {\n",
@ -637,7 +715,9 @@
" }\n",
" }\n",
"]"
]
],
"outputs": [],
"execution_count": 17
},
{
"attachments": {},
@ -652,10 +732,13 @@
},
{
"cell_type": "code",
"execution_count": 18,
"id": "65585e74",
"metadata": {},
"outputs": [],
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:46:19.198723Z",
"start_time": "2024-05-15T17:46:19.197043Z"
}
},
"source": [
"def ask_database(conn, query):\n",
" \"\"\"Function to query SQLite database with a provided SQL query.\"\"\"\n",
@ -663,92 +746,117 @@
" results = str(conn.execute(query).fetchall())\n",
" except Exception as e:\n",
" results = f\"query failed with error: {e}\"\n",
" return results\n",
"\n",
"def execute_function_call(message):\n",
" if message.tool_calls[0].function.name == \"ask_database\":\n",
" query = json.loads(message.tool_calls[0].function.arguments)[\"query\"]\n",
" results = ask_database(conn, query)\n",
" else:\n",
" results = f\"Error: function {message.tool_calls[0].function.name} does not exist\"\n",
" return results"
],
"outputs": [],
"execution_count": 18
},
{
"cell_type": "markdown",
"id": "8f6885e9f0af5c40",
"metadata": {},
"source": [
"##### Steps to invoke a function call using Chat Completions API: \n",
"\n",
"**Step 1**: Prompt the model with content that may result in model selecting a tool to use. The description of the tools such as a function names and signature is defined in the 'Tools' list and passed to the model in API call. If selected, the function name and parameters are included in the response.<br>\n",
" \n",
"**Step 2**: Check programmatically if model wanted to call a function. If true, proceed to step 3. <br> \n",
"**Step 3**: Extract the function name and parameters from response, call the function with parameters. Append the result to messages. <br> \n",
"**Step 4**: Invoke the chat completions API with the message list to get the response. "
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "38c55083",
"metadata": {},
"id": "e8b7cb9cdc7a7616",
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-15T17:46:25.725379Z",
"start_time": "2024-05-15T17:46:24.255505Z"
}
},
"source": [
"# Step #1: Prompt with content that may result in function call. In this case the model can identify the information requested by the user is potentially available in the database schema passed to the model in Tools description. \n",
"messages = [{\n",
" \"role\":\"user\", \n",
" \"content\": \"What is the name of the album with the most tracks?\"\n",
"}]\n",
"\n",
"response = client.chat.completions.create(\n",
" model='gpt-4o', \n",
" messages=messages, \n",
" tools= tools, \n",
" tool_choice=\"auto\"\n",
")\n",
"\n",
"# Append the message to messages list\n",
"response_message = response.choices[0].message \n",
"messages.append(response_message)\n",
"\n",
"print(response_message)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[31msystem: Answer user questions by generating SQL queries against the Chinook Music Database.\n",
"\u001b[0m\n",
"\u001b[32muser: Hi, who are the top 5 artists by number of tracks?\n",
"\u001b[0m\n",
"\u001b[34massistant: Function(arguments='{\\n \"query\": \"SELECT Artist.Name, COUNT(Track.TrackId) AS TrackCount FROM Artist JOIN Album ON Artist.ArtistId = Album.ArtistId JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Artist.ArtistId ORDER BY TrackCount DESC LIMIT 5;\"\\n}', name='ask_database')\n",
"\u001b[0m\n",
"\u001b[35mfunction (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Lost', 92)]\n",
"\u001b[0m\n"
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_bXMf903yO78sdsMZble4yu90', function=Function(arguments='{\"query\":\"SELECT A.Title, COUNT(T.TrackId) AS TrackCount FROM Album A JOIN Track T ON A.AlbumId = T.AlbumId GROUP BY A.Title ORDER BY TrackCount DESC LIMIT 1;\"}', name='ask_database'), type='function')])\n"
]
}
],
"source": [
"messages = []\n",
"messages.append({\"role\": \"system\", \"content\": \"Answer user questions by generating SQL queries against the Chinook Music Database.\"})\n",
"messages.append({\"role\": \"user\", \"content\": \"Hi, who are the top 5 artists by number of tracks?\"})\n",
"chat_response = chat_completion_request(messages, tools)\n",
"assistant_message = chat_response.choices[0].message\n",
"assistant_message.content = str(assistant_message.tool_calls[0].function)\n",
"messages.append({\"role\": assistant_message.role, \"content\": assistant_message.content})\n",
"if assistant_message.tool_calls:\n",
" results = execute_function_call(assistant_message)\n",
" messages.append({\"role\": \"function\", \"tool_call_id\": assistant_message.tool_calls[0].id, \"name\": assistant_message.tool_calls[0].function.name, \"content\": results})\n",
"pretty_print_conversation(messages)"
]
"execution_count": 19
},
{
"cell_type": "code",
"execution_count": 20,
"id": "710481dc",
"id": "351c39def3417776",
"metadata": {
"scrolled": true
"ExecuteTime": {
"end_time": "2024-05-15T17:46:30.346444Z",
"start_time": "2024-05-15T17:46:29.699046Z"
}
},
"source": [
"# Step 2: determine if the response from the model includes a tool call. \n",
"tool_calls = response_message.tool_calls\n",
"if tool_calls:\n",
" # If true the model will return the name of the tool / function to call and the argument(s) \n",
" tool_call_id = tool_calls[0].id\n",
" tool_function_name = tool_calls[0].function.name\n",
" tool_query_string = eval(tool_calls[0].function.arguments)['query']\n",
" \n",
" # Step 3: Call the function and retrieve results. Append the results to the messages list. \n",
" if tool_function_name == 'ask_database':\n",
" results = ask_database(conn, tool_query_string)\n",
" \n",
" messages.append({\n",
" \"role\":\"tool\", \n",
" \"tool_call_id\":tool_call_id, \n",
" \"name\": tool_function_name, \n",
" \"content\":results\n",
" })\n",
" \n",
" # Step 4: Invoke the chat completions API with the function response appended to the messages list\n",
" # Note that messages with role 'tool' must be a response to a preceding message with 'tool_calls'\n",
" model_response_with_function_call = client.chat.completions.create(\n",
" model=\"gpt-4o\",\n",
" messages=messages,\n",
" ) # get a new response from the model where it can see the function response\n",
" print(model_response_with_function_call.choices[0].message.content)\n",
" else: \n",
" print(f\"Error: function {tool_function_name} does not exist\")\n",
"else: \n",
" # Model did not identify a function to call, result can be returned to the user \n",
" print(response_message.content) "
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[31msystem: Answer user questions by generating SQL queries against the Chinook Music Database.\n",
"\u001b[0m\n",
"\u001b[32muser: Hi, who are the top 5 artists by number of tracks?\n",
"\u001b[0m\n",
"\u001b[34massistant: Function(arguments='{\\n \"query\": \"SELECT Artist.Name, COUNT(Track.TrackId) AS TrackCount FROM Artist JOIN Album ON Artist.ArtistId = Album.ArtistId JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Artist.ArtistId ORDER BY TrackCount DESC LIMIT 5;\"\\n}', name='ask_database')\n",
"\u001b[0m\n",
"\u001b[35mfunction (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Lost', 92)]\n",
"\u001b[0m\n",
"\u001b[32muser: What is the name of the album with the most tracks?\n",
"\u001b[0m\n",
"\u001b[34massistant: Function(arguments='{\\n \"query\": \"SELECT Album.Title, COUNT(Track.TrackId) AS TrackCount FROM Album JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Album.AlbumId ORDER BY TrackCount DESC LIMIT 1;\"\\n}', name='ask_database')\n",
"\u001b[0m\n",
"\u001b[35mfunction (ask_database): [('Greatest Hits', 57)]\n",
"\u001b[0m\n"
"The album with the most tracks is titled \"Greatest Hits,\" and it contains 57 tracks.\n"
]
}
],
"source": [
"messages.append({\"role\": \"user\", \"content\": \"What is the name of the album with the most tracks?\"})\n",
"chat_response = chat_completion_request(messages, tools)\n",
"assistant_message = chat_response.choices[0].message\n",
"assistant_message.content = str(assistant_message.tool_calls[0].function)\n",
"messages.append({\"role\": assistant_message.role, \"content\": assistant_message.content})\n",
"if assistant_message.tool_calls:\n",
" results = execute_function_call(assistant_message)\n",
" messages.append({\"role\": \"function\", \"tool_call_id\": assistant_message.tool_calls[0].id, \"name\": assistant_message.tool_calls[0].function.name, \"content\": results})\n",
"pretty_print_conversation(messages)"
]
"execution_count": 20
},
{
"attachments": {},

Loading…
Cancel
Save