mirror of
https://github.com/openai/openai-cookbook
synced 2024-11-15 18:13:18 +00:00
870 lines
43 KiB
Plaintext
870 lines
43 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "3e67f200",
|
|
"metadata": {},
|
|
"source": [
|
|
"# How to use functions with a knowledge base\n",
|
|
"\n",
|
|
"This notebook builds on the concepts in the [argument generation](How_to_call_functions_with_chat_models.ipynb) notebook, by creating an agent with access to a knowledge base and two functions that it can call based on the user requirement.\n",
|
|
"\n",
|
|
"We'll create an agent that uses data from arXiv to answer questions about academic subjects. It has two functions at its disposal:\n",
|
|
"- **get_articles**: A function that gets arXiv articles on a subject and summarizes them for the user with links.\n",
|
|
"- **read_article_and_summarize**: This function takes one of the previously searched articles, reads it in its entirety and summarizes the core argument, evidence and conclusions.\n",
|
|
"\n",
|
|
"This will get you comfortable with a multi-function workflow that can choose from multiple services, and where some of the data from the first function is persisted to be used by the second.\n",
|
|
"\n",
|
|
"## Walkthrough\n",
|
|
"\n",
|
|
"This cookbook takes you through the following workflow:\n",
|
|
"\n",
|
|
"- **Search utilities:** Creating the two functions that access arXiv for answers.\n",
|
|
"- **Configure Agent:** Building up the Agent behaviour that will assess the need for a function and, if one is required, call that function and present results back to the agent.\n",
|
|
"- **arXiv conversation:** Put all of this together in live conversation.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "80e71f33",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"is_executing": true
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0mRequirement already satisfied: scipy in /opt/homebrew/lib/python3.11/site-packages (1.10.1)\n",
|
|
"Requirement already satisfied: numpy<1.27.0,>=1.19.5 in /opt/homebrew/lib/python3.11/site-packages (from scipy) (1.24.3)\n",
|
|
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0mRequirement already satisfied: tenacity in /opt/homebrew/lib/python3.11/site-packages (8.2.2)\n",
|
|
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0mRequirement already satisfied: tiktoken in /opt/homebrew/lib/python3.11/site-packages (0.4.0)\n",
|
|
"Requirement already satisfied: regex>=2022.1.18 in /opt/homebrew/lib/python3.11/site-packages (from tiktoken) (2023.6.3)\n",
|
|
"Requirement already satisfied: requests>=2.26.0 in /opt/homebrew/lib/python3.11/site-packages (from tiktoken) (2.30.0)\n",
|
|
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (3.1.0)\n",
|
|
"Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (3.4)\n",
|
|
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (1.25.11)\n",
|
|
"Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (2023.5.7)\n",
|
|
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0mRequirement already satisfied: termcolor in /opt/homebrew/lib/python3.11/site-packages (2.3.0)\n",
|
|
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0mRequirement already satisfied: openai in /opt/homebrew/lib/python3.11/site-packages (0.27.6)\n",
|
|
"Requirement already satisfied: requests>=2.20 in /opt/homebrew/lib/python3.11/site-packages (from openai) (2.30.0)\n",
|
|
"Requirement already satisfied: tqdm in /opt/homebrew/lib/python3.11/site-packages (from openai) (4.65.0)\n",
|
|
"Requirement already satisfied: aiohttp in /opt/homebrew/lib/python3.11/site-packages (from openai) (3.8.4)\n",
|
|
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.20->openai) (3.1.0)\n",
|
|
"Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.20->openai) (3.4)\n",
|
|
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.20->openai) (1.25.11)\n",
|
|
"Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.20->openai) (2023.5.7)\n",
|
|
"Requirement already satisfied: attrs>=17.3.0 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (23.1.0)\n",
|
|
"Requirement already satisfied: multidict<7.0,>=4.5 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (6.0.4)\n",
|
|
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (4.0.2)\n",
|
|
"Requirement already satisfied: yarl<2.0,>=1.0 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (1.9.2)\n",
|
|
"Requirement already satisfied: frozenlist>=1.1.1 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (1.3.3)\n",
|
|
"Requirement already satisfied: aiosignal>=1.1.2 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (1.3.1)\n",
|
|
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0mRequirement already satisfied: requests in /opt/homebrew/lib/python3.11/site-packages (2.30.0)\n",
|
|
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/homebrew/lib/python3.11/site-packages (from requests) (3.1.0)\n",
|
|
"Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/lib/python3.11/site-packages (from requests) (3.4)\n",
|
|
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/homebrew/lib/python3.11/site-packages (from requests) (1.25.11)\n",
|
|
"Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/lib/python3.11/site-packages (from requests) (2023.5.7)\n",
|
|
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0mRequirement already satisfied: arxiv in /opt/homebrew/lib/python3.11/site-packages (1.4.7)\n",
|
|
"Requirement already satisfied: feedparser in /opt/homebrew/lib/python3.11/site-packages (from arxiv) (6.0.10)\n",
|
|
"Requirement already satisfied: sgmllib3k in /opt/homebrew/lib/python3.11/site-packages (from feedparser->arxiv) (1.0.0)\n",
|
|
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[0mRequirement already satisfied: pandas in /opt/homebrew/lib/python3.11/site-packages (2.0.1)\n",
|
|
"Requirement already satisfied: python-dateutil>=2.8.2 in /opt/homebrew/lib/python3.11/site-packages (from pandas) (2.8.2)\n",
|
|
"Requirement already satisfied: pytz>=2020.1 in /opt/homebrew/lib/python3.11/site-packages (from pandas) (2023.3)\n",
|
|
"Requirement already satisfied: tzdata>=2022.1 in /opt/homebrew/lib/python3.11/site-packages (from pandas) (2023.3)\n",
|
|
"Requirement already satisfied: numpy>=1.21.0 in /opt/homebrew/lib/python3.11/site-packages (from pandas) (1.24.3)\n",
|
|
"Requirement already satisfied: six>=1.5 in /opt/homebrew/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
|
|
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0mRequirement already satisfied: PyPDF2 in /opt/homebrew/lib/python3.11/site-packages (3.0.1)\n",
|
|
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0mRequirement already satisfied: tqdm in /opt/homebrew/lib/python3.11/site-packages (4.65.0)\n",
|
|
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
|
|
"\u001b[0m"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!pip install scipy\n",
|
|
"!pip install tenacity\n",
|
|
"!pip install tiktoken==0.3.3\n",
|
|
"!pip install termcolor \n",
|
|
"!pip install openai\n",
|
|
"!pip install requests\n",
|
|
"!pip install arxiv\n",
|
|
"!pip install pandas\n",
|
|
"!pip install PyPDF2\n",
|
|
"!pip install tqdm"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "dab872c5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"\n",
|
|
"import arxiv\n",
|
|
"import ast\n",
|
|
"import concurrent\n",
|
|
"from csv import writer\n",
|
|
"from IPython.display import display, Markdown, Latex\n",
|
|
"import json\n",
|
|
"import openai\n",
|
|
"import os\n",
|
|
"import pandas as pd\n",
|
|
"from PyPDF2 import PdfReader\n",
|
|
"import requests\n",
|
|
"from scipy import spatial\n",
|
|
"from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
|
|
"import tiktoken\n",
|
|
"from tqdm import tqdm\n",
|
|
"from termcolor import colored\n",
|
|
"\n",
|
|
"GPT_MODEL = \"gpt-3.5-turbo-0613\"\n",
|
|
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "f2e47962",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Search utilities\n",
|
|
"\n",
|
|
"We'll first set up some utilities that will underpin our two functions.\n",
|
|
"\n",
|
|
"Downloaded papers will be stored in a directory (we use ```./data/papers``` here). We create a file ```arxiv_library.csv``` to store the embeddings and details for downloaded papers to retrieve against using ```summarize_text```."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "2de5d32d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"directory = './data/papers'\n",
|
|
"\n",
|
|
"# Check if the directory already exists\n",
|
|
"if not os.path.exists(directory):\n",
|
|
" # If the directory doesn't exist, create it and any necessary intermediate directories\n",
|
|
" os.makedirs(directory)\n",
|
|
" print(f\"Directory '{directory}' created successfully.\")\n",
|
|
"else:\n",
|
|
" # If the directory already exists, print a message indicating it\n",
|
|
" print(f\"Directory '{directory}' already exists.\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ae5cb7a1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Set a directory to store downloaded papers\n",
|
|
"data_dir = os.path.join(os.curdir, \"data\", \"papers\")\n",
|
|
"paper_dir_filepath = \"./data/arxiv_library.csv\"\n",
|
|
"\n",
|
|
"# Generate a blank dataframe where we can store downloaded files\n",
|
|
"df = pd.DataFrame(list())\n",
|
|
"df.to_csv(paper_dir_filepath)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "57217b9d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n",
|
|
"def embedding_request(text):\n",
|
|
" response = openai.Embedding.create(input=text, model=EMBEDDING_MODEL)\n",
|
|
" return response\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_articles(query, library=paper_dir_filepath, top_k=5):\n",
|
|
" \"\"\"This function gets the top_k articles based on a user's query, sorted by relevance.\n",
|
|
" It also downloads the files and stores them in arxiv_library.csv to be retrieved by the read_article_and_summarize.\n",
|
|
" \"\"\"\n",
|
|
" search = arxiv.Search(\n",
|
|
" query=query, max_results=top_k, sort_by=arxiv.SortCriterion.Relevance\n",
|
|
" )\n",
|
|
" result_list = []\n",
|
|
" for result in search.results():\n",
|
|
" result_dict = {}\n",
|
|
" result_dict.update({\"title\": result.title})\n",
|
|
" result_dict.update({\"summary\": result.summary})\n",
|
|
"\n",
|
|
" # Taking the first url provided\n",
|
|
" result_dict.update({\"article_url\": [x.href for x in result.links][0]})\n",
|
|
" result_dict.update({\"pdf_url\": [x.href for x in result.links][1]})\n",
|
|
" result_list.append(result_dict)\n",
|
|
"\n",
|
|
" # Store references in library file\n",
|
|
" response = embedding_request(text=result.title)\n",
|
|
" file_reference = [\n",
|
|
" result.title,\n",
|
|
" result.download_pdf(data_dir),\n",
|
|
" response[\"data\"][0][\"embedding\"],\n",
|
|
" ]\n",
|
|
"\n",
|
|
" # Write to file\n",
|
|
" with open(library, \"a\") as f_object:\n",
|
|
" writer_object = writer(f_object)\n",
|
|
" writer_object.writerow(file_reference)\n",
|
|
" f_object.close()\n",
|
|
" return result_list\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "dda02bdb",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'title': 'Proximal Policy Optimization and its Dynamic Version for Sequence Generation',\n",
|
|
" 'summary': 'In sequence generation task, many works use policy gradient for model\\noptimization to tackle the intractable backpropagation issue when maximizing\\nthe non-differentiable evaluation metrics or fooling the discriminator in\\nadversarial learning. In this paper, we replace policy gradient with proximal\\npolicy optimization (PPO), which is a proved more efficient reinforcement\\nlearning algorithm, and propose a dynamic approach for PPO (PPO-dynamic). We\\ndemonstrate the efficacy of PPO and PPO-dynamic on conditional sequence\\ngeneration tasks including synthetic experiment and chit-chat chatbot. The\\nresults show that PPO and PPO-dynamic can beat policy gradient by stability and\\nperformance.',\n",
|
|
" 'article_url': 'http://arxiv.org/abs/1808.07982v1',\n",
|
|
" 'pdf_url': 'http://arxiv.org/pdf/1808.07982v1'}"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Test that the search is working\n",
|
|
"result_output = get_articles(\"ppo reinforcement learning\")\n",
|
|
"result_output[0]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "11675627",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def strings_ranked_by_relatedness(\n",
|
|
" query: str,\n",
|
|
" df: pd.DataFrame,\n",
|
|
" relatedness_fn=lambda x, y: 1 - spatial.distance.cosine(x, y),\n",
|
|
" top_n: int = 100,\n",
|
|
") -> list[str]:\n",
|
|
" \"\"\"Returns a list of strings and relatednesses, sorted from most related to least.\"\"\"\n",
|
|
" query_embedding_response = embedding_request(query)\n",
|
|
" query_embedding = query_embedding_response[\"data\"][0][\"embedding\"]\n",
|
|
" strings_and_relatednesses = [\n",
|
|
" (row[\"filepath\"], relatedness_fn(query_embedding, row[\"embedding\"]))\n",
|
|
" for i, row in df.iterrows()\n",
|
|
" ]\n",
|
|
" strings_and_relatednesses.sort(key=lambda x: x[1], reverse=True)\n",
|
|
" strings, relatednesses = zip(*strings_and_relatednesses)\n",
|
|
" return strings[:top_n]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "7211df2c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def read_pdf(filepath):\n",
|
|
" \"\"\"Takes a filepath to a PDF and returns a string of the PDF's contents\"\"\"\n",
|
|
" # creating a pdf reader object\n",
|
|
" reader = PdfReader(filepath)\n",
|
|
" pdf_text = \"\"\n",
|
|
" page_number = 0\n",
|
|
" for page in reader.pages:\n",
|
|
" page_number += 1\n",
|
|
" pdf_text += page.extract_text() + f\"\\nPage Number: {page_number}\"\n",
|
|
" return pdf_text\n",
|
|
"\n",
|
|
"\n",
|
|
"# Split a text into smaller chunks of size n, preferably ending at the end of a sentence\n",
|
|
"def create_chunks(text, n, tokenizer):\n",
|
|
" \"\"\"Returns successive n-sized chunks from provided text.\"\"\"\n",
|
|
" tokens = tokenizer.encode(text)\n",
|
|
" i = 0\n",
|
|
" while i < len(tokens):\n",
|
|
" # Find the nearest end of sentence within a range of 0.5 * n and 1.5 * n tokens\n",
|
|
" j = min(i + int(1.5 * n), len(tokens))\n",
|
|
" while j > i + int(0.5 * n):\n",
|
|
" # Decode the tokens and check for full stop or newline\n",
|
|
" chunk = tokenizer.decode(tokens[i:j])\n",
|
|
" if chunk.endswith(\".\") or chunk.endswith(\"\\n\"):\n",
|
|
" break\n",
|
|
" j -= 1\n",
|
|
" # If no end of sentence found, use n tokens as the chunk size\n",
|
|
" if j == i + int(0.5 * n):\n",
|
|
" j = min(i + n, len(tokens))\n",
|
|
" yield tokens[i:j]\n",
|
|
" i = j\n",
|
|
"\n",
|
|
"\n",
|
|
"def extract_chunk(content, template_prompt):\n",
|
|
" \"\"\"This function applies a prompt to some input content. In this case it returns a summarized chunk of text\"\"\"\n",
|
|
" prompt = template_prompt + content\n",
|
|
" response = openai.ChatCompletion.create(\n",
|
|
" model=GPT_MODEL, messages=[{\"role\": \"user\", \"content\": prompt}], temperature=0\n",
|
|
" )\n",
|
|
" return response[\"choices\"][0][\"message\"][\"content\"]\n",
|
|
"\n",
|
|
"\n",
|
|
"def summarize_text(query):\n",
|
|
" \"\"\"This function does the following:\n",
|
|
" - Reads in the arxiv_library.csv file in including the embeddings\n",
|
|
" - Finds the closest file to the user's query\n",
|
|
" - Scrapes the text out of the file and chunks it\n",
|
|
" - Summarizes each chunk in parallel\n",
|
|
" - Does one final summary and returns this to the user\"\"\"\n",
|
|
"\n",
|
|
" # A prompt to dictate how the recursive summarizations should approach the input paper\n",
|
|
" summary_prompt = \"\"\"Summarize this text from an academic paper. Extract any key points with reasoning.\\n\\nContent:\"\"\"\n",
|
|
"\n",
|
|
" # If the library is empty (no searches have been performed yet), we perform one and download the results\n",
|
|
" library_df = pd.read_csv(paper_dir_filepath).reset_index()\n",
|
|
" if len(library_df) == 0:\n",
|
|
" print(\"No papers searched yet, downloading first.\")\n",
|
|
" get_articles(query)\n",
|
|
" print(\"Papers downloaded, continuing\")\n",
|
|
" library_df = pd.read_csv(paper_dir_filepath).reset_index()\n",
|
|
" library_df.columns = [\"title\", \"filepath\", \"embedding\"]\n",
|
|
" library_df[\"embedding\"] = library_df[\"embedding\"].apply(ast.literal_eval)\n",
|
|
" strings = strings_ranked_by_relatedness(query, library_df, top_n=1)\n",
|
|
" print(\"Chunking text from paper\")\n",
|
|
" pdf_text = read_pdf(strings[0])\n",
|
|
"\n",
|
|
" # Initialise tokenizer\n",
|
|
" tokenizer = tiktoken.get_encoding(\"cl100k_base\")\n",
|
|
" results = \"\"\n",
|
|
"\n",
|
|
" # Chunk up the document into 1500 token chunks\n",
|
|
" chunks = create_chunks(pdf_text, 1500, tokenizer)\n",
|
|
" text_chunks = [tokenizer.decode(chunk) for chunk in chunks]\n",
|
|
" print(\"Summarizing each chunk of text\")\n",
|
|
"\n",
|
|
" # Parallel process the summaries\n",
|
|
" with concurrent.futures.ThreadPoolExecutor(\n",
|
|
" max_workers=len(text_chunks)\n",
|
|
" ) as executor:\n",
|
|
" futures = [\n",
|
|
" executor.submit(extract_chunk, chunk, summary_prompt)\n",
|
|
" for chunk in text_chunks\n",
|
|
" ]\n",
|
|
" with tqdm(total=len(text_chunks)) as pbar:\n",
|
|
" for _ in concurrent.futures.as_completed(futures):\n",
|
|
" pbar.update(1)\n",
|
|
" for future in futures:\n",
|
|
" data = future.result()\n",
|
|
" results += data\n",
|
|
"\n",
|
|
" # Final summary\n",
|
|
" print(\"Summarizing into overall summary\")\n",
|
|
" response = openai.ChatCompletion.create(\n",
|
|
" model=GPT_MODEL,\n",
|
|
" messages=[\n",
|
|
" {\n",
|
|
" \"role\": \"user\",\n",
|
|
" \"content\": f\"\"\"Write a summary collated from this collection of key points extracted from an academic paper.\n",
|
|
" The summary should highlight the core argument, conclusions and evidence, and answer the user's query.\n",
|
|
" User query: {query}\n",
|
|
" The summary should be structured in bulleted lists following the headings Core Argument, Evidence, and Conclusions.\n",
|
|
" Key points:\\n{results}\\nSummary:\\n\"\"\",\n",
|
|
" }\n",
|
|
" ],\n",
|
|
" temperature=0,\n",
|
|
" )\n",
|
|
" return response\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "898b94d4",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Chunking text from paper\n",
|
|
"Summarizing each chunk of text\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00, 1.19s/it]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Summarizing into overall summary\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Test the summarize_text function works\n",
|
|
"chat_test_response = summarize_text(\"PPO reinforcement learning sequence generation\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "c715f60d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Core Argument:\n",
|
|
"- The paper discusses the use of Proximal Policy Optimization (PPO) in sequence generation tasks, specifically in the context of chit-chat chatbots.\n",
|
|
"- The authors argue that PPO is a more efficient reinforcement learning algorithm compared to policy gradient, commonly used in text generation tasks.\n",
|
|
"- They propose a dynamic approach for PPO (PPO-dynamic) and demonstrate its efficacy in synthetic experiments and chit-chat chatbot tasks.\n",
|
|
"\n",
|
|
"Evidence:\n",
|
|
"- PPO-dynamic achieves high precision scores comparable to other algorithms in a synthetic counting task.\n",
|
|
"- PPO-dynamic shows faster progress and more stable learning curves compared to PPO in the synthetic counting task.\n",
|
|
"- In the chit-chat chatbot task, PPO-dynamic achieves a slightly higher BLEU-2 score than other algorithms.\n",
|
|
"- PPO and PPO-dynamic have more stable learning curves and converge faster than policy gradient.\n",
|
|
"\n",
|
|
"Conclusions:\n",
|
|
"- PPO is a better optimization method for sequence learning compared to policy gradient.\n",
|
|
"- PPO-dynamic further improves the optimization process by dynamically adjusting hyperparameters.\n",
|
|
"- PPO can be used as a new optimization method for GAN-based sequence learning for better performance.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(chat_test_response[\"choices\"][0][\"message\"][\"content\"])\n"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "dab07e98",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Configure Agent\n",
|
|
"\n",
|
|
"We'll create our agent in this step, including a ```Conversation``` class to support multiple turns with the API, and some Python functions to enable interaction between the ```ChatCompletion``` API and our knowledge base functions."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "77a6fb4f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n",
|
|
"def chat_completion_request(messages, functions=None, model=GPT_MODEL):\n",
|
|
" headers = {\n",
|
|
" \"Content-Type\": \"application/json\",\n",
|
|
" \"Authorization\": \"Bearer \" + openai.api_key,\n",
|
|
" }\n",
|
|
" json_data = {\"model\": model, \"messages\": messages}\n",
|
|
" if functions is not None:\n",
|
|
" json_data.update({\"functions\": functions})\n",
|
|
" try:\n",
|
|
" response = requests.post(\n",
|
|
" \"https://api.openai.com/v1/chat/completions\",\n",
|
|
" headers=headers,\n",
|
|
" json=json_data,\n",
|
|
" )\n",
|
|
" return response\n",
|
|
" except Exception as e:\n",
|
|
" print(\"Unable to generate ChatCompletion response\")\n",
|
|
" print(f\"Exception: {e}\")\n",
|
|
" return e\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "73f7672d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Conversation:\n",
|
|
" def __init__(self):\n",
|
|
" self.conversation_history = []\n",
|
|
"\n",
|
|
" def add_message(self, role, content):\n",
|
|
" message = {\"role\": role, \"content\": content}\n",
|
|
" self.conversation_history.append(message)\n",
|
|
"\n",
|
|
" def display_conversation(self, detailed=False):\n",
|
|
" role_to_color = {\n",
|
|
" \"system\": \"red\",\n",
|
|
" \"user\": \"green\",\n",
|
|
" \"assistant\": \"blue\",\n",
|
|
" \"function\": \"magenta\",\n",
|
|
" }\n",
|
|
" for message in self.conversation_history:\n",
|
|
" print(\n",
|
|
" colored(\n",
|
|
" f\"{message['role']}: {message['content']}\\n\\n\",\n",
|
|
" role_to_color[message[\"role\"]],\n",
|
|
" )\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "978b7877",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Initiate our get_articles and read_article_and_summarize functions\n",
|
|
"arxiv_functions = [\n",
|
|
" {\n",
|
|
" \"name\": \"get_articles\",\n",
|
|
" \"description\": \"\"\"Use this function to get academic papers from arXiv to answer user questions.\"\"\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"query\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": f\"\"\"\n",
|
|
" User query in JSON. Responses should be summarized and should include the article URL reference\n",
|
|
" \"\"\",\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"query\"],\n",
|
|
" },\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"name\": \"read_article_and_summarize\",\n",
|
|
" \"description\": \"\"\"Use this function to read whole papers and provide a summary for users.\n",
|
|
" You should NEVER call this function before get_articles has been called in the conversation.\"\"\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"query\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": f\"\"\"\n",
|
|
" Description of the article in plain text based on the user's query\n",
|
|
" \"\"\",\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"query\"],\n",
|
|
" },\n",
|
|
" }\n",
|
|
"]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "0c88ae15",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def chat_completion_with_function_execution(messages, functions=[None]):\n",
|
|
" \"\"\"This function makes a ChatCompletion API call with the option of adding functions\"\"\"\n",
|
|
" response = chat_completion_request(messages, functions)\n",
|
|
" full_message = response.json()[\"choices\"][0]\n",
|
|
" if full_message[\"finish_reason\"] == \"function_call\":\n",
|
|
" print(f\"Function generation requested, calling function\")\n",
|
|
" return call_arxiv_function(messages, full_message)\n",
|
|
" else:\n",
|
|
" print(f\"Function not required, responding to user\")\n",
|
|
" return response.json()\n",
|
|
"\n",
|
|
"\n",
|
|
"def call_arxiv_function(messages, full_message):\n",
|
|
" \"\"\"Function calling function which executes function calls when the model believes it is necessary.\n",
|
|
" Currently extended by adding clauses to this if statement.\"\"\"\n",
|
|
"\n",
|
|
" if full_message[\"message\"][\"function_call\"][\"name\"] == \"get_articles\":\n",
|
|
" try:\n",
|
|
" parsed_output = json.loads(\n",
|
|
" full_message[\"message\"][\"function_call\"][\"arguments\"]\n",
|
|
" )\n",
|
|
" print(\"Getting search results\")\n",
|
|
" results = get_articles(parsed_output[\"query\"])\n",
|
|
" except Exception as e:\n",
|
|
" print(parsed_output)\n",
|
|
" print(f\"Function execution failed\")\n",
|
|
" print(f\"Error message: {e}\")\n",
|
|
" messages.append(\n",
|
|
" {\n",
|
|
" \"role\": \"function\",\n",
|
|
" \"name\": full_message[\"message\"][\"function_call\"][\"name\"],\n",
|
|
" \"content\": str(results),\n",
|
|
" }\n",
|
|
" )\n",
|
|
" try:\n",
|
|
" print(\"Got search results, summarizing content\")\n",
|
|
" response = chat_completion_request(messages)\n",
|
|
" return response.json()\n",
|
|
" except Exception as e:\n",
|
|
" print(type(e))\n",
|
|
" raise Exception(\"Function chat request failed\")\n",
|
|
"\n",
|
|
" elif (\n",
|
|
" full_message[\"message\"][\"function_call\"][\"name\"] == \"read_article_and_summarize\"\n",
|
|
" ):\n",
|
|
" parsed_output = json.loads(\n",
|
|
" full_message[\"message\"][\"function_call\"][\"arguments\"]\n",
|
|
" )\n",
|
|
" print(\"Finding and reading paper\")\n",
|
|
" summary = summarize_text(parsed_output[\"query\"])\n",
|
|
" return summary\n",
|
|
"\n",
|
|
" else:\n",
|
|
" raise Exception(\"Function does not exist and cannot be called\")\n"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "dd3e7868",
|
|
"metadata": {},
|
|
"source": [
|
|
"## arXiv conversation\n",
|
|
"\n",
|
|
"Let's put this all together by testing our functions out in conversation."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "c39a1d80",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Start with a system message\n",
|
|
"paper_system_message = \"\"\"You are arXivGPT, a helpful assistant pulls academic papers to answer user questions.\n",
|
|
"You summarize the papers clearly so the customer can decide which to read to answer their question.\n",
|
|
"You always provide the article_url and title so the user can understand the name of the paper and click through to access it.\n",
|
|
"Begin!\"\"\"\n",
|
|
"paper_conversation = Conversation()\n",
|
|
"paper_conversation.add_message(\"system\", paper_system_message)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "253fd0f7",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Function generation requested, calling function\n",
|
|
"Finding and reading paper\n",
|
|
"Chunking text from paper\n",
|
|
"Summarizing each chunk of text\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████| 17/17 [00:06<00:00, 2.65it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Summarizing into overall summary\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/markdown": [
|
|
"Core Argument:\n",
|
|
"- The paper focuses on the theoretical analysis of the PPO-Clip algorithm in the context of deep reinforcement learning.\n",
|
|
"- The authors propose two core ideas: reinterpreting PPO-Clip from the perspective of hinge loss and introducing a two-step policy improvement scheme.\n",
|
|
"- The paper establishes the global convergence of PPO-Clip and characterizes its convergence rate.\n",
|
|
"\n",
|
|
"Evidence:\n",
|
|
"- The paper addresses the challenges posed by the clipping mechanism and neural function approximation.\n",
|
|
"- The authors provide theoretical proofs, lemmas, and mathematical analysis to support their arguments.\n",
|
|
"- The paper presents empirical experiments on various reinforcement learning benchmark tasks to validate the effectiveness of PPO-Clip.\n",
|
|
"\n",
|
|
"Conclusions:\n",
|
|
"- The paper offers theoretical insights into the performance of PPO-Clip and provides a framework for analyzing its convergence properties.\n",
|
|
"- PPO-Clip is shown to have a global convergence rate of O(1/sqrt(T)), where T is the number of iterations.\n",
|
|
"- The hinge loss reinterpretation of PPO-Clip allows for variants with comparable empirical performance.\n",
|
|
"- The paper contributes to a better understanding of PPO-Clip in the reinforcement learning community."
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.Markdown object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Add a user message\n",
|
|
"paper_conversation.add_message(\"user\", \"Hi, how does PPO reinforcement learning work?\")\n",
|
|
"chat_response = chat_completion_with_function_execution(\n",
|
|
" paper_conversation.conversation_history, functions=arxiv_functions\n",
|
|
")\n",
|
|
"assistant_message = chat_response[\"choices\"][0][\"message\"][\"content\"]\n",
|
|
"paper_conversation.add_message(\"assistant\", assistant_message)\n",
|
|
"display(Markdown(assistant_message))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "3ca3e18a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Function generation requested, calling function\n",
|
|
"Finding and reading paper\n",
|
|
"Chunking text from paper\n",
|
|
"Summarizing each chunk of text\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00, 1.08s/it]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Summarizing into overall summary\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/markdown": [
|
|
"Core Argument:\n",
|
|
"- The paper discusses the use of proximal policy optimization (PPO) in sequence generation tasks, specifically in the context of chit-chat chatbots.\n",
|
|
"- The authors argue that PPO is a more efficient reinforcement learning algorithm compared to policy gradient, which is commonly used in text generation tasks.\n",
|
|
"- They propose a dynamic approach for PPO (PPO-dynamic) and demonstrate its efficacy in synthetic experiments and chit-chat chatbot tasks.\n",
|
|
"\n",
|
|
"Evidence:\n",
|
|
"- The authors derive the constraints for PPO-dynamic and provide the pseudo code for both PPO and PPO-dynamic.\n",
|
|
"- They compare the performance of PPO-dynamic with other algorithms, including REINFORCE, MIXER, and SeqGAN, on a synthetic counting task and a chit-chat chatbot task using the OpenSubtitles dataset.\n",
|
|
"- In the synthetic counting task, PPO-dynamic achieves a high precision score comparable to REINFORCE and MIXER, with a faster learning curve compared to PPO.\n",
|
|
"- In the chit-chat chatbot task, PPO-dynamic achieves a slightly higher BLEU-2 score than REINFORCE and PPO, with a more stable and faster learning curve than policy gradient.\n",
|
|
"\n",
|
|
"Conclusions:\n",
|
|
"- The results suggest that PPO is a better optimization method for sequence learning compared to policy gradient.\n",
|
|
"- PPO-dynamic further improves the optimization process by dynamically adjusting the hyperparameters.\n",
|
|
"- The authors conclude that PPO can be used as a new optimization method for GAN-based sequence learning for better performance."
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.Markdown object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Add another user message to induce our system to use the second tool\n",
|
|
"paper_conversation.add_message(\n",
|
|
" \"user\",\n",
|
|
" \"Can you read the PPO sequence generation paper for me and give me a summary\",\n",
|
|
")\n",
|
|
"updated_response = chat_completion_with_function_execution(\n",
|
|
" paper_conversation.conversation_history, functions=arxiv_functions\n",
|
|
")\n",
|
|
"display(Markdown(updated_response[\"choices\"][0][\"message\"][\"content\"]))\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "tua_test",
|
|
"language": "python",
|
|
"name": "tua_test"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|