mirror of
https://github.com/openai/openai-cookbook
synced 2024-11-11 13:11:02 +00:00
895 lines
32 KiB
Plaintext
895 lines
32 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "3e67f200",
|
|
"metadata": {},
|
|
"source": [
|
|
"# How to call functions with chat models\n",
|
|
"\n",
|
|
"This notebook covers how to use the Chat Completions API in combination with external functions to extend the capabilities of GPT models.\n",
|
|
"\n",
|
|
"`tools` is an optional parameter in the Chat Completion API which can be used to provide function specifications. The purpose of this is to enable models to generate function arguments which adhere to the provided specifications. Note that the API will not actually execute any function calls. It is up to developers to execute function calls using model outputs.\n",
|
|
"\n",
|
|
"Within the `tools` parameter, if the `functions` parameter is provided then by default the model will decide when it is appropriate to use one of the functions. The API can be forced to use a specific function by setting the `tool_choice` parameter to `{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}`. The API can also be forced to not use any function by setting the `tool_choice` parameter to `\"none\"`. If a function is used, the output will contain `\"finish_reason\": \"tool_calls\"` in the response, as well as a `tool_calls` object that has the name of the function and the generated function arguments.\n",
|
|
"\n",
|
|
"### Overview\n",
|
|
"\n",
|
|
"This notebook contains the following 2 sections:\n",
|
|
"\n",
|
|
"- **How to generate function arguments:** Specify a set of functions and use the API to generate function arguments.\n",
|
|
"- **How to call functions with model generated arguments:** Close the loop by actually executing functions with model generated arguments."
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "64c85e26",
|
|
"metadata": {},
|
|
"source": [
|
|
"## How to generate function arguments"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "80e71f33",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"is_executing": true
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install scipy --quiet\n",
|
|
"!pip install tenacity --quiet\n",
|
|
"!pip install tiktoken --quiet\n",
|
|
"!pip install termcolor --quiet\n",
|
|
"!pip install openai --quiet"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "dab872c5",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:41:58.148850Z",
|
|
"start_time": "2024-07-12T22:41:58.133412Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import json\n",
|
|
"from openai import OpenAI\n",
|
|
"from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
|
|
"from termcolor import colored \n",
|
|
"\n",
|
|
"GPT_MODEL = \"gpt-4o\"\n",
|
|
"client = OpenAI()"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "69ee6a93",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Utilities\n",
|
|
"\n",
|
|
"First let's define a few utilities for making calls to the Chat Completions API and for maintaining and keeping track of the conversation state."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "745ceec5",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:41:59.531820Z",
|
|
"start_time": "2024-07-12T22:41:59.529870Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))\n",
|
|
"def chat_completion_request(messages, tools=None, tool_choice=None, model=GPT_MODEL):\n",
|
|
" try:\n",
|
|
" response = client.chat.completions.create(\n",
|
|
" model=model,\n",
|
|
" messages=messages,\n",
|
|
" tools=tools,\n",
|
|
" tool_choice=tool_choice,\n",
|
|
" )\n",
|
|
" return response\n",
|
|
" except Exception as e:\n",
|
|
" print(\"Unable to generate ChatCompletion response\")\n",
|
|
" print(f\"Exception: {e}\")\n",
|
|
" return e\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "c4d1c99f",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:00.463896Z",
|
|
"start_time": "2024-07-12T22:42:00.461258Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def pretty_print_conversation(messages):\n",
|
|
" role_to_color = {\n",
|
|
" \"system\": \"red\",\n",
|
|
" \"user\": \"green\",\n",
|
|
" \"assistant\": \"blue\",\n",
|
|
" \"function\": \"magenta\",\n",
|
|
" }\n",
|
|
" \n",
|
|
" for message in messages:\n",
|
|
" if message[\"role\"] == \"system\":\n",
|
|
" print(colored(f\"system: {message['content']}\\n\", role_to_color[message[\"role\"]]))\n",
|
|
" elif message[\"role\"] == \"user\":\n",
|
|
" print(colored(f\"user: {message['content']}\\n\", role_to_color[message[\"role\"]]))\n",
|
|
" elif message[\"role\"] == \"assistant\" and message.get(\"function_call\"):\n",
|
|
" print(colored(f\"assistant: {message['function_call']}\\n\", role_to_color[message[\"role\"]]))\n",
|
|
" elif message[\"role\"] == \"assistant\" and not message.get(\"function_call\"):\n",
|
|
" print(colored(f\"assistant: {message['content']}\\n\", role_to_color[message[\"role\"]]))\n",
|
|
" elif message[\"role\"] == \"function\":\n",
|
|
" print(colored(f\"function ({message['name']}): {message['content']}\\n\", role_to_color[message[\"role\"]]))\n"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "29d4e02b",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Basic concepts\n",
|
|
"\n",
|
|
"Let's create some function specifications to interface with a hypothetical weather API. We'll pass these function specification to the Chat Completions API in order to generate function arguments that adhere to the specification."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "d2e25069",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:01.676606Z",
|
|
"start_time": "2024-07-12T22:42:01.674348Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"tools = [\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"get_current_weather\",\n",
|
|
" \"description\": \"Get the current weather\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"location\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
|
|
" },\n",
|
|
" \"format\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
|
|
" \"description\": \"The temperature unit to use. Infer this from the users location.\",\n",
|
|
" },\n",
|
|
" },\n",
|
|
" \"required\": [\"location\", \"format\"],\n",
|
|
" },\n",
|
|
" }\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"get_n_day_weather_forecast\",\n",
|
|
" \"description\": \"Get an N-day weather forecast\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"location\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
|
|
" },\n",
|
|
" \"format\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
|
|
" \"description\": \"The temperature unit to use. Infer this from the users location.\",\n",
|
|
" },\n",
|
|
" \"num_days\": {\n",
|
|
" \"type\": \"integer\",\n",
|
|
" \"description\": \"The number of days to forecast\",\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"location\", \"format\", \"num_days\"]\n",
|
|
" },\n",
|
|
" }\n",
|
|
" },\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "bfc39899",
|
|
"metadata": {},
|
|
"source": [
|
|
"If we prompt the model about the current weather, it will respond with some clarifying questions."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "518d6827",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:03.726604Z",
|
|
"start_time": "2024-07-12T22:42:03.154689Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"ChatCompletionMessage(content='Sure, can you please provide me with the name of your city and state?', role='assistant', function_call=None, tool_calls=None)"
|
|
]
|
|
},
|
|
"execution_count": 60,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"messages = []\n",
|
|
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
|
|
"messages.append({\"role\": \"user\", \"content\": \"What's the weather like today\"})\n",
|
|
"chat_response = chat_completion_request(\n",
|
|
" messages, tools=tools\n",
|
|
")\n",
|
|
"assistant_message = chat_response.choices[0].message\n",
|
|
"messages.append(assistant_message)\n",
|
|
"assistant_message\n"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "4c999375",
|
|
"metadata": {},
|
|
"source": [
|
|
"Once we provide the missing information, it will generate the appropriate function arguments for us."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "23c42a6e",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:05.778263Z",
|
|
"start_time": "2024-07-12T22:42:05.277346Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_xb7QwwNnx90LkmhtlW0YrgP2', function=Function(arguments='{\"location\":\"Glasgow, Scotland\",\"format\":\"celsius\"}', name='get_current_weather'), type='function')])"
|
|
]
|
|
},
|
|
"execution_count": 61,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"messages.append({\"role\": \"user\", \"content\": \"I'm in Glasgow, Scotland.\"})\n",
|
|
"chat_response = chat_completion_request(\n",
|
|
" messages, tools=tools\n",
|
|
")\n",
|
|
"assistant_message = chat_response.choices[0].message\n",
|
|
"messages.append(assistant_message)\n",
|
|
"assistant_message\n"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "c14d4762",
|
|
"metadata": {},
|
|
"source": [
|
|
"By prompting it differently, we can get it to target the other function we've told it about."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "fa232e54",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:07.575820Z",
|
|
"start_time": "2024-07-12T22:42:07.018764Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"ChatCompletionMessage(content='To provide you with the weather forecast for Glasgow, Scotland, could you please specify the number of days you would like the forecast for?', role='assistant', function_call=None, tool_calls=None)"
|
|
]
|
|
},
|
|
"execution_count": 62,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"messages = []\n",
|
|
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
|
|
"messages.append({\"role\": \"user\", \"content\": \"what is the weather going to be like in Glasgow, Scotland over the next x days\"})\n",
|
|
"chat_response = chat_completion_request(\n",
|
|
" messages, tools=tools\n",
|
|
")\n",
|
|
"assistant_message = chat_response.choices[0].message\n",
|
|
"messages.append(assistant_message)\n",
|
|
"assistant_message\n"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "6172ddac",
|
|
"metadata": {},
|
|
"source": [
|
|
"Once again, the model is asking us for clarification because it doesn't have enough information yet. In this case it already knows the location for the forecast, but it needs to know how many days are required in the forecast."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "c7d8a543",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:09.587530Z",
|
|
"start_time": "2024-07-12T22:42:08.666795Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_34PBraFdNN6KR95uD5rHF8Aw', function=Function(arguments='{\"location\":\"Glasgow, Scotland\",\"format\":\"celsius\",\"num_days\":5}', name='get_n_day_weather_forecast'), type='function')]))"
|
|
]
|
|
},
|
|
"execution_count": 63,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"messages.append({\"role\": \"user\", \"content\": \"5 days\"})\n",
|
|
"chat_response = chat_completion_request(\n",
|
|
" messages, tools=tools\n",
|
|
")\n",
|
|
"chat_response.choices[0]\n"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "4b758a0a",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Forcing the use of specific functions or no function"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "412f79ba",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can force the model to use a specific function, for example get_n_day_weather_forecast by using the function_call argument. By doing so, we force the model to make assumptions about how to use it."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "559371b7",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:12.216712Z",
|
|
"start_time": "2024-07-12T22:42:11.714246Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_FImGxrLowOAOszCaaQqQWmEN', function=Function(arguments='{\"location\":\"Toronto, Canada\",\"format\":\"celsius\",\"num_days\":7}', name='get_n_day_weather_forecast'), type='function')])"
|
|
]
|
|
},
|
|
"execution_count": 64,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# in this cell we force the model to use get_n_day_weather_forecast\n",
|
|
"messages = []\n",
|
|
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
|
|
"messages.append({\"role\": \"user\", \"content\": \"Give me a weather report for Toronto, Canada.\"})\n",
|
|
"chat_response = chat_completion_request(\n",
|
|
" messages, tools=tools, tool_choice={\"type\": \"function\", \"function\": {\"name\": \"get_n_day_weather_forecast\"}}\n",
|
|
")\n",
|
|
"chat_response.choices[0].message"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "a7ab0f58",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:14.264601Z",
|
|
"start_time": "2024-07-12T22:42:13.001306Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_n84kYFqjNFDPNGDEnjnrd2KC', function=Function(arguments='{\"location\": \"Toronto, Canada\", \"format\": \"celsius\"}', name='get_current_weather'), type='function'), ChatCompletionMessageToolCall(id='call_AEs3AFhJc9pn42hWSbHTaIDh', function=Function(arguments='{\"location\": \"Toronto, Canada\", \"format\": \"celsius\", \"num_days\": 3}', name='get_n_day_weather_forecast'), type='function')])"
|
|
]
|
|
},
|
|
"execution_count": 65,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# if we don't force the model to use get_n_day_weather_forecast it may not\n",
|
|
"messages = []\n",
|
|
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
|
|
"messages.append({\"role\": \"user\", \"content\": \"Give me a weather report for Toronto, Canada.\"})\n",
|
|
"chat_response = chat_completion_request(\n",
|
|
" messages, tools=tools\n",
|
|
")\n",
|
|
"chat_response.choices[0].message"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "3bd70e48",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can also force the model to not use a function at all. By doing so we prevent it from producing a proper function call."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "acfe54e6",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:16.928643Z",
|
|
"start_time": "2024-07-12T22:42:16.295006Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"ChatCompletionMessage(content=\"Sure, I'll get the current weather for Toronto, Canada in Celsius.\", role='assistant', function_call=None, tool_calls=None)"
|
|
]
|
|
},
|
|
"execution_count": 66,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"messages = []\n",
|
|
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
|
|
"messages.append({\"role\": \"user\", \"content\": \"Give me the current weather (use Celcius) for Toronto, Canada.\"})\n",
|
|
"chat_response = chat_completion_request(\n",
|
|
" messages, tools=tools, tool_choice=\"none\"\n",
|
|
")\n",
|
|
"chat_response.choices[0].message\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b616353b",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Parallel Function Calling\n",
|
|
"\n",
|
|
"Newer models such as gpt-4o or gpt-3.5-turbo can call multiple functions in one turn."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "380eeb68",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:18.988762Z",
|
|
"start_time": "2024-07-12T22:42:18.041914Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[ChatCompletionMessageToolCall(id='call_ObhLiJwaHwc3U1KyB4Pdpx8y', function=Function(arguments='{\"location\": \"San Francisco, CA\", \"format\": \"fahrenheit\", \"num_days\": 4}', name='get_n_day_weather_forecast'), type='function'),\n",
|
|
" ChatCompletionMessageToolCall(id='call_5YRgeZ0MGBMFKE3hZiLouwg7', function=Function(arguments='{\"location\": \"Glasgow, SCT\", \"format\": \"celsius\", \"num_days\": 4}', name='get_n_day_weather_forecast'), type='function')]"
|
|
]
|
|
},
|
|
"execution_count": 67,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"messages = []\n",
|
|
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
|
|
"messages.append({\"role\": \"user\", \"content\": \"what is the weather going to be like in San Francisco and Glasgow over the next 4 days\"})\n",
|
|
"chat_response = chat_completion_request(\n",
|
|
" messages, tools=tools, model=GPT_MODEL\n",
|
|
")\n",
|
|
"\n",
|
|
"assistant_message = chat_response.choices[0].message.tool_calls\n",
|
|
"assistant_message"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "b4482aee",
|
|
"metadata": {},
|
|
"source": [
|
|
"## How to call functions with model generated arguments\n",
|
|
"\n",
|
|
"In our next example, we'll demonstrate how to execute functions whose inputs are model-generated, and use this to implement an agent that can answer questions for us about a database. For simplicity we'll use the [Chinook sample database](https://www.sqlitetutorial.net/sqlite-sample-database/).\n",
|
|
"\n",
|
|
"*Note:* SQL generation can be high-risk in a production environment since models are not perfectly reliable at generating correct SQL."
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "f7654fef",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Specifying a function to execute SQL queries\n",
|
|
"\n",
|
|
"First let's define some helpful utility functions to extract data from a SQLite database."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "30f6b60e",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:20.742187Z",
|
|
"start_time": "2024-07-12T22:42:20.737751Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Opened database successfully\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import sqlite3\n",
|
|
"\n",
|
|
"conn = sqlite3.connect(\"data/Chinook.db\")\n",
|
|
"print(\"Opened database successfully\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "abec0214",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:21.370623Z",
|
|
"start_time": "2024-07-12T22:42:21.368246Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_table_names(conn):\n",
|
|
" \"\"\"Return a list of table names.\"\"\"\n",
|
|
" table_names = []\n",
|
|
" tables = conn.execute(\"SELECT name FROM sqlite_master WHERE type='table';\")\n",
|
|
" for table in tables.fetchall():\n",
|
|
" table_names.append(table[0])\n",
|
|
" return table_names\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_column_names(conn, table_name):\n",
|
|
" \"\"\"Return a list of column names.\"\"\"\n",
|
|
" column_names = []\n",
|
|
" columns = conn.execute(f\"PRAGMA table_info('{table_name}');\").fetchall()\n",
|
|
" for col in columns:\n",
|
|
" column_names.append(col[1])\n",
|
|
" return column_names\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_database_info(conn):\n",
|
|
" \"\"\"Return a list of dicts containing the table name and columns for each table in the database.\"\"\"\n",
|
|
" table_dicts = []\n",
|
|
" for table_name in get_table_names(conn):\n",
|
|
" columns_names = get_column_names(conn, table_name)\n",
|
|
" table_dicts.append({\"table_name\": table_name, \"column_names\": columns_names})\n",
|
|
" return table_dicts\n"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "77e6e5ea",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now can use these utility functions to extract a representation of the database schema."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "0c0104cd",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:22.668456Z",
|
|
"start_time": "2024-07-12T22:42:22.665839Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"database_schema_dict = get_database_info(conn)\n",
|
|
"database_schema_string = \"\\n\".join(\n",
|
|
" [\n",
|
|
" f\"Table: {table['table_name']}\\nColumns: {', '.join(table['column_names'])}\"\n",
|
|
" for table in database_schema_dict\n",
|
|
" ]\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "ae73c9ee",
|
|
"metadata": {},
|
|
"source": [
|
|
"As before, we'll define a function specification for the function we'd like the API to generate arguments for. Notice that we are inserting the database schema into the function specification. This will be important for the model to know about."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "0258813a",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:24.156291Z",
|
|
"start_time": "2024-07-12T22:42:24.154372Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"tools = [\n",
|
|
" {\n",
|
|
" \"type\": \"function\",\n",
|
|
" \"function\": {\n",
|
|
" \"name\": \"ask_database\",\n",
|
|
" \"description\": \"Use this function to answer user questions about music. Input should be a fully formed SQL query.\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"query\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": f\"\"\"\n",
|
|
" SQL query extracting info to answer the user's question.\n",
|
|
" SQL should be written using this database schema:\n",
|
|
" {database_schema_string}\n",
|
|
" The query should be returned in plain text, not in JSON.\n",
|
|
" \"\"\",\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"query\"],\n",
|
|
" },\n",
|
|
" }\n",
|
|
" }\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "da08c121",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Executing SQL queries\n",
|
|
"\n",
|
|
"Now let's implement the function that will actually excute queries against the database."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "65585e74",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:25.444734Z",
|
|
"start_time": "2024-07-12T22:42:25.442757Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def ask_database(conn, query):\n",
|
|
" \"\"\"Function to query SQLite database with a provided SQL query.\"\"\"\n",
|
|
" try:\n",
|
|
" results = str(conn.execute(query).fetchall())\n",
|
|
" except Exception as e:\n",
|
|
" results = f\"query failed with error: {e}\"\n",
|
|
" return results"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8f6885e9f0af5c40",
|
|
"metadata": {},
|
|
"source": [
|
|
"##### Steps to invoke a function call using Chat Completions API: \n",
|
|
"\n",
|
|
"**Step 1**: Prompt the model with content that may result in model selecting a tool to use. The description of the tools such as a function names and signature is defined in the 'Tools' list and passed to the model in API call. If selected, the function name and parameters are included in the response.<br>\n",
|
|
" \n",
|
|
"**Step 2**: Check programmatically if model wanted to call a function. If true, proceed to step 3. <br> \n",
|
|
"**Step 3**: Extract the function name and parameters from response, call the function with parameters. Append the result to messages. <br> \n",
|
|
"**Step 4**: Invoke the chat completions API with the message list to get the response. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "e8b7cb9cdc7a7616",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:28.395683Z",
|
|
"start_time": "2024-07-12T22:42:27.415626Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_wDN8uLjq2ofuU6rVx1k8Gw0e', function=Function(arguments='{\"query\":\"SELECT Album.Title, COUNT(Track.TrackId) AS TrackCount FROM Album INNER JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Album.Title ORDER BY TrackCount DESC LIMIT 1;\"}', name='ask_database'), type='function')])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Step #1: Prompt with content that may result in function call. In this case the model can identify the information requested by the user is potentially available in the database schema passed to the model in Tools description. \n",
|
|
"messages = [{\n",
|
|
" \"role\":\"user\", \n",
|
|
" \"content\": \"What is the name of the album with the most tracks?\"\n",
|
|
"}]\n",
|
|
"\n",
|
|
"response = client.chat.completions.create(\n",
|
|
" model='gpt-4o', \n",
|
|
" messages=messages, \n",
|
|
" tools= tools, \n",
|
|
" tool_choice=\"auto\"\n",
|
|
")\n",
|
|
"\n",
|
|
"# Append the message to messages list\n",
|
|
"response_message = response.choices[0].message \n",
|
|
"messages.append(response_message)\n",
|
|
"\n",
|
|
"print(response_message)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "351c39def3417776",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-07-12T22:42:30.439519Z",
|
|
"start_time": "2024-07-12T22:42:29.799492Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The album with the most tracks is titled \"Greatest Hits,\" which contains 57 tracks.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Step 2: determine if the response from the model includes a tool call. \n",
|
|
"tool_calls = response_message.tool_calls\n",
|
|
"if tool_calls:\n",
|
|
" # If true the model will return the name of the tool / function to call and the argument(s) \n",
|
|
" tool_call_id = tool_calls[0].id\n",
|
|
" tool_function_name = tool_calls[0].function.name\n",
|
|
" tool_query_string = json.loads(tool_calls[0].function.arguments)['query']\n",
|
|
"\n",
|
|
" # Step 3: Call the function and retrieve results. Append the results to the messages list. \n",
|
|
" if tool_function_name == 'ask_database':\n",
|
|
" results = ask_database(conn, tool_query_string)\n",
|
|
" \n",
|
|
" messages.append({\n",
|
|
" \"role\":\"tool\", \n",
|
|
" \"tool_call_id\":tool_call_id, \n",
|
|
" \"name\": tool_function_name, \n",
|
|
" \"content\":results\n",
|
|
" })\n",
|
|
" \n",
|
|
" # Step 4: Invoke the chat completions API with the function response appended to the messages list\n",
|
|
" # Note that messages with role 'tool' must be a response to a preceding message with 'tool_calls'\n",
|
|
" model_response_with_function_call = client.chat.completions.create(\n",
|
|
" model=\"gpt-4o\",\n",
|
|
" messages=messages,\n",
|
|
" ) # get a new response from the model where it can see the function response\n",
|
|
" print(model_response_with_function_call.choices[0].message.content)\n",
|
|
" else: \n",
|
|
" print(f\"Error: function {tool_function_name} does not exist\")\n",
|
|
"else: \n",
|
|
" # Model did not identify a function to call, result can be returned to the user \n",
|
|
" print(response_message.content) "
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "2d89073c",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Next Steps\n",
|
|
"\n",
|
|
"See our other [notebook](How_to_call_functions_for_knowledge_retrieval.ipynb) that demonstrates how to use the Chat Completions API and functions for knowledge retrieval to interact conversationally with a knowledge base."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.1"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|