Update function calling examples to Python SDK (#1025)

pull/1028/head
teomusatoiu 3 months ago committed by GitHub
parent f1e13cfcc7
commit 4d37365182
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -27,21 +27,68 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "80e71f33",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: scipy in /usr/local/lib/python3.11/site-packages (1.12.0)\n",
"Requirement already satisfied: numpy<1.29.0,>=1.22.4 in /usr/local/lib/python3.11/site-packages (from scipy) (1.26.3)\n",
"Requirement already satisfied: tenacity in /usr/local/lib/python3.11/site-packages (8.2.3)\n",
"Requirement already satisfied: tiktoken==0.3.3 in /usr/local/lib/python3.11/site-packages (0.3.3)\n",
"Requirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.11/site-packages (from tiktoken==0.3.3) (2023.12.25)\n",
"Requirement already satisfied: requests>=2.26.0 in /usr/local/lib/python3.11/site-packages (from tiktoken==0.3.3) (2.31.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken==0.3.3) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken==0.3.3) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken==0.3.3) (2.1.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken==0.3.3) (2023.11.17)\n",
"Requirement already satisfied: termcolor in /usr/local/lib/python3.11/site-packages (2.4.0)\n",
"Requirement already satisfied: openai in /usr/local/lib/python3.11/site-packages (1.10.0)\n",
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.11/site-packages (from openai) (4.2.0)\n",
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.11/site-packages (from openai) (1.9.0)\n",
"Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.11/site-packages (from openai) (0.26.0)\n",
"Requirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.11/site-packages (from openai) (2.5.3)\n",
"Requirement already satisfied: sniffio in /usr/local/lib/python3.11/site-packages (from openai) (1.3.0)\n",
"Requirement already satisfied: tqdm>4 in /usr/local/lib/python3.11/site-packages (from openai) (4.66.1)\n",
"Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.11/site-packages (from openai) (4.9.0)\n",
"Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.11/site-packages (from anyio<5,>=3.5.0->openai) (3.6)\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.11/site-packages (from httpx<1,>=0.23.0->openai) (2023.11.17)\n",
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.11/site-packages (from httpx<1,>=0.23.0->openai) (1.0.2)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.11/site-packages (from httpcore==1.*->httpx<1,>=0.23.0->openai) (0.14.0)\n",
"Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.11/site-packages (from pydantic<3,>=1.9.0->openai) (0.6.0)\n",
"Requirement already satisfied: pydantic-core==2.14.6 in /usr/local/lib/python3.11/site-packages (from pydantic<3,>=1.9.0->openai) (2.14.6)\n",
"Requirement already satisfied: arxiv in /usr/local/lib/python3.11/site-packages (2.1.0)\n",
"Requirement already satisfied: feedparser==6.0.10 in /usr/local/lib/python3.11/site-packages (from arxiv) (6.0.10)\n",
"Requirement already satisfied: requests==2.31.0 in /usr/local/lib/python3.11/site-packages (from arxiv) (2.31.0)\n",
"Requirement already satisfied: sgmllib3k in /usr/local/lib/python3.11/site-packages (from feedparser==6.0.10->arxiv) (1.0.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/site-packages (from requests==2.31.0->arxiv) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/site-packages (from requests==2.31.0->arxiv) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/site-packages (from requests==2.31.0->arxiv) (2.1.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/site-packages (from requests==2.31.0->arxiv) (2023.11.17)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.11/site-packages (2.2.0)\n",
"Requirement already satisfied: numpy<2,>=1.23.2 in /usr/local/lib/python3.11/site-packages (from pandas) (1.26.3)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/site-packages (from pandas) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/site-packages (from pandas) (2023.3.post1)\n",
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/site-packages (from pandas) (2023.4)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
"Requirement already satisfied: PyPDF2 in /usr/local/lib/python3.11/site-packages (3.0.1)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.11/site-packages (4.66.1)\n"
]
}
],
"source": [
"!pip install scipy\n",
"!pip install tenacity\n",
"!pip install tiktoken==0.3.3\n",
"!pip install termcolor \n",
"!pip install openai\n",
"!pip install requests\n",
"!pip install arxiv\n",
"!pip install pandas\n",
"!pip install PyPDF2\n",
@ -56,28 +103,25 @@
"outputs": [],
"source": [
"import os\n",
"\n",
"import arxiv\n",
"import ast\n",
"import concurrent\n",
"from csv import writer\n",
"from IPython.display import display, Markdown, Latex\n",
"import json\n",
"import openai\n",
"import os\n",
"import pandas as pd\n",
"import tiktoken\n",
"from csv import writer\n",
"from IPython.display import display, Markdown, Latex\n",
"from openai import OpenAI\n",
"from PyPDF2 import PdfReader\n",
"import requests\n",
"from scipy import spatial\n",
"from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
"import tiktoken\n",
"from tqdm import tqdm\n",
"from termcolor import colored\n",
"\n",
"client = openai.OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
"\n",
"GPT_MODEL = \"gpt-3.5-turbo\"\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n"
"GPT_MODEL = \"gpt-3.5-turbo-0613\"\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
"client = OpenAI()"
]
},
{
@ -98,7 +142,15 @@
"execution_count": 3,
"id": "2de5d32d",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Directory './data/papers' already exists.\n"
]
}
],
"source": [
"directory = './data/papers'\n",
"\n",
@ -114,7 +166,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "ae5cb7a1",
"metadata": {},
"outputs": [],
@ -130,7 +182,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "57217b9d",
"metadata": {},
"outputs": [],
@ -141,15 +193,19 @@
" return response\n",
"\n",
"\n",
"@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n",
"def get_articles(query, library=paper_dir_filepath, top_k=5):\n",
" \"\"\"This function gets the top_k articles based on a user's query, sorted by relevance.\n",
" It also downloads the files and stores them in arxiv_library.csv to be retrieved by the read_article_and_summarize.\n",
" \"\"\"\n",
" client = arxiv.Client()\n",
" search = arxiv.Search(\n",
" query=query, max_results=top_k, sort_by=arxiv.SortCriterion.Relevance\n",
" query = \"quantum\",\n",
" max_results = 10,\n",
" sort_by = arxiv.SortCriterion.SubmittedDate\n",
" )\n",
" result_list = []\n",
" for result in search.results():\n",
" for result in client.results(search):\n",
" result_dict = {}\n",
" result_dict.update({\"title\": result.title})\n",
" result_dict.update({\"summary\": result.summary})\n",
@ -177,20 +233,20 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "dda02bdb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'title': 'Proximal Policy Optimization and its Dynamic Version for Sequence Generation',\n",
" 'summary': 'In sequence generation task, many works use policy gradient for model\\noptimization to tackle the intractable backpropagation issue when maximizing\\nthe non-differentiable evaluation metrics or fooling the discriminator in\\nadversarial learning. In this paper, we replace policy gradient with proximal\\npolicy optimization (PPO), which is a proved more efficient reinforcement\\nlearning algorithm, and propose a dynamic approach for PPO (PPO-dynamic). We\\ndemonstrate the efficacy of PPO and PPO-dynamic on conditional sequence\\ngeneration tasks including synthetic experiment and chit-chat chatbot. The\\nresults show that PPO and PPO-dynamic can beat policy gradient by stability and\\nperformance.',\n",
" 'article_url': 'http://arxiv.org/abs/1808.07982v1',\n",
" 'pdf_url': 'http://arxiv.org/pdf/1808.07982v1'}"
"{'title': 'Entanglement entropy and deconfined criticality: emergent SO(5) symmetry and proper lattice bipartition',\n",
" 'summary': \"We study the R\\\\'enyi entanglement entropy (EE) of the two-dimensional $J$-$Q$\\nmodel, the emblematic quantum spin model of deconfined criticality at the phase\\ntransition between antiferromagnetic and valence-bond-solid ground states.\\nQuantum Monte Carlo simulations with an improved EE scheme reveal critical\\ncorner contributions that scale logarithmically with the system size, with a\\ncoefficient in remarkable agreement with the form expected from a large-$N$\\nconformal field theory with SO($N=5$) symmetry. However, details of the\\nbipartition of the lattice are crucial in order to observe this behavior. If\\nthe subsystem for the reduced density matrix does not properly accommodate\\nvalence-bond fluctuations, logarithmic contributions appear even for\\ncorner-less bipartitions. We here use a $45^\\\\circ$ tilted cut on the square\\nlattice. Beyond supporting an SO($5$) deconfined quantum critical point, our\\nresults for both the regular and tilted cuts demonstrate important microscopic\\naspects of the EE that are not captured by conformal field theory.\",\n",
" 'article_url': 'http://arxiv.org/abs/2401.14396v1',\n",
" 'pdf_url': 'http://arxiv.org/pdf/2401.14396v1'}"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -203,7 +259,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "11675627",
"metadata": {},
"outputs": [],
@ -228,7 +284,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "7211df2c",
"metadata": {},
"outputs": [],
@ -269,7 +325,7 @@
"def extract_chunk(content, template_prompt):\n",
" \"\"\"This function applies a prompt to some input content. In this case it returns a summarized chunk of text\"\"\"\n",
" prompt = template_prompt + content\n",
" response = openai.chat.completions.create(\n",
" response = client.chat.completions.create(\n",
" model=GPT_MODEL, messages=[{\"role\": \"user\", \"content\": prompt}], temperature=0\n",
" )\n",
" return response.choices[0].message.content\n",
@ -344,7 +400,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "898b94d4",
"metadata": {},
"outputs": [
@ -360,7 +416,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00, 1.19s/it]\n"
"100%|██████████| 15/15 [00:08<00:00, 1.76it/s]\n"
]
},
{
@ -378,7 +434,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "c715f60d",
"metadata": {},
"outputs": [
@ -386,21 +442,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Core Argument:\n",
"- The paper discusses the use of Proximal Policy Optimization (PPO) in sequence generation tasks, specifically in the context of chit-chat chatbots.\n",
"- The authors argue that PPO is a more efficient reinforcement learning algorithm compared to policy gradient, commonly used in text generation tasks.\n",
"- They propose a dynamic approach for PPO (PPO-dynamic) and demonstrate its efficacy in synthetic experiments and chit-chat chatbot tasks.\n",
"\n",
"Evidence:\n",
"- PPO-dynamic achieves high precision scores comparable to other algorithms in a synthetic counting task.\n",
"- PPO-dynamic shows faster progress and more stable learning curves compared to PPO in the synthetic counting task.\n",
"- In the chit-chat chatbot task, PPO-dynamic achieves a slightly higher BLEU-2 score than other algorithms.\n",
"- PPO and PPO-dynamic have more stable learning curves and converge faster than policy gradient.\n",
"\n",
"Conclusions:\n",
"- PPO is a better optimization method for sequence learning compared to policy gradient.\n",
"- PPO-dynamic further improves the optimization process by dynamically adjusting hyperparameters.\n",
"- PPO can be used as a new optimization method for GAN-based sequence learning for better performance.\n"
"The academic paper discusses the unique decomposition of generators of completely positive dynamical semigroups in infinite dimensions. The main result of the paper is that for any separable complex Hilbert space, any trace-class operator B that does not have a purely imaginary trace, and any generator L of a norm-continuous one-parameter semigroup of completely positive maps, there exists a unique bounded operator K and a unique completely positive map Φ such that L=K(·) + (·)K+ Φ. The paper also introduces a modified version of the Choi formalism, which relates completely positive maps to positive semi-definite operators, and characterizes when this correspondence is injective and surjective. The paper concludes by discussing the challenges and questions that arise when generalizing the results to non-separable Hilbert spaces.\n"
]
}
],
@ -421,25 +463,18 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "77a6fb4f",
"metadata": {},
"outputs": [],
"source": [
"@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n",
"def chat_completion_request(messages, functions=None, model=GPT_MODEL):\n",
" headers = {\n",
" \"Content-Type\": \"application/json\",\n",
" \"Authorization\": \"Bearer \" + openai.api_key,\n",
" }\n",
" json_data = {\"model\": model, \"messages\": messages}\n",
" if functions is not None:\n",
" json_data.update({\"tools\": functions})\n",
" try:\n",
" response = requests.post(\n",
" \"https://api.openai.com/v1/chat/completions\",\n",
" headers=headers,\n",
" json=json_data,\n",
" response = client.chat.completions.create(\n",
" model=model,\n",
" messages=messages,\n",
" functions=functions,\n",
" )\n",
" return response\n",
" except Exception as e:\n",
@ -450,7 +485,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "73f7672d",
"metadata": {},
"outputs": [],
@ -468,7 +503,7 @@
" \"system\": \"red\",\n",
" \"user\": \"green\",\n",
" \"assistant\": \"blue\",\n",
" \"tools\": \"magenta\",\n",
" \"function\": \"magenta\",\n",
" }\n",
" for message in self.conversation_history:\n",
" print(\n",
@ -481,58 +516,52 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "978b7877",
"metadata": {},
"outputs": [],
"source": [
"# Initiate our get_articles and read_article_and_summarize functions\n",
"arxiv_functions = [\n",
" { \n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_articles\",\n",
" \"description\": \"\"\"Use this function to get academic papers from arXiv to answer user questions.\"\"\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"query\": {\n",
" \"type\": \"string\",\n",
" \"description\": f\"\"\"\n",
" User query in JSON. Responses should be summarized and should include the article URL reference\n",
" \"\"\",\n",
" }\n",
" },\n",
" \"required\": [\"query\"],\n",
" {\n",
" \"name\": \"get_articles\",\n",
" \"description\": \"\"\"Use this function to get academic papers from arXiv to answer user questions.\"\"\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"query\": {\n",
" \"type\": \"string\",\n",
" \"description\": f\"\"\"\n",
" User query in JSON. Responses should be summarized and should include the article URL reference\n",
" \"\"\",\n",
" }\n",
" },\n",
" }\n",
" \"required\": [\"query\"],\n",
" },\n",
" },\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"read_article_and_summarize\",\n",
" \"description\": \"\"\"Use this function to read whole papers and provide a summary for users.\n",
" You should NEVER call this function before get_articles has been called in the conversation.\"\"\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"query\": {\n",
" \"type\": \"string\",\n",
" \"description\": f\"\"\"\n",
" Description of the article in plain text based on the user's query\n",
" \"\"\",\n",
" }\n",
" },\n",
" \"required\": [\"query\"],\n",
" \"name\": \"read_article_and_summarize\",\n",
" \"description\": \"\"\"Use this function to read whole papers and provide a summary for users.\n",
" You should NEVER call this function before get_articles has been called in the conversation.\"\"\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"query\": {\n",
" \"type\": \"string\",\n",
" \"description\": f\"\"\"\n",
" Description of the article in plain text based on the user's query\n",
" \"\"\",\n",
" }\n",
" },\n",
" }\n",
" \"required\": [\"query\"],\n",
" },\n",
" }\n",
"]\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"id": "0c88ae15",
"metadata": {},
"outputs": [],
@ -540,23 +569,23 @@
"def chat_completion_with_function_execution(messages, functions=[None]):\n",
" \"\"\"This function makes a ChatCompletion API call with the option of adding functions\"\"\"\n",
" response = chat_completion_request(messages, functions)\n",
" full_message = response.json().choices[0]\n",
" if full_message.finish_reason == \"tool_calls\":\n",
" full_message = response.choices[0]\n",
" if full_message.finish_reason == \"function_call\":\n",
" print(f\"Function generation requested, calling function\")\n",
" return call_arxiv_function(messages, full_message)\n",
" else:\n",
" print(f\"Function not required, responding to user\")\n",
" return response.json()\n",
" return response\n",
"\n",
"\n",
"def call_arxiv_function(messages, full_message):\n",
" \"\"\"Function calling function which executes function calls when the model believes it is necessary.\n",
" Currently extended by adding clauses to this if statement.\"\"\"\n",
"\n",
" if full_message.message.tool_calls[0].function.name == \"get_articles\":\n",
" if full_message.message.function_call.name == \"get_articles\":\n",
" try:\n",
" parsed_output = json.loads(\n",
" full_message.message.tool_calls[0].function.arguments\n",
" full_message.message.function_call.arguments\n",
" )\n",
" print(\"Getting search results\")\n",
" results = get_articles(parsed_output[\"query\"])\n",
@ -566,24 +595,24 @@
" print(f\"Error message: {e}\")\n",
" messages.append(\n",
" {\n",
" \"role\": \"tool\",\n",
" \"tool_call_id\": full_message.message.tool_calls[0].id,\n",
" \"role\": \"function\",\n",
" \"name\": full_message.message.function_call.name,\n",
" \"content\": str(results),\n",
" }\n",
" )\n",
" try:\n",
" print(\"Got search results, summarizing content\")\n",
" response = chat_completion_request(messages)\n",
" return response.json()\n",
" return response\n",
" except Exception as e:\n",
" print(type(e))\n",
" raise Exception(\"Function chat request failed\")\n",
"\n",
" elif (\n",
" full_message.message.tool_calls[0].function.name == \"read_article_and_summarize\"\n",
" full_message.message.function_call.name == \"read_article_and_summarize\"\n",
" ):\n",
" parsed_output = json.loads(\n",
" full_message.message.tool_calls[0].function.arguments\n",
" full_message.message.function_call.arguments\n",
" )\n",
" print(\"Finding and reading paper\")\n",
" summary = summarize_text(parsed_output[\"query\"])\n",
@ -606,7 +635,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"id": "c39a1d80",
"metadata": {},
"outputs": [],
@ -622,7 +651,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"id": "253fd0f7",
"metadata": {},
"outputs": [
@ -631,43 +660,30 @@
"output_type": "stream",
"text": [
"Function generation requested, calling function\n",
"Finding and reading paper\n",
"Chunking text from paper\n",
"Summarizing each chunk of text\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████████████| 17/17 [00:06<00:00, 2.65it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Summarizing into overall summary\n"
"Getting search results\n",
"Got search results, summarizing content\n"
]
},
{
"data": {
"text/markdown": [
"Core Argument:\n",
"- The paper focuses on the theoretical analysis of the PPO-Clip algorithm in the context of deep reinforcement learning.\n",
"- The authors propose two core ideas: reinterpreting PPO-Clip from the perspective of hinge loss and introducing a two-step policy improvement scheme.\n",
"- The paper establishes the global convergence of PPO-Clip and characterizes its convergence rate.\n",
"PPO (Proximal Policy Optimization) is a reinforcement learning algorithm used in training agents to make sequential decisions in dynamic environments. It belongs to the family of policy optimization algorithms and addresses the challenge of optimizing policies in a stable and sample-efficient manner. \n",
"\n",
"PPO works by iteratively collecting a batch of data from interacting with the environment, computing advantages to estimate the quality of actions, and then performing multiple policy updates using a clipped surrogate objective. This objective function helps prevent excessive policy updates that could lead to policy divergence and instability. \n",
"\n",
"By iteratively updating the policy using the collected data, PPO seeks to maximize the expected cumulative rewards obtained by the agent. It has been used successfully in a variety of reinforcement learning tasks, including robotic control, game playing, and simulated environments. \n",
"\n",
"Evidence:\n",
"- The paper addresses the challenges posed by the clipping mechanism and neural function approximation.\n",
"- The authors provide theoretical proofs, lemmas, and mathematical analysis to support their arguments.\n",
"- The paper presents empirical experiments on various reinforcement learning benchmark tasks to validate the effectiveness of PPO-Clip.\n",
"To learn more about PPO reinforcement learning, you can read the following papers:\n",
"\n",
"Conclusions:\n",
"- The paper offers theoretical insights into the performance of PPO-Clip and provides a framework for analyzing its convergence properties.\n",
"- PPO-Clip is shown to have a global convergence rate of O(1/sqrt(T)), where T is the number of iterations.\n",
"- The hinge loss reinterpretation of PPO-Clip allows for variants with comparable empirical performance.\n",
"- The paper contributes to a better understanding of PPO-Clip in the reinforcement learning community."
"1. Title: \"Proximal Policy Optimization Algorithms\"\n",
" Article URL: [arxiv.org/abs/1707.06347v2](http://arxiv.org/abs/1707.06347v2)\n",
" Summary: This paper introduces PPO and presents two versions of the algorithm: PPO-Penalty and PPO-Clip. It provides a detailed description of PPO's update rule and compares its performance against other popular reinforcement learning algorithms.\n",
"\n",
"2. Title: \"Emergent Properties of PPO Reinforcement Learning in Resource-Limited Environments\"\n",
" Article URL: [arxiv.org/abs/2001.14342v1](http://arxiv.org/abs/2001.14342v1)\n",
" Summary: This paper explores the emergent properties of PPO reinforcement learning algorithms in resource-limited environments. It discusses the impact of varying the resource constraints and agent population sizes on the learning process and performance.\n",
"\n",
"Reading these papers will give you a deeper understanding of PPO reinforcement learning and its applications in different domains."
],
"text/plain": [
"<IPython.core.display.Markdown object>"
@ -690,7 +706,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"id": "3ca3e18a",
"metadata": {},
"outputs": [
@ -708,7 +724,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00, 1.08s/it]\n"
"100%|██████████| 15/15 [00:09<00:00, 1.67it/s]\n"
]
},
{
@ -721,21 +737,7 @@
{
"data": {
"text/markdown": [
"Core Argument:\n",
"- The paper discusses the use of proximal policy optimization (PPO) in sequence generation tasks, specifically in the context of chit-chat chatbots.\n",
"- The authors argue that PPO is a more efficient reinforcement learning algorithm compared to policy gradient, which is commonly used in text generation tasks.\n",
"- They propose a dynamic approach for PPO (PPO-dynamic) and demonstrate its efficacy in synthetic experiments and chit-chat chatbot tasks.\n",
"\n",
"Evidence:\n",
"- The authors derive the constraints for PPO-dynamic and provide the pseudo code for both PPO and PPO-dynamic.\n",
"- They compare the performance of PPO-dynamic with other algorithms, including REINFORCE, MIXER, and SeqGAN, on a synthetic counting task and a chit-chat chatbot task using the OpenSubtitles dataset.\n",
"- In the synthetic counting task, PPO-dynamic achieves a high precision score comparable to REINFORCE and MIXER, with a faster learning curve compared to PPO.\n",
"- In the chit-chat chatbot task, PPO-dynamic achieves a slightly higher BLEU-2 score than REINFORCE and PPO, with a more stable and faster learning curve than policy gradient.\n",
"\n",
"Conclusions:\n",
"- The results suggest that PPO is a better optimization method for sequence learning compared to policy gradient.\n",
"- PPO-dynamic further improves the optimization process by dynamically adjusting the hyperparameters.\n",
"- The authors conclude that PPO can be used as a new optimization method for GAN-based sequence learning for better performance."
"The paper discusses the unique decomposition of generators of completely positive dynamical semigroups in infinite dimensions. The main result is that for any separable complex Hilbert space, any trace-class operator B that does not have a purely imaginary trace, and any generator L of a norm-continuous one-parameter semigroup of completely positive maps, there exists a unique bounded operator K and a unique completely positive map Φ such that L=K(·) + (·)K+ Φ. The paper also introduces a modified version of the Choi formalism and characterizes when this correspondence is injective and surjective. The paper concludes by discussing the challenges and questions that arise when generalizing the results to non-separable Hilbert spaces."
],
"text/plain": [
"<IPython.core.display.Markdown object>"
@ -760,9 +762,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "tua_test",
"display_name": "Python 3",
"language": "python",
"name": "tua_test"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
@ -774,7 +776,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.12.1"
}
},
"nbformat": 4,

@ -12,7 +12,7 @@
"\n",
"`tools` is an optional parameter in the Chat Completion API which can be used to provide function specifications. The purpose of this is to enable models to generate function arguments which adhere to the provided specifications. Note that the API will not actually execute any function calls. It is up to developers to execute function calls using model outputs.\n",
"\n",
"Within the `tools` parameter, if the `functions` parameter is provided then by default the model will decide when it is appropriate to use one of the functions. The API can be forced to use a specific function by setting the `tool_choice` parameter to `{\"type\": \"function\", \"function\": {\"name\": \"<insert-function-name>\"}}`. The API can also be forced to not use any function by setting the `tool_choice` parameter to `\"none\"`. If a function is used, the output will contain `\"finish_reason\": \"tool_calls\"` in the response, as well as a `tool_choice` object that has the name of the function and the generated function arguments. For details, see the API [Documentation](https://platform.openai.com/docs/api-reference/chat/create)\n",
"Within the `tools` parameter, if the `functions` parameter is provided then by default the model will decide when it is appropriate to use one of the functions. The API can be forced to use a specific function by setting the `tool_choice` parameter to `{\"name\": \"<insert-function-name>\"}`. The API can also be forced to not use any function by setting the `tool_choice` parameter to `\"none\"`. If a function is used, the output will contain `\"finish_reason\": \"function_call\"` in the response, as well as a `tool_choice` object that has the name of the function and the generated function arguments.\n",
"\n",
"### Overview\n",
"\n",
@ -33,39 +33,68 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "80e71f33",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: scipy in /usr/local/lib/python3.11/site-packages (1.12.0)\n",
"Requirement already satisfied: numpy<1.29.0,>=1.22.4 in /usr/local/lib/python3.11/site-packages (from scipy) (1.26.3)\n",
"Requirement already satisfied: tenacity in /usr/local/lib/python3.11/site-packages (8.2.3)\n",
"Requirement already satisfied: tiktoken in /usr/local/lib/python3.11/site-packages (0.3.3)\n",
"Requirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.11/site-packages (from tiktoken) (2023.12.25)\n",
"Requirement already satisfied: requests>=2.26.0 in /usr/local/lib/python3.11/site-packages (from tiktoken) (2.31.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (2.1.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (2023.11.17)\n",
"Requirement already satisfied: termcolor in /usr/local/lib/python3.11/site-packages (2.4.0)\n",
"Requirement already satisfied: openai in /usr/local/lib/python3.11/site-packages (1.10.0)\n",
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.11/site-packages (from openai) (4.2.0)\n",
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.11/site-packages (from openai) (1.9.0)\n",
"Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.11/site-packages (from openai) (0.26.0)\n",
"Requirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.11/site-packages (from openai) (2.5.3)\n",
"Requirement already satisfied: sniffio in /usr/local/lib/python3.11/site-packages (from openai) (1.3.0)\n",
"Requirement already satisfied: tqdm>4 in /usr/local/lib/python3.11/site-packages (from openai) (4.66.1)\n",
"Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.11/site-packages (from openai) (4.9.0)\n",
"Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.11/site-packages (from anyio<5,>=3.5.0->openai) (3.6)\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.11/site-packages (from httpx<1,>=0.23.0->openai) (2023.11.17)\n",
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.11/site-packages (from httpx<1,>=0.23.0->openai) (1.0.2)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.11/site-packages (from httpcore==1.*->httpx<1,>=0.23.0->openai) (0.14.0)\n",
"Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.11/site-packages (from pydantic<3,>=1.9.0->openai) (0.6.0)\n",
"Requirement already satisfied: pydantic-core==2.14.6 in /usr/local/lib/python3.11/site-packages (from pydantic<3,>=1.9.0->openai) (2.14.6)\n"
]
}
],
"source": [
"!pip install scipy\n",
"!pip install tenacity\n",
"!pip install tiktoken\n",
"!pip install termcolor \n",
"!pip install openai\n",
"!pip install requests"
"!pip install openai"
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "dab872c5",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import openai\n",
"import requests\n",
"from openai import OpenAI\n",
"from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
"from termcolor import colored\n",
"from termcolor import colored \n",
"\n",
"openai.api_key = \"YOUR_API_KEY\" # or set via environment variable 'OPENAI_API_KEY'\n",
"\n",
"GPT_MODEL = \"gpt-3.5-turbo-0613\""
"GPT_MODEL = \"gpt-3.5-turbo-0613\"\n",
"client = OpenAI()"
]
},
{
@ -81,27 +110,19 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "745ceec5",
"metadata": {},
"outputs": [],
"source": [
"@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))\n",
"def chat_completion_request(messages, tools=None, tool_choice=None, model=GPT_MODEL):\n",
" headers = {\n",
" \"Content-Type\": \"application/json\",\n",
" \"Authorization\": \"Bearer \" + openai.api_key,\n",
" }\n",
" json_data = {\"model\": model, \"messages\": messages}\n",
" if tools is not None:\n",
" json_data.update({\"tools\": tools})\n",
" if tool_choice is not None:\n",
" json_data.update({\"tool_choice\": tool_choice})\n",
" try:\n",
" response = requests.post(\n",
" \"https://api.openai.com/v1/chat/completions\",\n",
" headers=headers,\n",
" json=json_data,\n",
" response = client.chat.completions.create(\n",
" model=model,\n",
" messages=messages,\n",
" tools=tools,\n",
" tool_choice=tool_choice,\n",
" )\n",
" return response\n",
" except Exception as e:\n",
@ -112,7 +133,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "c4d1c99f",
"metadata": {},
"outputs": [],
@ -122,7 +143,7 @@
" \"system\": \"red\",\n",
" \"user\": \"green\",\n",
" \"assistant\": \"blue\",\n",
" \"tool\": \"magenta\",\n",
" \"function\": \"magenta\",\n",
" }\n",
" \n",
" for message in messages:\n",
@ -134,7 +155,7 @@
" print(colored(f\"assistant: {message['function_call']}\\n\", role_to_color[message[\"role\"]]))\n",
" elif message[\"role\"] == \"assistant\" and not message.get(\"function_call\"):\n",
" print(colored(f\"assistant: {message['content']}\\n\", role_to_color[message[\"role\"]]))\n",
" elif message[\"role\"] == \"tool\":\n",
" elif message[\"role\"] == \"function\":\n",
" print(colored(f\"function ({message['name']}): {message['content']}\\n\", role_to_color[message[\"role\"]]))\n"
]
},
@ -151,7 +172,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "d2e25069",
"metadata": {},
"outputs": [],
@ -219,18 +240,17 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "518d6827",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant',\n",
" 'content': 'Sure, I can help you with that. Could you please tell me the city and state you are in or the location you want to know the weather for?'}"
"ChatCompletionMessage(content='Sure, I can help you with that. Could you please provide me with your location?', role='assistant', function_call=None, tool_calls=None)"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -242,7 +262,7 @@
"chat_response = chat_completion_request(\n",
" messages, tools=tools\n",
")\n",
"assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
"assistant_message = chat_response.choices[0].message\n",
"messages.append(assistant_message)\n",
"assistant_message\n"
]
@ -258,22 +278,17 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "23c42a6e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant',\n",
" 'content': None,\n",
" 'tool_calls': [{'id': 'call_o7uyztQLeVIoRdjcDkDJY3ni',\n",
" 'type': 'function',\n",
" 'function': {'name': 'get_current_weather',\n",
" 'arguments': '{\\n \"location\": \"Glasgow, Scotland\",\\n \"format\": \"celsius\"\\n}'}}]}"
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_qOYhFO7fKaU6wpG2f1XzkDjW', function=Function(arguments='{\\n \"location\": \"Glasgow, Scotland\",\\n \"format\": \"celsius\"\\n}', name='get_current_weather'), type='function')])"
]
},
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@ -283,7 +298,7 @@
"chat_response = chat_completion_request(\n",
" messages, tools=tools\n",
")\n",
"assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
"assistant_message = chat_response.choices[0].message\n",
"messages.append(assistant_message)\n",
"assistant_message\n"
]
@ -299,18 +314,17 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "fa232e54",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant',\n",
" 'content': 'Sure, I can help you with that. Please let me know the value for x.'}"
"ChatCompletionMessage(content='Sure! Please provide the number of days you would like to know the weather forecast for.', role='assistant', function_call=None, tool_calls=None)"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@ -322,7 +336,7 @@
"chat_response = chat_completion_request(\n",
" messages, tools=tools\n",
")\n",
"assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
"assistant_message = chat_response.choices[0].message\n",
"messages.append(assistant_message)\n",
"assistant_message\n"
]
@ -338,24 +352,17 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "c7d8a543",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'index': 0,\n",
" 'message': {'role': 'assistant',\n",
" 'content': None,\n",
" 'tool_calls': [{'id': 'call_drz2YpGPWEMVySzYgsWYY249',\n",
" 'type': 'function',\n",
" 'function': {'name': 'get_n_day_weather_forecast',\n",
" 'arguments': '{\\n \"location\": \"Glasgow, Scotland\",\\n \"format\": \"celsius\",\\n \"num_days\": 5\\n}'}}]},\n",
" 'finish_reason': 'tool_calls'}"
"Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_HwWHsNZsmkZUroPj6glmEgA5', function=Function(arguments='{\\n \"location\": \"Glasgow, Scotland\",\\n \"format\": \"celsius\",\\n \"num_days\": 5\\n}', name='get_n_day_weather_forecast'), type='function')]), internal_metrics=[{'cached_prompt_tokens': 0, 'total_accepted_tokens': 0, 'total_batched_tokens': 269, 'total_predicted_tokens': 0, 'total_rejected_tokens': 0, 'total_tokens_in_completion': 270, 'cached_embeddings_bytes': 0, 'cached_embeddings_n': 0, 'uncached_embeddings_bytes': 0, 'uncached_embeddings_n': 0, 'fetched_embeddings_bytes': 0, 'fetched_embeddings_n': 0, 'n_evictions': 0, 'sampling_steps': 40, 'sampling_steps_with_predictions': 0, 'batcher_ttft': 0.055008649826049805, 'batcher_initial_queue_time': 0.00098419189453125}])"
]
},
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -365,7 +372,7 @@
"chat_response = chat_completion_request(\n",
" messages, tools=tools\n",
")\n",
"chat_response.json()[\"choices\"][0]\n"
"chat_response.choices[0]\n"
]
},
{
@ -388,22 +395,17 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "559371b7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant',\n",
" 'content': None,\n",
" 'tool_calls': [{'id': 'call_jdmoJQ4lqsu4mBWcVBYtt5cU',\n",
" 'type': 'function',\n",
" 'function': {'name': 'get_n_day_weather_forecast',\n",
" 'arguments': '{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\",\\n \"num_days\": 1\\n}'}}]}"
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_240XQedt4Gi8VZsUwOvFpQfZ', function=Function(arguments='{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\",\\n \"num_days\": 1\\n}', name='get_n_day_weather_forecast'), type='function')])"
]
},
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@ -416,27 +418,22 @@
"chat_response = chat_completion_request(\n",
" messages, tools=tools, tool_choice={\"type\": \"function\", \"function\": {\"name\": \"get_n_day_weather_forecast\"}}\n",
")\n",
"chat_response.json()[\"choices\"][0][\"message\"]\n"
"chat_response.choices[0].message"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "a7ab0f58",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant',\n",
" 'content': None,\n",
" 'tool_calls': [{'id': 'call_RYXaDjxpUCfWmpXU7BZEYVqS',\n",
" 'type': 'function',\n",
" 'function': {'name': 'get_current_weather',\n",
" 'arguments': '{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\"\\n}'}}]}"
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_lQhrFlzIVPpeYG1QrSv7e3H3', function=Function(arguments='{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\"\\n}', name='get_current_weather'), type='function')])"
]
},
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
@ -449,7 +446,7 @@
"chat_response = chat_completion_request(\n",
" messages, tools=tools\n",
")\n",
"chat_response.json()[\"choices\"][0][\"message\"]\n"
"chat_response.choices[0].message"
]
},
{
@ -463,18 +460,17 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "acfe54e6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant',\n",
" 'content': '{ \"location\": \"Toronto, Canada\", \"format\": \"celsius\" }'}"
"ChatCompletionMessage(content='{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\"\\n}', role='assistant', function_call=None, tool_calls=None)"
]
},
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@ -486,11 +482,12 @@
"chat_response = chat_completion_request(\n",
" messages, tools=tools, tool_choice=\"none\"\n",
")\n",
"chat_response.json()[\"choices\"][0][\"message\"]\n"
"chat_response.choices[0].message\n"
]
},
{
"cell_type": "markdown",
"id": "b616353b",
"metadata": {},
"source": [
"### Parallel Function Calling\n",
@ -500,23 +497,18 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "380eeb68",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'id': 'call_fLsKR5vGllhbWxvpqsDT3jBj',\n",
" 'type': 'function',\n",
" 'function': {'name': 'get_n_day_weather_forecast',\n",
" 'arguments': '{\"location\": \"San Francisco, CA\", \"format\": \"celsius\", \"num_days\": 4}'}},\n",
" {'id': 'call_CchlsGE8OE03QmeyFbg7pkDz',\n",
" 'type': 'function',\n",
" 'function': {'name': 'get_n_day_weather_forecast',\n",
" 'arguments': '{\"location\": \"Glasgow\", \"format\": \"celsius\", \"num_days\": 4}'}}]"
"[ChatCompletionMessageToolCall(id='call_q8k4geh0uGPRtIfOXYPB0yM8', function=Function(arguments='{\"location\": \"San Francisco, CA\", \"format\": \"celsius\", \"num_days\": 4}', name='get_n_day_weather_forecast'), type='function'),\n",
" ChatCompletionMessageToolCall(id='call_Hdl7Py7aLswCBPptrD4y5BD3', function=Function(arguments='{\"location\": \"Glasgow\", \"format\": \"celsius\", \"num_days\": 4}', name='get_n_day_weather_forecast'), type='function')]"
]
},
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@ -529,8 +521,7 @@
" messages, tools=tools, model='gpt-3.5-turbo-1106'\n",
")\n",
"\n",
"chat_response.json()\n",
"assistant_message = chat_response.json()[\"choices\"][0][\"message\"]['tool_calls']\n",
"assistant_message = chat_response.choices[0].message.tool_calls\n",
"assistant_message"
]
},
@ -560,7 +551,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"id": "30f6b60e",
"metadata": {},
"outputs": [
@ -581,7 +572,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"id": "abec0214",
"metadata": {},
"outputs": [],
@ -624,7 +615,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"id": "0c0104cd",
"metadata": {},
"outputs": [],
@ -649,7 +640,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"id": "0258813a",
"metadata": {},
"outputs": [],
@ -693,7 +684,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"id": "65585e74",
"metadata": {},
"outputs": [],
@ -707,17 +698,17 @@
" return results\n",
"\n",
"def execute_function_call(message):\n",
" if message[\"tool_calls\"][0][\"function\"][\"name\"] == \"ask_database\":\n",
" query = json.loads(message[\"tool_calls\"][0][\"function\"][\"arguments\"])[\"query\"]\n",
" if message.tool_calls[0].function.name == \"ask_database\":\n",
" query = json.loads(message.tool_calls[0].function.arguments)[\"query\"]\n",
" results = ask_database(conn, query)\n",
" else:\n",
" results = f\"Error: function {message['tool_calls'][0]['function']['name']} does not exist\"\n",
" results = f\"Error: function {message.tool_calls[0].function.name} does not exist\"\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 19,
"id": "38c55083",
"metadata": {},
"outputs": [
@ -729,7 +720,7 @@
"\u001b[0m\n",
"\u001b[32muser: Hi, who are the top 5 artists by number of tracks?\n",
"\u001b[0m\n",
"\u001b[34massistant: {'name': 'ask_database', 'arguments': '{\\n \"query\": \"SELECT Artist.Name, COUNT(Track.TrackId) AS TrackCount FROM Artist JOIN Album ON Artist.ArtistId = Album.ArtistId JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Artist.Name ORDER BY TrackCount DESC LIMIT 5\"\\n}'}\n",
"\u001b[34massistant: Function(arguments='{\\n \"query\": \"SELECT artist.Name, COUNT(track.TrackId) AS num_tracks FROM artist JOIN album ON artist.ArtistId = album.ArtistId JOIN track ON album.AlbumId = track.AlbumId GROUP BY artist.ArtistId ORDER BY num_tracks DESC LIMIT 5\"\\n}', name='ask_database')\n",
"\u001b[0m\n",
"\u001b[35mfunction (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Lost', 92)]\n",
"\u001b[0m\n"
@ -741,18 +732,18 @@
"messages.append({\"role\": \"system\", \"content\": \"Answer user questions by generating SQL queries against the Chinook Music Database.\"})\n",
"messages.append({\"role\": \"user\", \"content\": \"Hi, who are the top 5 artists by number of tracks?\"})\n",
"chat_response = chat_completion_request(messages, tools)\n",
"assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
"assistant_message['content'] = str(assistant_message[\"tool_calls\"][0][\"function\"])\n",
"messages.append(assistant_message)\n",
"if assistant_message.get(\"tool_calls\"):\n",
"assistant_message = chat_response.choices[0].message\n",
"assistant_message.content = str(assistant_message.tool_calls[0].function)\n",
"messages.append({\"role\": assistant_message.role, \"content\": assistant_message.content})\n",
"if assistant_message.tool_calls:\n",
" results = execute_function_call(assistant_message)\n",
" messages.append({\"role\": \"tool\", \"tool_call_id\": assistant_message[\"tool_calls\"][0]['id'], \"name\": assistant_message[\"tool_calls\"][0][\"function\"][\"name\"], \"content\": results})\n",
" messages.append({\"role\": \"function\", \"tool_call_id\": assistant_message.tool_calls[0].id, \"name\": assistant_message.tool_calls[0].function.name, \"content\": results})\n",
"pretty_print_conversation(messages)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 20,
"id": "710481dc",
"metadata": {
"scrolled": true
@ -766,13 +757,13 @@
"\u001b[0m\n",
"\u001b[32muser: Hi, who are the top 5 artists by number of tracks?\n",
"\u001b[0m\n",
"\u001b[34massistant: {'name': 'ask_database', 'arguments': '{\\n \"query\": \"SELECT Artist.Name, COUNT(Track.TrackId) AS TrackCount FROM Artist JOIN Album ON Artist.ArtistId = Album.ArtistId JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Artist.Name ORDER BY TrackCount DESC LIMIT 5\"\\n}'}\n",
"\u001b[34massistant: Function(arguments='{\\n \"query\": \"SELECT artist.Name, COUNT(track.TrackId) AS num_tracks FROM artist JOIN album ON artist.ArtistId = album.ArtistId JOIN track ON album.AlbumId = track.AlbumId GROUP BY artist.ArtistId ORDER BY num_tracks DESC LIMIT 5\"\\n}', name='ask_database')\n",
"\u001b[0m\n",
"\u001b[35mfunction (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Lost', 92)]\n",
"\u001b[0m\n",
"\u001b[32muser: What is the name of the album with the most tracks?\n",
"\u001b[0m\n",
"\u001b[34massistant: {'name': 'ask_database', 'arguments': '{\\n \"query\": \"SELECT Album.Title, COUNT(Track.TrackId) AS TrackCount FROM Album JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Album.Title ORDER BY TrackCount DESC LIMIT 1\"\\n}'}\n",
"\u001b[34massistant: Function(arguments='{\\n \"query\": \"SELECT album.Title, COUNT(track.TrackId) AS num_tracks FROM album JOIN track ON album.AlbumId = track.AlbumId GROUP BY album.AlbumId ORDER BY num_tracks DESC LIMIT 1\"\\n}', name='ask_database')\n",
"\u001b[0m\n",
"\u001b[35mfunction (ask_database): [('Greatest Hits', 57)]\n",
"\u001b[0m\n"
@ -782,12 +773,12 @@
"source": [
"messages.append({\"role\": \"user\", \"content\": \"What is the name of the album with the most tracks?\"})\n",
"chat_response = chat_completion_request(messages, tools)\n",
"assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
"assistant_message['content'] = str(assistant_message[\"tool_calls\"][0][\"function\"])\n",
"messages.append(assistant_message)\n",
"if assistant_message.get(\"tool_calls\"):\n",
"assistant_message = chat_response.choices[0].message\n",
"assistant_message.content = str(assistant_message.tool_calls[0].function)\n",
"messages.append({\"role\": assistant_message.role, \"content\": assistant_message.content})\n",
"if assistant_message.tool_calls:\n",
" results = execute_function_call(assistant_message)\n",
" messages.append({\"role\": \"tool\", \"tool_call_id\": assistant_message[\"tool_calls\"][0]['id'], \"name\": assistant_message[\"tool_calls\"][0][\"function\"][\"name\"], \"content\": results})\n",
" messages.append({\"role\": \"function\", \"tool_call_id\": assistant_message.tool_calls[0].id, \"name\": assistant_message.tool_calls[0].function.name, \"content\": results})\n",
"pretty_print_conversation(messages)"
]
},
@ -801,6 +792,12 @@
"\n",
"See our other [notebook](How_to_call_functions_for_knowledge_retrieval.ipynb) that demonstrates how to use the Chat Completions API and functions for knowledge retrieval to interact conversationally with a knowledge base."
]
},
{
"cell_type": "markdown",
"id": "ec721d07",
"metadata": {},
"source": []
}
],
"metadata": {
@ -819,7 +816,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
"version": "3.12.1"
}
},
"nbformat": 4,

Loading…
Cancel
Save