From ca51ea2c4079ce0eeef564cd5a21b982fc1f426f Mon Sep 17 00:00:00 2001 From: joe-at-openai Date: Wed, 14 Jun 2023 10:48:11 -0700 Subject: [PATCH] integrate improvements to original function calling notebook + factor out arxiv example into new notebook --- ...ll_functions_for_knowledge_retrieval.ipynb | 842 ++++++++++ ...w_to_call_functions_with_chat_models.ipynb | 1390 +++++------------ 2 files changed, 1247 insertions(+), 985 deletions(-) create mode 100644 examples/How_to_call_functions_for_knowledge_retrieval.ipynb diff --git a/examples/How_to_call_functions_for_knowledge_retrieval.ipynb b/examples/How_to_call_functions_for_knowledge_retrieval.ipynb new file mode 100644 index 00000000..20b2e252 --- /dev/null +++ b/examples/How_to_call_functions_for_knowledge_retrieval.ipynb @@ -0,0 +1,842 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "3e67f200", + "metadata": {}, + "source": [ + "# How to use functions with a knowledge base\n", + "\n", + "This notebook builds on the concepts in the [argument generation]('How_to_generate_function_arguments_with_chat_models.ipynb') notebook, by creating an agent with access to a knowledge base and two functions that it can call based on the user requirement.\n", + "\n", + "We'll create an agent that uses data from arXiv to answer questions about academic subjects. It has two functions at its disposal:\n", + "- **get_articles**: A function that gets arXiv articles on a subject and summarizes them for the user with links.\n", + "- **read_article_and_summarize**: This function takes one of the previously searched articles, reads it in its entirety and summarizes the core argument, evidence and conclusions.\n", + "\n", + "This will get you comfortable with a multi-function workflow that can choose from multiple services, and where some of the data from the first function is persisted to be used by the second.\n", + "\n", + "## Walkthrough\n", + "\n", + "This cookbook takes you through the following workflow:\n", + "\n", + "- **Search utilities:** Creating the two functions that access arXiv for answers.\n", + "- **Configure Agent:** Building up the Agent behaviour that will assess the need for a function and, if one is required, call that function and present results back to the agent.\n", + "- **arXiv conversation:** Put all of this together in live conversation.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "80e71f33", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: scipy in /opt/homebrew/lib/python3.11/site-packages (1.10.1)\n", + "Requirement already satisfied: numpy<1.27.0,>=1.19.5 in /opt/homebrew/lib/python3.11/site-packages (from scipy) (1.24.3)\n", + "\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: tenacity in /opt/homebrew/lib/python3.11/site-packages (8.2.2)\n", + "\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: tiktoken in /opt/homebrew/lib/python3.11/site-packages (0.4.0)\n", + "Requirement already satisfied: regex>=2022.1.18 in /opt/homebrew/lib/python3.11/site-packages (from tiktoken) (2023.6.3)\n", + "Requirement already satisfied: requests>=2.26.0 in /opt/homebrew/lib/python3.11/site-packages (from tiktoken) (2.30.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (3.1.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (1.25.11)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (2023.5.7)\n", + "\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: termcolor in /opt/homebrew/lib/python3.11/site-packages (2.3.0)\n", + "\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: openai in /opt/homebrew/lib/python3.11/site-packages (0.27.6)\n", + "Requirement already satisfied: requests>=2.20 in /opt/homebrew/lib/python3.11/site-packages (from openai) (2.30.0)\n", + "Requirement already satisfied: tqdm in /opt/homebrew/lib/python3.11/site-packages (from openai) (4.65.0)\n", + "Requirement already satisfied: aiohttp in /opt/homebrew/lib/python3.11/site-packages (from openai) (3.8.4)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.20->openai) (3.1.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.20->openai) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.20->openai) (1.25.11)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.20->openai) (2023.5.7)\n", + "Requirement already satisfied: attrs>=17.3.0 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (23.1.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (6.0.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (4.0.2)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (1.9.2)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (1.3.3)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (1.3.1)\n", + "\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: requests in /opt/homebrew/lib/python3.11/site-packages (2.30.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/homebrew/lib/python3.11/site-packages (from requests) (3.1.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/lib/python3.11/site-packages (from requests) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/homebrew/lib/python3.11/site-packages (from requests) (1.25.11)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/lib/python3.11/site-packages (from requests) (2023.5.7)\n", + "\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: arxiv in /opt/homebrew/lib/python3.11/site-packages (1.4.7)\n", + "Requirement already satisfied: feedparser in /opt/homebrew/lib/python3.11/site-packages (from arxiv) (6.0.10)\n", + "Requirement already satisfied: sgmllib3k in /opt/homebrew/lib/python3.11/site-packages (from feedparser->arxiv) (1.0.0)\n", + "\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[0mRequirement already satisfied: pandas in /opt/homebrew/lib/python3.11/site-packages (2.0.1)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/homebrew/lib/python3.11/site-packages (from pandas) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /opt/homebrew/lib/python3.11/site-packages (from pandas) (2023.3)\n", + "Requirement already satisfied: tzdata>=2022.1 in /opt/homebrew/lib/python3.11/site-packages (from pandas) (2023.3)\n", + "Requirement already satisfied: numpy>=1.21.0 in /opt/homebrew/lib/python3.11/site-packages (from pandas) (1.24.3)\n", + "Requirement already satisfied: six>=1.5 in /opt/homebrew/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n", + "\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: PyPDF2 in /opt/homebrew/lib/python3.11/site-packages (3.0.1)\n", + "\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: tqdm in /opt/homebrew/lib/python3.11/site-packages (4.65.0)\n", + "\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install scipy\n", + "!pip install tenacity\n", + "!pip install tiktoken==0.3.3\n", + "!pip install termcolor \n", + "!pip install openai\n", + "!pip install requests\n", + "!pip install arxiv\n", + "!pip install pandas\n", + "!pip install PyPDF2\n", + "!pip install tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dab872c5", + "metadata": {}, + "outputs": [], + "source": [ + "import arxiv\n", + "import ast\n", + "import concurrent\n", + "from csv import writer\n", + "from IPython.display import display, Markdown, Latex\n", + "import json\n", + "import openai\n", + "import os\n", + "import pandas as pd\n", + "from PyPDF2 import PdfReader\n", + "import requests\n", + "from scipy import spatial\n", + "from tenacity import retry, wait_random_exponential, stop_after_attempt\n", + "import tiktoken\n", + "from tqdm import tqdm\n", + "from termcolor import colored\n", + "\n", + "GPT_MODEL = \"gpt-3.5-turbo-0613\"\n", + "EMBEDDING_MODEL = \"text-embedding-ada-002\"\n" + ] + }, + { + "cell_type": "markdown", + "id": "f2e47962", + "metadata": {}, + "source": [ + "## Search utilities\n", + "\n", + "We'll first set up some utilities that will underpin our two functions.\n", + "\n", + "Downloaded papers will be stored in a directory (we use ```./data/papers``` here). We create a file ```arxiv_library.csv``` to store the embeddings and details for downloaded papers to retrieve against using ```summarize_text```." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2de5d32d", + "metadata": {}, + "outputs": [], + "source": [ + "# Set a directory to store downloaded papers\n", + "data_dir = os.path.join(os.curdir, \"data\", \"papers\")\n", + "paper_dir_filepath = \"./data/arxiv_library.csv\"\n", + "\n", + "# Generate a blank dataframe where we can store downloaded files\n", + "df = pd.DataFrame(list())\n", + "df.to_csv(paper_dir_filepath)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "57217b9d", + "metadata": {}, + "outputs": [], + "source": [ + "@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n", + "def embedding_request(text):\n", + " response = openai.Embedding.create(input=text, model=EMBEDDING_MODEL)\n", + " return response\n", + "\n", + "\n", + "def get_articles(query, library=paper_dir_filepath, top_k=5):\n", + " \"\"\"This function gets the top_k articles based on a user's query, sorted by relevance.\n", + " It also downloads the files and stores them in arxiv_library.csv to be retrieved by the read_article_and_summarize.\n", + " \"\"\"\n", + " search = arxiv.Search(\n", + " query=query, max_results=top_k, sort_by=arxiv.SortCriterion.Relevance\n", + " )\n", + " result_list = []\n", + " for result in search.results():\n", + " result_dict = {}\n", + " result_dict.update({\"title\": result.title})\n", + " result_dict.update({\"summary\": result.summary})\n", + "\n", + " # Taking the first url provided\n", + " result_dict.update({\"article_url\": [x.href for x in result.links][0]})\n", + " result_dict.update({\"pdf_url\": [x.href for x in result.links][1]})\n", + " result_list.append(result_dict)\n", + "\n", + " # Store references in library file\n", + " response = embedding_request(text=result.title)\n", + " file_reference = [\n", + " result.title,\n", + " result.download_pdf(data_dir),\n", + " response[\"data\"][0][\"embedding\"],\n", + " ]\n", + "\n", + " # Write to file\n", + " with open(library, \"a\") as f_object:\n", + " writer_object = writer(f_object)\n", + " writer_object.writerow(file_reference)\n", + " f_object.close()\n", + " return result_list\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "dda02bdb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'title': 'Proximal Policy Optimization and its Dynamic Version for Sequence Generation',\n", + " 'summary': 'In sequence generation task, many works use policy gradient for model\\noptimization to tackle the intractable backpropagation issue when maximizing\\nthe non-differentiable evaluation metrics or fooling the discriminator in\\nadversarial learning. In this paper, we replace policy gradient with proximal\\npolicy optimization (PPO), which is a proved more efficient reinforcement\\nlearning algorithm, and propose a dynamic approach for PPO (PPO-dynamic). We\\ndemonstrate the efficacy of PPO and PPO-dynamic on conditional sequence\\ngeneration tasks including synthetic experiment and chit-chat chatbot. The\\nresults show that PPO and PPO-dynamic can beat policy gradient by stability and\\nperformance.',\n", + " 'article_url': 'http://arxiv.org/abs/1808.07982v1',\n", + " 'pdf_url': 'http://arxiv.org/pdf/1808.07982v1'}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Test that the search is working\n", + "result_output = get_articles(\"ppo reinforcement learning\")\n", + "result_output[0]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "11675627", + "metadata": {}, + "outputs": [], + "source": [ + "def strings_ranked_by_relatedness(\n", + " query: str,\n", + " df: pd.DataFrame,\n", + " relatedness_fn=lambda x, y: 1 - spatial.distance.cosine(x, y),\n", + " top_n: int = 100,\n", + ") -> list[str]:\n", + " \"\"\"Returns a list of strings and relatednesses, sorted from most related to least.\"\"\"\n", + " query_embedding_response = embedding_request(query)\n", + " query_embedding = query_embedding_response[\"data\"][0][\"embedding\"]\n", + " strings_and_relatednesses = [\n", + " (row[\"filepath\"], relatedness_fn(query_embedding, row[\"embedding\"]))\n", + " for i, row in df.iterrows()\n", + " ]\n", + " strings_and_relatednesses.sort(key=lambda x: x[1], reverse=True)\n", + " strings, relatednesses = zip(*strings_and_relatednesses)\n", + " return strings[:top_n]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7211df2c", + "metadata": {}, + "outputs": [], + "source": [ + "def read_pdf(filepath):\n", + " \"\"\"Takes a filepath to a PDF and returns a string of the PDF's contents\"\"\"\n", + " # creating a pdf reader object\n", + " reader = PdfReader(filepath)\n", + " pdf_text = \"\"\n", + " page_number = 0\n", + " for page in reader.pages:\n", + " page_number += 1\n", + " pdf_text += page.extract_text() + f\"\\nPage Number: {page_number}\"\n", + " return pdf_text\n", + "\n", + "\n", + "# Split a text into smaller chunks of size n, preferably ending at the end of a sentence\n", + "def create_chunks(text, n, tokenizer):\n", + " \"\"\"Returns successive n-sized chunks from provided text.\"\"\"\n", + " tokens = tokenizer.encode(text)\n", + " i = 0\n", + " while i < len(tokens):\n", + " # Find the nearest end of sentence within a range of 0.5 * n and 1.5 * n tokens\n", + " j = min(i + int(1.5 * n), len(tokens))\n", + " while j > i + int(0.5 * n):\n", + " # Decode the tokens and check for full stop or newline\n", + " chunk = tokenizer.decode(tokens[i:j])\n", + " if chunk.endswith(\".\") or chunk.endswith(\"\\n\"):\n", + " break\n", + " j -= 1\n", + " # If no end of sentence found, use n tokens as the chunk size\n", + " if j == i + int(0.5 * n):\n", + " j = min(i + n, len(tokens))\n", + " yield tokens[i:j]\n", + " i = j\n", + "\n", + "\n", + "def extract_chunk(content, template_prompt):\n", + " \"\"\"This function applies a prompt to some input content. In this case it returns a summarize chunk of text\"\"\"\n", + " prompt = template_prompt + content\n", + " response = openai.ChatCompletion.create(\n", + " model=GPT_MODEL, messages=[{\"role\": \"user\", \"content\": prompt}], temperature=0\n", + " )\n", + " return response[\"choices\"][0][\"message\"][\"content\"]\n", + "\n", + "\n", + "def summarize_text(query):\n", + " \"\"\"This function does the following:\n", + " - Reads in the arxiv_library.csv file in including the embeddings\n", + " - Finds the closest file to the user's query\n", + " - Scrapes the text out of the file and chunks it\n", + " - Summarizes each chunk in parallel\n", + " - Does one final summary and returns this to the user\"\"\"\n", + "\n", + " # A prompt to dictate how the recursive summarizations should approach the input paper\n", + " summary_prompt = \"\"\"Summarize this text from an academic paper. Extract any key points with reasoning.\\n\\nContent:\"\"\"\n", + "\n", + " # If the library is empty (no searches have been performed yet), we perform one and download the results\n", + " library_df = pd.read_csv(paper_dir_filepath).reset_index()\n", + " if len(library_df) == 0:\n", + " print(\"No papers searched yet, downloading first.\")\n", + " get_articles(query)\n", + " print(\"Papers downloaded, continuing\")\n", + " library_df = pd.read_csv(paper_dir_filepath).reset_index()\n", + " library_df.columns = [\"title\", \"filepath\", \"embedding\"]\n", + " library_df[\"embedding\"] = library_df[\"embedding\"].apply(ast.literal_eval)\n", + " strings = strings_ranked_by_relatedness(query, library_df, top_n=1)\n", + " print(\"Chunking text from paper\")\n", + " pdf_text = read_pdf(strings[0])\n", + "\n", + " # Initialise tokenizer\n", + " tokenizer = tiktoken.get_encoding(\"cl100k_base\")\n", + " results = \"\"\n", + "\n", + " # Chunk up the document into 1500 token chunks\n", + " chunks = create_chunks(pdf_text, 1500, tokenizer)\n", + " text_chunks = [tokenizer.decode(chunk) for chunk in chunks]\n", + " print(\"Summarizing each chunk of text\")\n", + "\n", + " # Parallel process the summaries\n", + " with concurrent.futures.ThreadPoolExecutor(\n", + " max_workers=len(text_chunks)\n", + " ) as executor:\n", + " futures = [\n", + " executor.submit(extract_chunk, chunk, summary_prompt)\n", + " for chunk in text_chunks\n", + " ]\n", + " with tqdm(total=len(text_chunks)) as pbar:\n", + " for _ in concurrent.futures.as_completed(futures):\n", + " pbar.update(1)\n", + " for future in futures:\n", + " data = future.result()\n", + " results += data\n", + "\n", + " # Final summary\n", + " print(\"Summarizing into overall summary\")\n", + " response = openai.ChatCompletion.create(\n", + " model=GPT_MODEL,\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"\"\"Write a summary collated from this collection of key points extracted from an academic paper.\n", + " The summary should highlight the core argument, conclusions and evidence, and answer the user's query.\n", + " User query: {query}\n", + " The summary should be structured in bulleted lists following the headings Core Argument, Evidence, and Conclusions.\n", + " Key points:\\n{results}\\nSummary:\\n\"\"\",\n", + " }\n", + " ],\n", + " temperature=0,\n", + " )\n", + " return response\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "898b94d4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Chunking text from paper\n", + "Summarizing each chunk of text\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00, 1.19s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Summarizing into overall summary\n" + ] + } + ], + "source": [ + "# Test the summarize_text function works\n", + "chat_test_response = summarize_text(\"PPO reinforcement learning sequence generation\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c715f60d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Core Argument:\n", + "- The paper discusses the use of Proximal Policy Optimization (PPO) in sequence generation tasks, specifically in the context of chit-chat chatbots.\n", + "- The authors argue that PPO is a more efficient reinforcement learning algorithm compared to policy gradient, commonly used in text generation tasks.\n", + "- They propose a dynamic approach for PPO (PPO-dynamic) and demonstrate its efficacy in synthetic experiments and chit-chat chatbot tasks.\n", + "\n", + "Evidence:\n", + "- PPO-dynamic achieves high precision scores comparable to other algorithms in a synthetic counting task.\n", + "- PPO-dynamic shows faster progress and more stable learning curves compared to PPO in the synthetic counting task.\n", + "- In the chit-chat chatbot task, PPO-dynamic achieves a slightly higher BLEU-2 score than other algorithms.\n", + "- PPO and PPO-dynamic have more stable learning curves and converge faster than policy gradient.\n", + "\n", + "Conclusions:\n", + "- PPO is a better optimization method for sequence learning compared to policy gradient.\n", + "- PPO-dynamic further improves the optimization process by dynamically adjusting hyperparameters.\n", + "- PPO can be used as a new optimization method for GAN-based sequence learning for better performance.\n" + ] + } + ], + "source": [ + "print(chat_test_response[\"choices\"][0][\"message\"][\"content\"])\n" + ] + }, + { + "cell_type": "markdown", + "id": "dab07e98", + "metadata": {}, + "source": [ + "## Configure Agent\n", + "\n", + "We'll create our agent in this step, including a ```Conversation``` class to support multiple turns with the API, and some Python functions to enable interaction between the ```ChatCompletion``` API and our knowledge base functions." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "77a6fb4f", + "metadata": {}, + "outputs": [], + "source": [ + "@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n", + "def chat_completion_request(messages, functions=None, model=GPT_MODEL):\n", + " headers = {\n", + " \"Content-Type\": \"application/json\",\n", + " \"Authorization\": \"Bearer \" + openai.api_key,\n", + " }\n", + " json_data = {\"model\": model, \"messages\": messages}\n", + " if functions is not None:\n", + " json_data.update({\"functions\": functions})\n", + " try:\n", + " response = requests.post(\n", + " \"https://api.openai.com/v1/chat/completions\",\n", + " headers=headers,\n", + " json=json_data,\n", + " )\n", + " return response\n", + " except Exception as e:\n", + " print(\"Unable to generate ChatCompletion response\")\n", + " print(f\"Exception: {e}\")\n", + " return e\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "73f7672d", + "metadata": {}, + "outputs": [], + "source": [ + "class Conversation:\n", + " def __init__(self):\n", + " self.conversation_history = []\n", + "\n", + " def add_message(self, role, content):\n", + " message = {\"role\": role, \"content\": content}\n", + " self.conversation_history.append(message)\n", + "\n", + " def display_conversation(self, detailed=False):\n", + " role_to_color = {\n", + " \"system\": \"red\",\n", + " \"user\": \"green\",\n", + " \"assistant\": \"blue\",\n", + " \"function\": \"magenta\",\n", + " }\n", + " for message in self.conversation_history:\n", + " print(\n", + " colored(\n", + " f\"{message['role']}: {message['content']}\\n\\n\",\n", + " role_to_color[message[\"role\"]],\n", + " )\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "978b7877", + "metadata": {}, + "outputs": [], + "source": [ + "# Initiate our get_articles and read_article_and_summarize functions\n", + "arxiv_functions = [\n", + " {\n", + " \"name\": \"get_articles\",\n", + " \"description\": \"\"\"Use this function to get academic papers from arXiv to answer user questions.\"\"\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"query\": {\n", + " \"type\": \"string\",\n", + " \"description\": f\"\"\"\n", + " User query in JSON. Responses should be summarized and should include the article URL reference\n", + " \"\"\",\n", + " }\n", + " },\n", + " \"required\": [\"query\"],\n", + " },\n", + " \"name\": \"read_article_and_summarize\",\n", + " \"description\": \"\"\"Use this function to read whole papers and provide a summary for users.\n", + " You should NEVER call this function before get_articles has been called in the conversation.\"\"\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"query\": {\n", + " \"type\": \"string\",\n", + " \"description\": f\"\"\"\n", + " Description of the article in plain text based on the user's query\n", + " \"\"\",\n", + " }\n", + " },\n", + " \"required\": [\"query\"],\n", + " },\n", + " }\n", + "]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "0c88ae15", + "metadata": {}, + "outputs": [], + "source": [ + "def chat_completion_with_function_execution(messages, functions=[None]):\n", + " \"\"\"This function makes a ChatCompletion API call with the option of adding functions\"\"\"\n", + " response = chat_completion_request(messages, functions)\n", + " full_message = response.json()[\"choices\"][0]\n", + " if full_message[\"finish_reason\"] == \"function_call\":\n", + " print(f\"Function generation requested, calling function\")\n", + " return call_arxiv_function(messages, full_message)\n", + " else:\n", + " print(f\"Function not required, responding to user\")\n", + " return response.json()\n", + "\n", + "\n", + "def call_arxiv_function(messages, full_message):\n", + " \"\"\"Function calling function which executes function calls when the model believes it is necessary.\n", + " Currently extended by adding clauses to this if statement.\"\"\"\n", + "\n", + " if full_message[\"message\"][\"function_call\"][\"name\"] == \"get_articles\":\n", + " try:\n", + " parsed_output = json.loads(\n", + " full_message[\"message\"][\"function_call\"][\"arguments\"]\n", + " )\n", + " print(\"Getting search results\")\n", + " results = get_articles(parsed_output[\"query\"])\n", + " except Exception as e:\n", + " print(parsed_output)\n", + " print(f\"Function execution failed\")\n", + " print(f\"Error message: {e}\")\n", + " messages.append(\n", + " {\n", + " \"role\": \"function\",\n", + " \"name\": full_message[\"message\"][\"function_call\"][\"name\"],\n", + " \"content\": str(results),\n", + " }\n", + " )\n", + " try:\n", + " print(\"Got search results, summarizing content\")\n", + " response = chat_completion_request(messages)\n", + " return response.json()\n", + " except Exception as e:\n", + " print(type(e))\n", + " raise Exception(\"Function chat request failed\")\n", + "\n", + " elif (\n", + " full_message[\"message\"][\"function_call\"][\"name\"] == \"read_article_and_summarize\"\n", + " ):\n", + " parsed_output = json.loads(\n", + " full_message[\"message\"][\"function_call\"][\"arguments\"]\n", + " )\n", + " print(\"Finding and reading paper\")\n", + " summary = summarize_text(parsed_output[\"query\"])\n", + " return summary\n", + "\n", + " else:\n", + " raise Exception(\"Function does not exist and cannot be called\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "dd3e7868", + "metadata": {}, + "source": [ + "## arXiv conversation\n", + "\n", + "Let's put this all together by testing our functions out in conversation." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "c39a1d80", + "metadata": {}, + "outputs": [], + "source": [ + "# Start with a system message\n", + "paper_system_message = \"\"\"You are arXivGPT, a helpful assistant pulls academic papers to answer user questions.\n", + "You summarize the papers clearly so the customer can decide which to read to answer their question.\n", + "You always provide the article_url and title so the user can understand the name of the paper and click through to access it.\n", + "Begin!\"\"\"\n", + "paper_conversation = Conversation()\n", + "paper_conversation.add_message(\"system\", paper_system_message)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "253fd0f7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Function generation requested, calling function\n", + "Finding and reading paper\n", + "Chunking text from paper\n", + "Summarizing each chunk of text\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████| 17/17 [00:06<00:00, 2.65it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Summarizing into overall summary\n" + ] + }, + { + "data": { + "text/markdown": [ + "Core Argument:\n", + "- The paper focuses on the theoretical analysis of the PPO-Clip algorithm in the context of deep reinforcement learning.\n", + "- The authors propose two core ideas: reinterpreting PPO-Clip from the perspective of hinge loss and introducing a two-step policy improvement scheme.\n", + "- The paper establishes the global convergence of PPO-Clip and characterizes its convergence rate.\n", + "\n", + "Evidence:\n", + "- The paper addresses the challenges posed by the clipping mechanism and neural function approximation.\n", + "- The authors provide theoretical proofs, lemmas, and mathematical analysis to support their arguments.\n", + "- The paper presents empirical experiments on various reinforcement learning benchmark tasks to validate the effectiveness of PPO-Clip.\n", + "\n", + "Conclusions:\n", + "- The paper offers theoretical insights into the performance of PPO-Clip and provides a framework for analyzing its convergence properties.\n", + "- PPO-Clip is shown to have a global convergence rate of O(1/sqrt(T)), where T is the number of iterations.\n", + "- The hinge loss reinterpretation of PPO-Clip allows for variants with comparable empirical performance.\n", + "- The paper contributes to a better understanding of PPO-Clip in the reinforcement learning community." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Add a user message\n", + "paper_conversation.add_message(\"user\", \"Hi, how does PPO reinforcement learning work?\")\n", + "chat_response = chat_completion_with_function_execution(\n", + " paper_conversation.conversation_history, functions=arxiv_functions\n", + ")\n", + "assistant_message = chat_response[\"choices\"][0][\"message\"][\"content\"]\n", + "paper_conversation.add_message(\"assistant\", assistant_message)\n", + "display(Markdown(assistant_message))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "3ca3e18a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Function generation requested, calling function\n", + "Finding and reading paper\n", + "Chunking text from paper\n", + "Summarizing each chunk of text\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00, 1.08s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Summarizing into overall summary\n" + ] + }, + { + "data": { + "text/markdown": [ + "Core Argument:\n", + "- The paper discusses the use of proximal policy optimization (PPO) in sequence generation tasks, specifically in the context of chit-chat chatbots.\n", + "- The authors argue that PPO is a more efficient reinforcement learning algorithm compared to policy gradient, which is commonly used in text generation tasks.\n", + "- They propose a dynamic approach for PPO (PPO-dynamic) and demonstrate its efficacy in synthetic experiments and chit-chat chatbot tasks.\n", + "\n", + "Evidence:\n", + "- The authors derive the constraints for PPO-dynamic and provide the pseudo code for both PPO and PPO-dynamic.\n", + "- They compare the performance of PPO-dynamic with other algorithms, including REINFORCE, MIXER, and SeqGAN, on a synthetic counting task and a chit-chat chatbot task using the OpenSubtitles dataset.\n", + "- In the synthetic counting task, PPO-dynamic achieves a high precision score comparable to REINFORCE and MIXER, with a faster learning curve compared to PPO.\n", + "- In the chit-chat chatbot task, PPO-dynamic achieves a slightly higher BLEU-2 score than REINFORCE and PPO, with a more stable and faster learning curve than policy gradient.\n", + "\n", + "Conclusions:\n", + "- The results suggest that PPO is a better optimization method for sequence learning compared to policy gradient.\n", + "- PPO-dynamic further improves the optimization process by dynamically adjusting the hyperparameters.\n", + "- The authors conclude that PPO can be used as a new optimization method for GAN-based sequence learning for better performance." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Add another user message to induce our system to use the second tool\n", + "paper_conversation.add_message(\n", + " \"user\",\n", + " \"Can you read the PPO sequence generation paper for me and give me a summary\",\n", + ")\n", + "updated_response = chat_completion_with_function_execution(\n", + " paper_conversation.conversation_history, functions=arxiv_functions\n", + ")\n", + "display(Markdown(updated_response[\"choices\"][0][\"message\"][\"content\"]))\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tua_test", + "language": "python", + "name": "tua_test" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/How_to_call_functions_with_chat_models.ipynb b/examples/How_to_call_functions_with_chat_models.ipynb index 0788a081..963bcd7e 100644 --- a/examples/How_to_call_functions_with_chat_models.ipynb +++ b/examples/How_to_call_functions_with_chat_models.ipynb @@ -9,28 +9,24 @@ "\n", "This notebook covers how to use the Chat Completions API in combination with external functions to extend the capabilities of GPT models.\n", "\n", - "## How to use functions\n", + "`functions` is an optional parameter in the Chat Completion API which can be used to provide function specifications. The purpose of this is to enable models to generate function arguments which adhere to the provided specifications. Note that the API will not actually execute any function calls. It is up to developers to execute function calls using model outputs.\n", "\n", - "`functions` is an optional parameter in the ChatCompletion API which can be used to provide function specifications. The purpose of this is to enable models to generate outputs which adhere to function input schemas. Note that the API will not actually execute any function calls. It is up to developers to execute function calls using model outputs.\n", + "If the `functions` parameter is provided then by default the model will decide when it is appropriate to use one of the functions. The API can be forced to use a specific function by setting the `function_call` parameter to `{\"name\": \"\"}`. The API can also be forced to not use any function by setting the `function_call` parameter to `\"none\"`. If a function is used, the output will contain `\"finish_reason\": \"function_message\"` in the response, as well as a `function_call` object that has the name of the function and the generated function arguments.\n", "\n", - "If the `functions` parameter is provided then by default the model will decide when it is appropriate to use one of the functions. The API can also be forced to use a specific function by setting the `function_call` parameter to `{\"name\": \"\"}`. If a function is used, the output will contain `\"finish_reason\": \"function_call\"` in the response, as well as a `function_call` object that has the name of the function and the generated function arguments.\n", + "### Overview\n", "\n", - "Functions are specified with the following fields:\n", + "This notebook contains the following 2 sections:\n", "\n", - "- **Name:** The name of the function.\n", - "- **Description:** A description of what the function does. The model will use this to decide when to call the function.\n", - "- **Parameters:** The parameters object contains all of the input fields the function requires. These inputs can be of the following types: String, Number, Boolean, Object, Null, AnyOf. Refer to the [API reference docs](https://platform.openai.com/docs/api-reference/chat) for details.\n", - "- **Required:** Which of the parameters are required to make a query. The rest will be treated as optional.\n", - "\n", - "You can chain function calls by executing the function and passing the output of the function execution directly back to the assistant. This can lead to _infinite loop_ behaviour where the model continues calling functions indefinitely, however guardrails can be put in place to prevent this.\n", - "\n", - "## Walkthrough\n", - "\n", - "This cookbook takes you through the following workflow:\n", - "\n", - "- **Basic concepts:** Creating an example function and getting the API to use it if appropriate.\n", - "- **Integrating API calls with function execution:** Creating an agent that uses API calls to generate function arguments and then executes the function.\n", - "- **Using multiple functions:** Allowing multiple functions to be called in sequence before responding to the user.\n" + "- **How to generate function arguments:** Specify a set of functions and use the API to generate function arguments.\n", + "- **How to call functions with model generated arguments:** Close the loop by actually executing functions with model generated arguments." + ] + }, + { + "cell_type": "markdown", + "id": "64c85e26", + "metadata": {}, + "source": [ + "## How to generate function arguments" ] }, { @@ -47,45 +43,36 @@ "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: scipy in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (1.10.1)\n", - "Requirement already satisfied: numpy<1.27.0,>=1.19.5 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from scipy) (1.24.2)\n", - "Requirement already satisfied: tenacity in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (8.2.2)\n", - "Requirement already satisfied: tiktoken in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (0.4.0)\n", - "Requirement already satisfied: regex>=2022.1.18 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from tiktoken) (2023.6.3)\n", - "Requirement already satisfied: requests>=2.26.0 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from tiktoken) (2.28.2)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (2.1.1)\n", - "Requirement already satisfied: idna<4,>=2.5 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (3.4)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (1.26.14)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (2022.12.7)\n", - "Requirement already satisfied: termcolor in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (2.3.0)\n", - "Requirement already satisfied: openai in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (0.27.7)\n", - "Requirement already satisfied: requests>=2.20 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from openai) (2.28.2)\n", - "Requirement already satisfied: tqdm in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from openai) (4.64.1)\n", - "Requirement already satisfied: aiohttp in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from openai) (3.8.3)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from requests>=2.20->openai) (2.1.1)\n", - "Requirement already satisfied: idna<4,>=2.5 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from requests>=2.20->openai) (3.4)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from requests>=2.20->openai) (1.26.14)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from requests>=2.20->openai) (2022.12.7)\n", - "Requirement already satisfied: attrs>=17.3.0 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from aiohttp->openai) (22.2.0)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from aiohttp->openai) (6.0.4)\n", - "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from aiohttp->openai) (4.0.2)\n", - "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from aiohttp->openai) (1.8.2)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from aiohttp->openai) (1.3.3)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from aiohttp->openai) (1.3.1)\n", - "Requirement already satisfied: requests in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (2.28.2)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from requests) (2.1.1)\n", - "Requirement already satisfied: idna<4,>=2.5 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from requests) (3.4)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from requests) (1.26.14)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from requests) (2022.12.7)\n", - "Requirement already satisfied: arxiv in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (1.4.7)\n", - "Requirement already satisfied: feedparser in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from arxiv) (6.0.10)\n", - "Requirement already satisfied: sgmllib3k in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from feedparser->arxiv) (1.0.0)\n", - "Requirement already satisfied: pandas in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (1.5.3)\n", - "Requirement already satisfied: python-dateutil>=2.8.1 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from pandas) (2.8.2)\n", - "Requirement already satisfied: pytz>=2020.1 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from pandas) (2022.7.1)\n", - "Requirement already satisfied: numpy>=1.21.0 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from pandas) (1.24.2)\n", - "Requirement already satisfied: six>=1.5 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (from python-dateutil>=2.8.1->pandas) (1.16.0)\n", - "Requirement already satisfied: PyPDF2 in /Users/colin.jarvis/Documents/dev/openai_scratchpad/openai_test/lib/python3.10/site-packages (3.0.1)\n" + "Requirement already satisfied: scipy in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (1.10.1)\n", + "Requirement already satisfied: numpy<1.27.0,>=1.19.5 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from scipy) (1.24.3)\n", + "Requirement already satisfied: tenacity in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (8.2.2)\n", + "Requirement already satisfied: tiktoken in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (0.4.0)\n", + "Requirement already satisfied: regex>=2022.1.18 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from tiktoken) (2023.6.3)\n", + "Requirement already satisfied: requests>=2.26.0 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from tiktoken) (2.31.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (3.1.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (2.0.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (2023.5.7)\n", + "Requirement already satisfied: termcolor in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (2.3.0)\n", + "Requirement already satisfied: openai in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (0.27.8)\n", + "Requirement already satisfied: requests>=2.20 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from openai) (2.31.0)\n", + "Requirement already satisfied: tqdm in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from openai) (4.65.0)\n", + "Requirement already satisfied: aiohttp in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from openai) (3.8.4)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from requests>=2.20->openai) (3.1.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from requests>=2.20->openai) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from requests>=2.20->openai) (2.0.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from requests>=2.20->openai) (2023.5.7)\n", + "Requirement already satisfied: attrs>=17.3.0 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from aiohttp->openai) (23.1.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from aiohttp->openai) (6.0.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from aiohttp->openai) (4.0.2)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from aiohttp->openai) (1.9.2)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from aiohttp->openai) (1.3.3)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from aiohttp->openai) (1.3.1)\n", + "Requirement already satisfied: requests in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (2.31.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from requests) (3.1.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from requests) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from requests) (2.0.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/joe/.virtualenvs/openai-cookbook/lib/python3.9/site-packages (from requests) (2023.5.7)\n" ] } ], @@ -95,10 +82,7 @@ "!pip install tiktoken\n", "!pip install termcolor \n", "!pip install openai\n", - "!pip install requests\n", - "!pip install arxiv\n", - "!pip install pandas\n", - "!pip install PyPDF2" + "!pip install requests" ] }, { @@ -108,25 +92,13 @@ "metadata": {}, "outputs": [], "source": [ - "import arxiv\n", - "import ast\n", - "import concurrent\n", - "from csv import writer\n", - "from IPython.display import display, Markdown, Latex\n", "import json\n", "import openai\n", - "import os\n", - "import pandas as pd\n", - "from PyPDF2 import PdfReader\n", "import requests\n", - "from scipy import spatial\n", "from tenacity import retry, wait_random_exponential, stop_after_attempt\n", - "import tiktoken\n", - "from tqdm import tqdm\n", "from termcolor import colored\n", "\n", - "GPT_MODEL = \"gpt-3.5-turbo-0613\"\n", - "EMBEDDING_MODEL = \"text-embedding-ada-002\"\n" + "GPT_MODEL = \"gpt-3.5-turbo-0613\"" ] }, { @@ -134,7 +106,7 @@ "id": "69ee6a93", "metadata": {}, "source": [ - "## Utilities\n", + "### Utilities\n", "\n", "First let's define a few utilities for making calls to the Chat Completions API and for maintaining and keeping track of the conversation state." ] @@ -147,7 +119,7 @@ "outputs": [], "source": [ "@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n", - "def chat_completion_request(messages, functions=None, model=GPT_MODEL):\n", + "def chat_completion_request(messages, functions=None, function_call=None, model=GPT_MODEL):\n", " headers = {\n", " \"Content-Type\": \"application/json\",\n", " \"Authorization\": \"Bearer \" + openai.api_key,\n", @@ -155,6 +127,8 @@ " json_data = {\"model\": model, \"messages\": messages}\n", " if functions is not None:\n", " json_data.update({\"functions\": functions})\n", + " if function_call is not None:\n", + " json_data.update({\"function_call\": function_call})\n", " try:\n", " response = requests.post(\n", " \"https://api.openai.com/v1/chat/completions\",\n", @@ -175,28 +149,32 @@ "metadata": {}, "outputs": [], "source": [ - "class Conversation:\n", - " def __init__(self):\n", - " self.conversation_history = []\n", - "\n", - " def add_message(self, role, content):\n", - " message = {\"role\": role, \"content\": content}\n", - " self.conversation_history.append(message)\n", - "\n", - " def display_conversation(self, detailed=False):\n", - " role_to_color = {\n", - " \"system\": \"red\",\n", - " \"user\": \"green\",\n", - " \"assistant\": \"blue\",\n", - " \"function\": \"magenta\",\n", - " }\n", - " for message in self.conversation_history:\n", - " print(\n", - " colored(\n", - " f\"{message['role']}: {message['content']}\\n\\n\",\n", - " role_to_color[message[\"role\"]],\n", - " )\n", - " )" + "def pretty_print_conversation(messages):\n", + " role_to_color = {\n", + " \"system\": \"red\",\n", + " \"user\": \"green\",\n", + " \"assistant\": \"blue\",\n", + " \"function\": \"magenta\",\n", + " }\n", + " formatted_messages = []\n", + " for message in messages:\n", + " if message[\"role\"] == \"system\":\n", + " formatted_messages.append(f\"system: {message['content']}\\n\")\n", + " elif message[\"role\"] == \"user\":\n", + " formatted_messages.append(f\"user: {message['content']}\\n\")\n", + " elif message[\"role\"] == \"assistant\" and message.get(\"function_call\"):\n", + " formatted_messages.append(f\"assistant: {message['function_call']}\\n\")\n", + " elif message[\"role\"] == \"assistant\" and not message.get(\"function_call\"):\n", + " formatted_messages.append(f\"assistant: {message['content']}\\n\")\n", + " elif message[\"role\"] == \"function\":\n", + " formatted_messages.append(f\"function ({message['name']}): {message['content']}\\n\")\n", + " for formatted_message in formatted_messages:\n", + " print(\n", + " colored(\n", + " formatted_message,\n", + " role_to_color[messages[formatted_messages.index(formatted_message)][\"role\"]],\n", + " )\n", + " )" ] }, { @@ -204,9 +182,9 @@ "id": "29d4e02b", "metadata": {}, "source": [ - "## Basic concepts\n", + "### Basic concepts\n", "\n", - "Next we'll create a specification for a function called ```get_current_weather```. Later we'll pass this function specification to the API in order to generate function arguments that adhere to the specification." + "Let's create some function specifications to interface with a hypothetical weather API. We'll pass these function specification to the Chat Completions API in order to generate function arguments that adhere to the specification." ] }, { @@ -235,32 +213,92 @@ " },\n", " \"required\": [\"location\", \"format\"],\n", " },\n", - " }\n", - "]\n" + " },\n", + " {\n", + " \"name\": \"get_n_day_weather_forecast\",\n", + " \"description\": \"Get an N-day weather forecast\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", + " },\n", + " \"format\": {\n", + " \"type\": \"string\",\n", + " \"enum\": [\"celsius\", \"fahrenheit\"],\n", + " \"description\": \"The temperature unit to use. Infer this from the users location.\",\n", + " },\n", + " \"num_days\": {\n", + " \"type\": \"integer\",\n", + " \"description\": \"The number of days to forecast\",\n", + " }\n", + " },\n", + " \"required\": [\"location\", \"format\", \"num_days\"]\n", + " },\n", + " },\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "bfc39899", + "metadata": {}, + "source": [ + "If we prompt the model about the current weather, it will respond with some clarifying questions." ] }, { "cell_type": "code", "execution_count": 6, - "id": "131d9a33", - "metadata": {}, - "outputs": [], - "source": [ - "conversation = Conversation()\n", - "conversation.add_message(\"user\", \"what is the weather like today\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "12ed2515", + "id": "518d6827", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'role': 'assistant',\n", - " 'content': 'Sure, can you please provide me with your location or the city you want to know the weather for?'}" + " 'content': \"Sure, could you please tell me the city and state for which you'd like to know the weather?\"}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages = []\n", + "messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n", + "messages.append({\"role\": \"user\", \"content\": \"What's the weather like today\"})\n", + "chat_response = chat_completion_request(\n", + " messages, functions=functions\n", + ")\n", + "assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n", + "messages.append(assistant_message)\n", + "assistant_message\n" + ] + }, + { + "cell_type": "markdown", + "id": "4c999375", + "metadata": {}, + "source": [ + "Once we provide the missing information, it will generate the appropriate function arguments for us." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "23c42a6e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'role': 'assistant',\n", + " 'content': None,\n", + " 'function_call': {'name': 'get_current_weather',\n", + " 'arguments': '{\\n \"location\": \"Glasgow, Scotland\",\\n \"format\": \"celsius\"\\n}'}}" ] }, "execution_count": 7, @@ -269,30 +307,34 @@ } ], "source": [ - "# The model first prompts the user for the information it needs to use the weather function\n", + "messages.append({\"role\": \"user\", \"content\": \"I'm in Glasgow, Scotland.\"})\n", "chat_response = chat_completion_request(\n", - " conversation.conversation_history, functions=functions\n", + " messages, functions=functions\n", ")\n", "assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n", - "conversation.add_message(assistant_message[\"role\"], assistant_message[\"content\"])\n", + "messages.append(assistant_message)\n", "assistant_message\n" ] }, + { + "cell_type": "markdown", + "id": "c14d4762", + "metadata": {}, + "source": [ + "By prompting it differently, we can get it to target the other function we've told it about." + ] + }, { "cell_type": "code", "execution_count": 8, - "id": "854b6e61", + "id": "fa232e54", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'index': 0,\n", - " 'message': {'role': 'assistant',\n", - " 'content': None,\n", - " 'function_call': {'name': 'get_current_weather',\n", - " 'arguments': '{\\n \"location\": \"Glasgow, Scotland\",\\n \"format\": \"celsius\"\\n}'}},\n", - " 'finish_reason': 'function_call'}" + "{'role': 'assistant',\n", + " 'content': 'Sure, I can help you with that. Just tell me how many days you would like to forecast for.'}" ] }, "execution_count": 8, @@ -301,24 +343,179 @@ } ], "source": [ - "# Once the user provides the required information, the model can generate the function arguments\n", - "conversation.add_message(\"user\", \"I'm in Glasgow, Scotland\")\n", + "messages = []\n", + "messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n", + "messages.append({\"role\": \"user\", \"content\": \"what is the weather going to be like in Glasgow, Scotland over the next x days\"})\n", "chat_response = chat_completion_request(\n", - " conversation.conversation_history, functions=functions\n", + " messages, functions=functions\n", + ")\n", + "assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n", + "messages.append(assistant_message)\n", + "assistant_message\n" + ] + }, + { + "cell_type": "markdown", + "id": "6172ddac", + "metadata": {}, + "source": [ + "Once again, the model is asking us for clarification because it doesn't have enough information yet. In this case it already knows the location for the forecast, but it needs to know how many days are required in the forecast." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c7d8a543", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'index': 0,\n", + " 'message': {'role': 'assistant',\n", + " 'content': None,\n", + " 'function_call': {'name': 'get_n_day_weather_forecast',\n", + " 'arguments': '{\\n \"location\": \"Glasgow, Scotland\",\\n \"format\": \"celsius\",\\n \"num_days\": 5\\n}'}},\n", + " 'finish_reason': 'function_call'}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages.append({\"role\": \"user\", \"content\": \"5 days\"})\n", + "chat_response = chat_completion_request(\n", + " messages, functions=functions\n", ")\n", "chat_response.json()[\"choices\"][0]\n" ] }, + { + "cell_type": "markdown", + "id": "4b758a0a", + "metadata": {}, + "source": [ + "#### Forcing the use of specific functions or no function" + ] + }, + { + "cell_type": "markdown", + "id": "412f79ba", + "metadata": {}, + "source": [ + "We can force the model to use a specific function, for example get_n_day_weather_forecast by using the function_call argument. By doing so, we force the model to make assumptions about how to use it." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "559371b7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'role': 'assistant',\n", + " 'content': None,\n", + " 'function_call': {'name': 'get_n_day_weather_forecast',\n", + " 'arguments': '{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\",\\n \"num_days\": 1\\n}'}}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# in this cell we force the model to use get_n_day_weather_forecast\n", + "messages = []\n", + "messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n", + "messages.append({\"role\": \"user\", \"content\": \"Give me a weather report for Toronto, Canada.\"})\n", + "chat_response = chat_completion_request(\n", + " messages, functions=functions, function_call={\"name\": \"get_n_day_weather_forecast\"}\n", + ")\n", + "chat_response.json()[\"choices\"][0][\"message\"]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a7ab0f58", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'role': 'assistant',\n", + " 'content': None,\n", + " 'function_call': {'name': 'get_current_weather',\n", + " 'arguments': '{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\"\\n}'}}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# if we don't force the model to use get_n_day_weather_forecast it may not\n", + "messages = []\n", + "messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n", + "messages.append({\"role\": \"user\", \"content\": \"Give me a weather report for Toronto, Canada.\"})\n", + "chat_response = chat_completion_request(\n", + " messages, functions=functions\n", + ")\n", + "chat_response.json()[\"choices\"][0][\"message\"]\n" + ] + }, + { + "cell_type": "markdown", + "id": "3bd70e48", + "metadata": {}, + "source": [ + "We can also force the model to not use a function at all. By doing so we prevent it from producing a proper function call." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "acfe54e6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'role': 'assistant',\n", + " 'content': '{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\"\\n}'}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages = []\n", + "messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n", + "messages.append({\"role\": \"user\", \"content\": \"Give me the current weather (use Celcius) for Toronto, Canada.\"})\n", + "chat_response = chat_completion_request(\n", + " messages, functions=functions, function_call=\"none\"\n", + ")\n", + "chat_response.json()[\"choices\"][0][\"message\"]\n" + ] + }, { "cell_type": "markdown", "id": "b4482aee", "metadata": {}, "source": [ - "## Integrating API calls with function execution\n", + "## How to call functions with model generated arguments\n", "\n", "In our next example, we'll demonstrate how to execute functions whose inputs are model-generated, and use this to implement an agent that can answer questions for us about a database. For simplicity we'll use the [Chinook sample database](https://www.sqlitetutorial.net/sqlite-sample-database/).\n", "\n", - "*Note:* SQL generation use cases are high-risk in a production environment - models can be unreliable when generating consistent SQL syntax. A more reliable way to solve this problem may be to build a query generation API that takes the desired columns as input from the model." + "*Note:* SQL generation can be high-risk in a production environment since models are not perfectly reliable at generating correct SQL." ] }, { @@ -326,14 +523,14 @@ "id": "f7654fef", "metadata": {}, "source": [ - "### Pull SQL Database Info\n", + "### Specifying a function to execute SQL queries\n", "\n", "First let's define some helpful utility functions to extract data from a SQLite database." ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 13, "id": "30f6b60e", "metadata": {}, "outputs": [ @@ -349,12 +546,12 @@ "import sqlite3\n", "\n", "conn = sqlite3.connect(\"data/Chinook.db\")\n", - "print(\"Opened database successfully\")\n" + "print(\"Opened database successfully\")" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 14, "id": "abec0214", "metadata": {}, "outputs": [], @@ -396,7 +593,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 15, "id": "0c0104cd", "metadata": {}, "outputs": [], @@ -407,7 +604,7 @@ " f\"Table: {table['table_name']}\\nColumns: {', '.join(table['column_names'])}\"\n", " for table in database_schema_dict\n", " ]\n", - ")\n" + ")" ] }, { @@ -420,7 +617,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 16, "id": "0258813a", "metadata": {}, "outputs": [], @@ -428,7 +625,7 @@ "functions = [\n", " {\n", " \"name\": \"ask_database\",\n", - " \"description\": \"Use this function to answer user questions about music. Input should be a fully formed SQL query.\",\n", + " \"description\": \"Use this function to answer user questions about music. Output should be a fully formed SQL query.\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", @@ -445,7 +642,7 @@ " \"required\": [\"query\"],\n", " },\n", " }\n", - "]\n" + "]" ] }, { @@ -453,903 +650,126 @@ "id": "da08c121", "metadata": {}, "source": [ - "### SQL execution\n", + "### Executing SQL queries\n", "\n", - "Now let's implement the function that the agent will use to query the database. We also need to implement utilities to integrate the calls to the Chat Completions API with the function it is calling." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "65585e74", - "metadata": {}, - "outputs": [], - "source": [ - "def ask_database(conn, query):\n", - " \"\"\"Function to query SQLite database with provided SQL query.\"\"\"\n", - " try:\n", - " results = conn.execute(query).fetchall()\n", - " return results\n", - " except Exception as e:\n", - " raise Exception(f\"SQL error: {e}\")\n", - "\n", - "\n", - "def chat_completion_with_function_execution(messages, functions=None):\n", - " \"\"\"This function makes a ChatCompletion API call and if a function call is requested, executes the function\"\"\"\n", - " try:\n", - " response = chat_completion_request(messages, functions)\n", - " full_message = response.json()[\"choices\"][0]\n", - " if full_message[\"finish_reason\"] == \"function_call\":\n", - " print(f\"Function generation requested, calling function\")\n", - " return call_function(messages, full_message)\n", - " else:\n", - " print(f\"Function not required, responding to user\")\n", - " return response.json()\n", - " except Exception as e:\n", - " print(\"Unable to generate ChatCompletion response\")\n", - " print(f\"Exception: {e}\")\n", - " return response\n", - "\n", - "\n", - "def call_function(messages, full_message):\n", - " \"\"\"Executes function calls using model generated function arguments.\"\"\"\n", - "\n", - " # We'll add our one function here - this can be extended with any additional functions\n", - " if full_message[\"message\"][\"function_call\"][\"name\"] == \"ask_database\":\n", - " query = eval(full_message[\"message\"][\"function_call\"][\"arguments\"])\n", - " print(f\"Prepped query is {query}\")\n", - " try:\n", - " results = ask_database(conn, query[\"query\"])\n", - " except Exception as e:\n", - " print(e)\n", - "\n", - " # This following block tries to fix any issues in query generation with a subsequent call\n", - " messages.append(\n", - " {\n", - " \"role\": \"system\",\n", - " \"content\": f\"\"\"Query: {query['query']}\n", - "The previous query received the error {e}. \n", - "Please return a fixed SQL query in plain text.\n", - "Your response should consist of ONLY the SQL query with the separator sql_start at the beginning and sql_end at the end\"\"\",\n", - " }\n", - " )\n", - " response = chat_completion_request(messages, model=\"gpt-4-0613\")\n", - "\n", - " # Retrying with the fixed SQL query. If it fails a second time we exit.\n", - " try:\n", - " cleaned_query = response.json()[\"choices\"][0][\"message\"][\n", - " \"content\"\n", - " ].split(\"sql_start\")[1]\n", - " cleaned_query = cleaned_query.split(\"sql_end\")[0]\n", - " print(cleaned_query)\n", - " results = ask_database(conn, cleaned_query)\n", - " print(results)\n", - " print(\"Got on second try\")\n", - "\n", - " except Exception as e:\n", - " print(\"Second failure, exiting\")\n", - "\n", - " print(f\"Function execution failed\")\n", - " print(f\"Error message: {e}\")\n", - "\n", - " messages.append(\n", - " {\"role\": \"function\", \"name\": \"ask_database\", \"content\": str(results)}\n", - " )\n", - "\n", - " try:\n", - " response = chat_completion_request(messages)\n", - " return response.json()\n", - " except Exception as e:\n", - " print(type(e))\n", - " print(e)\n", - " raise Exception(\"Function chat request failed\")\n", - " else:\n", - " raise Exception(\"Function does not exist and cannot be called\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "38c55083", - "metadata": {}, - "outputs": [], - "source": [ - "agent_system_message = \"\"\"You are ChinookGPT, a helpful assistant who gets answers to user questions from the Chinook Music Database.\n", - "Provide as many details as possible to your users\n", - "Begin!\"\"\"\n", - "\n", - "sql_conversation = Conversation()\n", - "sql_conversation.add_message(\"system\", agent_system_message)\n", - "sql_conversation.add_message(\n", - " \"user\", \"Hi, who are the top 5 artists by number of tracks\"\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "a2e5338e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Function generation requested, calling function\n", - "Prepped query is {'query': 'SELECT a.Name AS Artist, COUNT(t.TrackId) AS NumTracks FROM Artist a JOIN Album al ON a.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId GROUP BY a.Name ORDER BY NumTracks DESC LIMIT 5'}\n", - "The top 5 artists in the Chinook Music Database based on the number of tracks they have are:\n", - "\n", - "1. Iron Maiden - 213 tracks\n", - "2. U2 - 135 tracks\n", - "3. Led Zeppelin - 114 tracks\n", - "4. Metallica - 112 tracks\n", - "5. Lost - 92 tracks\n" - ] - } - ], - "source": [ - "chat_response = chat_completion_with_function_execution(\n", - " sql_conversation.conversation_history, functions=functions\n", - ")\n", - "try:\n", - " assistant_message = chat_response[\"choices\"][0][\"message\"][\"content\"]\n", - " print(assistant_message)\n", - "except Exception as e:\n", - " print(e)\n", - " print(chat_response)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "28471ddc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[31msystem: You are ChinookGPT, a helpful assistant who gets answers to user questions from the Chinook Music Database.\n", - "Provide as many details as possible to your users\n", - "Begin!\n", - "\n", - "\u001b[0m\n", - "\u001b[32muser: Hi, who are the top 5 artists by number of tracks\n", - "\n", - "\u001b[0m\n", - "\u001b[35mfunction: [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Lost', 92)]\n", - "\n", - "\u001b[0m\n", - "\u001b[34massistant: The top 5 artists in the Chinook Music Database based on the number of tracks they have are:\n", - "\n", - "1. Iron Maiden - 213 tracks\n", - "2. U2 - 135 tracks\n", - "3. Led Zeppelin - 114 tracks\n", - "4. Metallica - 112 tracks\n", - "5. Lost - 92 tracks\n", - "\n", - "\u001b[0m\n" - ] - } - ], - "source": [ - "sql_conversation.add_message(\"assistant\", assistant_message)\n", - "sql_conversation.display_conversation(detailed=True)\n" + "Now let's implement the function that will actually excute queries against the database." ] }, { "cell_type": "code", "execution_count": 17, - "id": "710481dc", + "id": "65585e74", "metadata": {}, "outputs": [], "source": [ - "sql_conversation.add_message(\n", - " \"user\", \"What is the name of the album with the most tracks\"\n", - ")\n" + "def ask_database(conn, query):\n", + " \"\"\"Function to query SQLite database with a provided SQL query.\"\"\"\n", + " try:\n", + " results = str(conn.execute(query).fetchall())\n", + " except Exception as e:\n", + " results = f\"query failed with error: {e}\"\n", + " return results\n", + "\n", + "def execute_function_call(message):\n", + " if message[\"function_call\"][\"name\"] == \"ask_database\":\n", + " query = eval(message[\"function_call\"][\"arguments\"])[\"query\"]\n", + " results = ask_database(conn, query)\n", + " else:\n", + " results = f\"Error: function {message['function_call']['name']} does not exist\"\n", + " return results" ] }, { "cell_type": "code", "execution_count": 18, - "id": "40f954e2", + "id": "38c55083", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Function generation requested, calling function\n", - "Prepped query is {'query': 'SELECT AlbumId, Title, COUNT(TrackId) AS TrackCount FROM Album GROUP BY AlbumId ORDER BY TrackCount DESC LIMIT 1;'}\n", - "SQL error: no such column: TrackId\n", - "\n", - "SELECT a.Title, COUNT(t.TrackId) as TrackCount\n", - "FROM Album a\n", - "JOIN Track t ON a.AlbumId = t.AlbumId\n", - "GROUP BY a.AlbumId, a.Title\n", - "ORDER BY TrackCount DESC\n", - "LIMIT 1;\n", - "\n", - "[('Greatest Hits', 57)]\n", - "Got on second try\n" + "\u001B[31msystem: Answer user questions by generating SQL queries against the Chinook Music Database.\n", + "\u001B[0m\n", + "\u001B[32muser: Hi, who are the top 5 artists by number of tracks?\n", + "\u001B[0m\n", + "\u001B[34massistant: {'name': 'ask_database', 'arguments': '{\\n \"query\": \"SELECT Artist.Name, COUNT(Track.TrackId) AS TrackCount FROM Artist INNER JOIN Album ON Artist.ArtistId = Album.ArtistId INNER JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Artist.Name ORDER BY TrackCount DESC LIMIT 5;\"\\n}'}\n", + "\u001B[0m\n", + "\u001B[35mfunction (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Lost', 92)]\n", + "\u001B[0m\n" ] - }, - { - "data": { - "text/plain": [ - "'The album with the most tracks in the Chinook Music Database is \"Greatest Hits\" with a total of 57 tracks.'" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ - "chat_response = chat_completion_with_function_execution(\n", - " sql_conversation.conversation_history, functions=functions\n", - ")\n", - "assistant_message = chat_response[\"choices\"][0][\"message\"][\"content\"]\n", - "assistant_message\n" + "messages = []\n", + "messages.append({\"role\": \"system\", \"content\": \"Answer user questions by generating SQL queries against the Chinook Music Database.\"})\n", + "messages.append({\"role\": \"user\", \"content\": \"Hi, who are the top 5 artists by number of tracks?\"})\n", + "chat_response = chat_completion_request(messages, functions)\n", + "assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n", + "messages.append(assistant_message)\n", + "if assistant_message.get(\"function_call\"):\n", + " results = execute_function_call(assistant_message)\n", + " messages.append({\"role\": \"function\", \"name\": assistant_message[\"function_call\"][\"name\"], \"content\": results})\n", + "pretty_print_conversation(messages)" ] }, { "cell_type": "code", "execution_count": 19, - "id": "df2518ae", - "metadata": {}, - "outputs": [], - "source": [ - "sql_conversation.add_message(\"assistant\", assistant_message)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "13984dcb", - "metadata": {}, + "id": "710481dc", + "metadata": { + "scrolled": true + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[31msystem: You are ChinookGPT, a helpful assistant who gets answers to user questions from the Chinook Music Database.\n", - "Provide as many details as possible to your users\n", - "Begin!\n", - "\n", - "\u001b[0m\n", - "\u001b[32muser: Hi, who are the top 5 artists by number of tracks\n", - "\n", - "\u001b[0m\n", - "\u001b[35mfunction: [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Lost', 92)]\n", - "\n", - "\u001b[0m\n", - "\u001b[34massistant: The top 5 artists in the Chinook Music Database based on the number of tracks they have are:\n", - "\n", - "1. Iron Maiden - 213 tracks\n", - "2. U2 - 135 tracks\n", - "3. Led Zeppelin - 114 tracks\n", - "4. Metallica - 112 tracks\n", - "5. Lost - 92 tracks\n", - "\n", - "\u001b[0m\n", - "\u001b[32muser: What is the name of the album with the most tracks\n", - "\n", - "\u001b[0m\n", - "\u001b[31msystem: Query: SELECT AlbumId, Title, COUNT(TrackId) AS TrackCount FROM Album GROUP BY AlbumId ORDER BY TrackCount DESC LIMIT 1;\n", - "The previous query received the error SQL error: no such column: TrackId. \n", - "Please return a fixed SQL query in plain text.\n", - "Your response should consist of ONLY the SQL query with the separator sql_start at the beginning and sql_end at the end\n", - "\n", - "\u001b[0m\n", - "\u001b[35mfunction: [('Greatest Hits', 57)]\n", - "\n", - "\u001b[0m\n", - "\u001b[34massistant: The album with the most tracks in the Chinook Music Database is \"Greatest Hits\" with a total of 57 tracks.\n", - "\n", - "\u001b[0m\n" + "\u001B[31msystem: Answer user questions by generating SQL queries against the Chinook Music Database.\n", + "\u001B[0m\n", + "\u001B[32muser: Hi, who are the top 5 artists by number of tracks?\n", + "\u001B[0m\n", + "\u001B[34massistant: {'name': 'ask_database', 'arguments': '{\\n \"query\": \"SELECT Artist.Name, COUNT(Track.TrackId) AS TrackCount FROM Artist INNER JOIN Album ON Artist.ArtistId = Album.ArtistId INNER JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Artist.Name ORDER BY TrackCount DESC LIMIT 5;\"\\n}'}\n", + "\u001B[0m\n", + "\u001B[35mfunction (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Lost', 92)]\n", + "\u001B[0m\n", + "\u001B[32muser: What is the name of the album with the most tracks?\n", + "\u001B[0m\n", + "\u001B[34massistant: {'name': 'ask_database', 'arguments': '{\\n \"query\": \"SELECT Album.Title, COUNT(Track.TrackId) AS TrackCount FROM Album INNER JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Album.Title ORDER BY TrackCount DESC LIMIT 1;\"\\n}'}\n", + "\u001B[0m\n", + "\u001B[35mfunction (ask_database): [('Greatest Hits', 57)]\n", + "\u001B[0m\n" ] } ], "source": [ - "sql_conversation.display_conversation(detailed=True)\n" + "messages.append({\"role\": \"user\", \"content\": \"What is the name of the album with the most tracks?\"})\n", + "chat_response = chat_completion_request(messages, functions)\n", + "assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n", + "messages.append(assistant_message)\n", + "if assistant_message.get(\"function_call\"):\n", + " results = execute_function_call(assistant_message)\n", + " messages.append({\"role\": \"function\", \"content\": results, \"name\": assistant_message[\"function_call\"][\"name\"]})\n", + "pretty_print_conversation(messages)" ] }, { "cell_type": "markdown", - "id": "c282b9e3", + "id": "2d89073c", "metadata": {}, "source": [ - "## Using Multiple Functions\n", + "## Next Steps\n", "\n", - "Now let's construct a scenario in which we provide a model with more than one function to call. We'll create an agent that uses data from arXiv to answer questions about academic subjects. It has two new functions at its disposal:\n", - "- **get_articles**: A function that gets arXiv articles on a subject and summarizes them for the user with links.\n", - "- **read_article_and_summarize**: This function takes one of the previously searched articles, reads it in its entirety and summarizes the core argument, evidence and conclusions.\n", - "\n", - "This will get you comfortable with a multi-function workflow that can choose from multiple services, and where some of the data from the first function is persisted to be used by the second." - ] - }, - { - "cell_type": "markdown", - "id": "f2e47962", - "metadata": {}, - "source": [ - "### arXiv search\n", - "\n", - "We'll first set up some utilities that will underpin our two functions.\n", - "\n", - "Downloaded papers will be stored in a directory (we use ```./data/papers``` here). We create a file ```arxiv_library.csv``` to store the embeddings and details for downloaded papers to retrieve against using ```summarize_text```." + "See our other [notebook](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_call_functions_for_knowledge_retrieval.ipynb) that demonstrates how to use the Chat Completions API and functions for knowledge retrieval to interact conversationally with a knowledge base." ] }, { "cell_type": "code", - "execution_count": 21, - "id": "2de5d32d", + "execution_count": null, + "id": "c83138e8", "metadata": {}, "outputs": [], - "source": [ - "# Set a directory to store downloaded papers\n", - "data_dir = os.path.join(os.curdir, \"data\", \"papers\")\n", - "paper_dir_filepath = \"./data/arxiv_library.csv\"\n", - "\n", - "# Generate a blank dataframe where we can store downloaded files\n", - "df = pd.DataFrame(list())\n", - "df.to_csv(paper_dir_filepath)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "57217b9d", - "metadata": {}, - "outputs": [], - "source": [ - "@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n", - "def embedding_request(text):\n", - " response = openai.Embedding.create(input=text, model=EMBEDDING_MODEL)\n", - " return response\n", - "\n", - "\n", - "def get_articles(query, library=paper_dir_filepath, top_k=5):\n", - " \"\"\"This function gets the top_k articles based on a user's query, sorted by relevance.\n", - " It also downloads the files and stores them in arxiv_library.csv to be retrieved by the read_article_and_summarize.\n", - " \"\"\"\n", - " search = arxiv.Search(\n", - " query=query, max_results=top_k, sort_by=arxiv.SortCriterion.Relevance\n", - " )\n", - " result_list = []\n", - " for result in search.results():\n", - " result_dict = {}\n", - " result_dict.update({\"title\": result.title})\n", - " result_dict.update({\"summary\": result.summary})\n", - "\n", - " # Taking the first url provided\n", - " result_dict.update({\"article_url\": [x.href for x in result.links][0]})\n", - " result_dict.update({\"pdf_url\": [x.href for x in result.links][1]})\n", - " result_list.append(result_dict)\n", - "\n", - " # Store references in library file\n", - " response = embedding_request(text=result.title)\n", - " file_reference = [\n", - " result.title,\n", - " result.download_pdf(data_dir),\n", - " response[\"data\"][0][\"embedding\"],\n", - " ]\n", - "\n", - " # Write to file\n", - " with open(library, \"a\") as f_object:\n", - " writer_object = writer(f_object)\n", - " writer_object.writerow(file_reference)\n", - " f_object.close()\n", - " return result_list\n" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "dda02bdb", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'title': 'Proximal Policy Optimization and its Dynamic Version for Sequence Generation',\n", - " 'summary': 'In sequence generation task, many works use policy gradient for model\\noptimization to tackle the intractable backpropagation issue when maximizing\\nthe non-differentiable evaluation metrics or fooling the discriminator in\\nadversarial learning. In this paper, we replace policy gradient with proximal\\npolicy optimization (PPO), which is a proved more efficient reinforcement\\nlearning algorithm, and propose a dynamic approach for PPO (PPO-dynamic). We\\ndemonstrate the efficacy of PPO and PPO-dynamic on conditional sequence\\ngeneration tasks including synthetic experiment and chit-chat chatbot. The\\nresults show that PPO and PPO-dynamic can beat policy gradient by stability and\\nperformance.',\n", - " 'article_url': 'http://arxiv.org/abs/1808.07982v1',\n", - " 'pdf_url': 'http://arxiv.org/pdf/1808.07982v1'}" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Test that the search is working\n", - "result_output = get_articles(\"ppo reinforcement learning\")\n", - "result_output[0]\n" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "11675627", - "metadata": {}, - "outputs": [], - "source": [ - "def strings_ranked_by_relatedness(\n", - " query: str,\n", - " df: pd.DataFrame,\n", - " relatedness_fn=lambda x, y: 1 - spatial.distance.cosine(x, y),\n", - " top_n: int = 100,\n", - ") -> list[str]:\n", - " \"\"\"Returns a list of strings and relatednesses, sorted from most related to least.\"\"\"\n", - " query_embedding_response = embedding_request(query)\n", - " query_embedding = query_embedding_response[\"data\"][0][\"embedding\"]\n", - " strings_and_relatednesses = [\n", - " (row[\"filepath\"], relatedness_fn(query_embedding, row[\"embedding\"]))\n", - " for i, row in df.iterrows()\n", - " ]\n", - " strings_and_relatednesses.sort(key=lambda x: x[1], reverse=True)\n", - " strings, relatednesses = zip(*strings_and_relatednesses)\n", - " return strings[:top_n]\n" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "7211df2c", - "metadata": {}, - "outputs": [], - "source": [ - "def read_pdf(filepath):\n", - " \"\"\"Takes a filepath to a PDF and returns a string of the PDF's contents\"\"\"\n", - " # creating a pdf reader object\n", - " reader = PdfReader(filepath)\n", - " pdf_text = \"\"\n", - " page_number = 0\n", - " for page in reader.pages:\n", - " page_number += 1\n", - " pdf_text += page.extract_text() + f\"\\nPage Number: {page_number}\"\n", - " return pdf_text\n", - "\n", - "\n", - "# Split a text into smaller chunks of size n, preferably ending at the end of a sentence\n", - "def create_chunks(text, n, tokenizer):\n", - " \"\"\"Returns successive n-sized chunks from provided text.\"\"\"\n", - " tokens = tokenizer.encode(text)\n", - " i = 0\n", - " while i < len(tokens):\n", - " # Find the nearest end of sentence within a range of 0.5 * n and 1.5 * n tokens\n", - " j = min(i + int(1.5 * n), len(tokens))\n", - " while j > i + int(0.5 * n):\n", - " # Decode the tokens and check for full stop or newline\n", - " chunk = tokenizer.decode(tokens[i:j])\n", - " if chunk.endswith(\".\") or chunk.endswith(\"\\n\"):\n", - " break\n", - " j -= 1\n", - " # If no end of sentence found, use n tokens as the chunk size\n", - " if j == i + int(0.5 * n):\n", - " j = min(i + n, len(tokens))\n", - " yield tokens[i:j]\n", - " i = j\n", - "\n", - "\n", - "def extract_chunk(content, template_prompt):\n", - " \"\"\"This function applies a prompt to some input content. In this case it returns a summarize chunk of text\"\"\"\n", - " prompt = template_prompt + content\n", - " response = openai.ChatCompletion.create(\n", - " model=GPT_MODEL, messages=[{\"role\": \"user\", \"content\": prompt}], temperature=0\n", - " )\n", - " return response[\"choices\"][0][\"message\"][\"content\"]\n", - "\n", - "\n", - "def summarize_text(query):\n", - " \"\"\"This function does the following:\n", - " - Reads in the arxiv_library.csv file in including the embeddings\n", - " - Finds the closest file to the user's query\n", - " - Scrapes the text out of the file and chunks it\n", - " - Summarizes each chunk in parallel\n", - " - Does one final summary and returns this to the user\"\"\"\n", - "\n", - " # A prompt to dictate how the recursive summarizations should approach the input paper\n", - " summary_prompt = \"\"\"Summarize this text from an academic paper. Extract any key points with reasoning.\\n\\nContent:\"\"\"\n", - "\n", - " # If the library is empty (no searches have been performed yet), we perform one and download the results\n", - " library_df = pd.read_csv(paper_dir_filepath).reset_index()\n", - " if len(library_df) == 0:\n", - " print(\"No papers searched yet, downloading first.\")\n", - " get_articles(query)\n", - " print(\"Papers downloaded, continuing\")\n", - " library_df = pd.read_csv(paper_dir_filepath).reset_index()\n", - " library_df.columns = [\"title\", \"filepath\", \"embedding\"]\n", - " library_df[\"embedding\"] = library_df[\"embedding\"].apply(ast.literal_eval)\n", - " strings = strings_ranked_by_relatedness(query, library_df, top_n=1)\n", - " print(\"Chunking text from paper\")\n", - " pdf_text = read_pdf(strings[0])\n", - "\n", - " # Initialise tokenizer\n", - " tokenizer = tiktoken.get_encoding(\"cl100k_base\")\n", - " results = \"\"\n", - "\n", - " # Chunk up the document into 1500 token chunks\n", - " chunks = create_chunks(pdf_text, 1500, tokenizer)\n", - " text_chunks = [tokenizer.decode(chunk) for chunk in chunks]\n", - " print(\"Summarizing each chunk of text\")\n", - "\n", - " # Parallel process the summaries\n", - " with concurrent.futures.ThreadPoolExecutor(\n", - " max_workers=len(text_chunks)\n", - " ) as executor:\n", - " futures = [\n", - " executor.submit(extract_chunk, chunk, summary_prompt)\n", - " for chunk in text_chunks\n", - " ]\n", - " with tqdm(total=len(text_chunks)) as pbar:\n", - " for _ in concurrent.futures.as_completed(futures):\n", - " pbar.update(1)\n", - " for future in futures:\n", - " data = future.result()\n", - " results += data\n", - "\n", - " # Final summary\n", - " print(\"Summarizing into overall summary\")\n", - " response = openai.ChatCompletion.create(\n", - " model=GPT_MODEL,\n", - " messages=[\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": f\"\"\"Write a summary collated from this collection of key points extracted from an academic paper.\n", - " The summary should highlight the core argument, conclusions and evidence, and answer the user's query.\n", - " User query: {query}\n", - " The summary should be structured in bulleted lists following the headings Core Argument, Evidence, and Conclusions.\n", - " Key points:\\n{results}\\nSummary:\\n\"\"\",\n", - " }\n", - " ],\n", - " temperature=0,\n", - " )\n", - " return response\n" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "898b94d4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Chunking text from paper\n", - "Summarizing each chunk of text\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|█████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00, 1.33it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Summarizing into overall summary\n" - ] - } - ], - "source": [ - "# Test the summarize_text function works\n", - "chat_test_response = summarize_text(\"PPO reinforcement learning sequence generation\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "c715f60d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Core Argument:\n", - "- The paper discusses the use of Proximal Policy Optimization (PPO) in sequence generation tasks, specifically in the context of chit-chat chatbots.\n", - "- The authors argue that PPO is a more efficient reinforcement learning algorithm compared to policy gradient, which is commonly used in these tasks.\n", - "- They propose a dynamic approach for PPO (PPO-dynamic) and demonstrate its efficacy in synthetic experiments and chit-chat chatbot tasks.\n", - "\n", - "Evidence:\n", - "- PPO-dynamic achieves high precision scores in a synthetic counting task, comparable to other algorithms like REINFORCE and MIXER.\n", - "- In the chit-chat chatbot task, PPO-dynamic achieves a slightly higher BLEU-2 score than REINFORCE and PPO.\n", - "- The learning curve of PPO-dynamic is more stable and faster than policy gradient.\n", - "\n", - "Conclusions:\n", - "- PPO is a better optimization method for sequence learning compared to policy gradient.\n", - "- PPO-dynamic further improves the optimization process by dynamically adjusting the hyperparameters.\n", - "- PPO can be used as a new optimization method for GAN-based sequence learning for better performance.\n" - ] - } - ], - "source": [ - "print(chat_test_response[\"choices\"][0][\"message\"][\"content\"])\n" - ] - }, - { - "cell_type": "markdown", - "id": "93a9f651", - "metadata": {}, - "source": [ - "### Configure Agent\n", - "\n", - "We'll now create 2 function specifications for functions that provide access to the arXiv data. We'll also create some more utilities to integrate Chat Completions API calls with function execution." - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "a391cabe", - "metadata": {}, - "outputs": [], - "source": [ - "# Initiate our get_articles and read_article_and_summarize functions\n", - "arxiv_functions = [\n", - " {\n", - " \"name\": \"get_articles\",\n", - " \"description\": \"\"\"Use this function to get academic papers from arXiv to answer user questions.\"\"\",\n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"query\": {\n", - " \"type\": \"string\",\n", - " \"description\": f\"\"\"\n", - " User query in JSON. Responses should be summarized and should include the article URL reference\n", - " \"\"\",\n", - " }\n", - " },\n", - " \"required\": [\"query\"],\n", - " },\n", - " \"name\": \"read_article_and_summarize\",\n", - " \"description\": \"\"\"Use this function to read whole papers and provide a summary for users.\n", - " You should NEVER call this function before get_articles has been called in the conversation.\"\"\",\n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"query\": {\n", - " \"type\": \"string\",\n", - " \"description\": f\"\"\"\n", - " Description of the article in plain text based on the user's query\n", - " \"\"\",\n", - " }\n", - " },\n", - " \"required\": [\"query\"],\n", - " },\n", - " }\n", - "]\n" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "6d5ccb27", - "metadata": {}, - "outputs": [], - "source": [ - "def chat_completion_with_function_execution(messages, functions=[None]):\n", - " \"\"\"This function makes a ChatCompletion API call with the option of adding functions\"\"\"\n", - " response = chat_completion_request(messages, functions)\n", - " full_message = response.json()[\"choices\"][0]\n", - " if full_message[\"finish_reason\"] == \"function_call\":\n", - " print(f\"Function generation requested, calling function\")\n", - " return call_arxiv_function(messages, full_message)\n", - " else:\n", - " print(f\"Function not required, responding to user\")\n", - " return response.json()\n", - "\n", - "\n", - "def call_arxiv_function(messages, full_message):\n", - " \"\"\"Function calling function which executes function calls when the model believes it is necessary.\n", - " Currently extended by adding clauses to this if statement.\"\"\"\n", - "\n", - " if full_message[\"message\"][\"function_call\"][\"name\"] == \"get_articles\":\n", - " try:\n", - " parsed_output = json.loads(\n", - " full_message[\"message\"][\"function_call\"][\"arguments\"]\n", - " )\n", - " print(\"Getting search results\")\n", - " results = get_articles(parsed_output[\"query\"])\n", - " except Exception as e:\n", - " print(parsed_output)\n", - " print(f\"Function execution failed\")\n", - " print(f\"Error message: {e}\")\n", - " messages.append(\n", - " {\n", - " \"role\": \"function\",\n", - " \"name\": full_message[\"message\"][\"function_call\"][\"name\"],\n", - " \"content\": str(results),\n", - " }\n", - " )\n", - " try:\n", - " print(\"Got search results, summarizing content\")\n", - " response = chat_completion_request(messages)\n", - " return response.json()\n", - " except Exception as e:\n", - " print(type(e))\n", - " raise Exception(\"Function chat request failed\")\n", - "\n", - " elif (\n", - " full_message[\"message\"][\"function_call\"][\"name\"] == \"read_article_and_summarize\"\n", - " ):\n", - " parsed_output = json.loads(\n", - " full_message[\"message\"][\"function_call\"][\"arguments\"]\n", - " )\n", - " print(\"Finding and reading paper\")\n", - " summary = summarize_text(parsed_output[\"query\"])\n", - " return summary\n", - "\n", - " else:\n", - " raise Exception(\"Function does not exist and cannot be called\")\n" - ] - }, - { - "cell_type": "markdown", - "id": "dd3e7868", - "metadata": {}, - "source": [ - "### arXiv conversation\n", - "\n", - "Let's test out our function in conversation" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "c39a1d80", - "metadata": {}, - "outputs": [], - "source": [ - "# Start with a system message\n", - "paper_system_message = \"\"\"You are arXivGPT, a helpful assistant pulls academic papers to answer user questions.\n", - "You summarize the papers clearly so the customer can decide which to read to answer their question.\n", - "You always provide the article_url and title so the user can understand the name of the paper and click through to access it.\n", - "Begin!\"\"\"\n", - "paper_conversation = Conversation()\n", - "paper_conversation.add_message(\"system\", paper_system_message)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "253fd0f7", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Function generation requested, calling function\n", - "Finding and reading paper\n", - "Chunking text from paper\n", - "Summarizing each chunk of text\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████████████████████████████████████| 17/17 [00:04<00:00, 3.68it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Summarizing into overall summary\n" - ] - }, - { - "data": { - "text/markdown": [ - "Core Argument:\n", - "- The paper focuses on the theoretical analysis of the PPO-Clip algorithm in the context of deep reinforcement learning.\n", - "- The paper aims to establish the first global convergence rate guarantee for PPO-Clip under neural function approximation.\n", - "\n", - "Evidence:\n", - "- The authors identify challenges in analyzing PPO-Clip, including the lack of a closed-form expression for policy updates and the coupling between clipping behavior and neural function approximation.\n", - "- The authors propose two core ideas: reinterpreting PPO-Clip from the perspective of hinge loss and introducing a two-step policy improvement scheme.\n", - "- The paper provides theoretical proofs, lemmas, and analysis to support the convergence properties of PPO-Clip and Neural PPO-Clip.\n", - "- Experimental evaluations on reinforcement learning benchmark tasks validate the effectiveness of PPO-Clip.\n", - "\n", - "Conclusions:\n", - "- The paper establishes the global convergence of PPO-Clip and characterizes its convergence rate as O(1/sqrt(T)).\n", - "- The reinterpretation of PPO-Clip through hinge loss offers a framework for generalization.\n", - "- The paper provides insights into the interplay between convergence behavior and the clipping mechanism in PPO-Clip." - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Add a user message\n", - "paper_conversation.add_message(\"user\", \"Hi, how does PPO reinforcement learning work?\")\n", - "chat_response = chat_completion_with_function_execution(\n", - " paper_conversation.conversation_history, functions=arxiv_functions\n", - ")\n", - "assistant_message = chat_response[\"choices\"][0][\"message\"][\"content\"]\n", - "paper_conversation.add_message(\"assistant\", assistant_message)\n", - "display(Markdown(assistant_message))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "3ca3e18a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Function generation requested, calling function\n", - "Finding and reading paper\n", - "Chunking text from paper\n", - "Summarizing each chunk of text\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|█████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00, 1.33it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Summarizing into overall summary\n" - ] - }, - { - "data": { - "text/markdown": [ - "Core Argument:\n", - "The paper discusses the use of Proximal Policy Optimization (PPO) in sequence generation tasks, specifically in the context of chit-chat chatbots. The authors argue that PPO is a more efficient reinforcement learning algorithm compared to policy gradient, which is commonly used in these tasks. They propose a dynamic approach for PPO (PPO-dynamic) and demonstrate its efficacy in synthetic experiments and chit-chat chatbot tasks.\n", - "\n", - "Evidence:\n", - "- PPO-dynamic achieves high precision scores in a synthetic counting task, comparable to other algorithms like REINFORCE and MIXER.\n", - "- In the chit-chat chatbot task, PPO-dynamic achieves a slightly higher BLEU-2 score than REINFORCE and PPO.\n", - "- The learning curve of PPO-dynamic is more stable and faster than policy gradient.\n", - "\n", - "Conclusions:\n", - "- PPO is a better optimization method for sequence learning compared to policy gradient.\n", - "- PPO-dynamic further improves the optimization process by dynamically adjusting the hyperparameters.\n", - "- PPO can be used as a new optimization method for GAN-based sequence learning for better performance." - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Add another user message to induce our system to use the second tool\n", - "paper_conversation.add_message(\n", - " \"user\",\n", - " \"Can you read the PPO sequence generation paper for me and give me a summary\",\n", - ")\n", - "updated_response = chat_completion_with_function_execution(\n", - " paper_conversation.conversation_history, functions=arxiv_functions\n", - ")\n", - "display(Markdown(updated_response[\"choices\"][0][\"message\"][\"content\"]))\n" - ] + "source": [] } ], "metadata": {