From 354e2c28cb5bc1de21714ec3c61c80a612cac3be Mon Sep 17 00:00:00 2001 From: simonpfish Date: Thu, 15 Jun 2023 11:49:19 -0700 Subject: [PATCH 01/24] add search augmentation notebook --- examples/search_augmentation.ipynb | 472 +++++++++++++++++++++++++++++ 1 file changed, 472 insertions(+) create mode 100644 examples/search_augmentation.ipynb diff --git a/examples/search_augmentation.ipynb b/examples/search_augmentation.ipynb new file mode 100644 index 00000000..03796919 --- /dev/null +++ b/examples/search_augmentation.ipynb @@ -0,0 +1,472 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Search Augmentation\n", + "\n", + "### with Multiple Query Generation and Semantic Reranking\n", + "\n", + "Searching for information can be challenging. We can leverage the completions API and embeddings to help us sift through the noise. In this notebook, we will use the completions API to generate search queries given a user's question, and then rerank the results using semantic similarity to a hypothetical answer.\n", + "\n", + "We can break down this process into three steps:\n", + "\n", + "**1. Search**\n", + "\n", + "- User asks a question\n", + "- Model generates a list of queries\n", + "- Search queries are executed in parallel\n", + "\n", + "**2. Re-rank**\n", + "\n", + "- Model generates an ideal answer by hallucination\n", + "- Search results are ranked based on semantic similarity to the ideal answer\n", + "\n", + "**3. Answer**\n", + "\n", + "- Given the top search results, the model attempts to answer the user question, including references and links.\n", + "\n", + "Let's dive into it! We will use Twitter as an example domain to search over.\n", + "\n", + "## Setup\n", + "\n", + "Once you have your keys, you can set them as environment variables in your`.env` file in the same directory as this notebook. The `.env` file should look like this:\n", + "\n", + "```\n", + "NEWS_API_KEY=your_api_key\n", + "OPENAI_API_KEY=your_openai_api_key\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Dependencies\n", + "import openai\n", + "from tqdm import tqdm\n", + "import os\n", + "import dotenv\n", + "import requests\n", + "import json\n", + "from datetime import date, timedelta\n", + "from numpy import dot\n", + "from IPython import display\n", + "\n", + "\n", + "# Load environment variables\n", + "dotenv.load_dotenv()\n", + "\n", + "news_api_key = os.getenv(\"NEWS_API_KEY\")\n", + "\n", + "\n", + "# Helper functions\n", + "def json_gpt(prompt):\n", + " completion = openai.ChatCompletion.create(\n", + " model=\"gpt-3.5-turbo\",\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"Output only valid JSON\"},\n", + " {\"role\": \"user\", \"content\": prompt},\n", + " ],\n", + " temperature=1,\n", + " )\n", + "\n", + " text = completion.choices[0].message.content\n", + " parsed = json.loads(text)\n", + "\n", + " return parsed\n", + "\n", + "\n", + "def embedding(input):\n", + " response = openai.Embedding.create(\n", + " model=\"text-embedding-ada-002\", input=input)\n", + " return [data.embedding for data in response.data]\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Search\n", + "\n", + "Let's first generate a set of queries given a user question.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['San Francisco recent events',\n", + " 'news in San Francisco',\n", + " 'San Francisco happenings',\n", + " 'recent developments in San Francisco',\n", + " 'Happening in San Francisco today']" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# User asks a question\n", + "USER_QUESTION = \"What has happened recently in San Francisco?\"\n", + "\n", + "# Model generates a list of queries\n", + "PROMPT = f\"\"\"\n", + "Generate an array of search queries that are relevant to this question. \n", + "Use a variation of related keywords for the queries, trying to be as general as possible.\n", + "Include as many queries as you can think of, including and excluding terms. \n", + "For example, include queries like ['keyword_1 keyword_2', 'keyword_1', 'keyword_2']. \n", + "Be creative. The more queries you include, the more likely you are to find relevant results.\n", + "\n", + "User question: {USER_QUESTION}\n", + "\n", + "Format: {{\"queries\": [\"query_1\", \"query_2\", \"query_3\"]}}\n", + "Maximum 5 queries\n", + "\"\"\"\n", + "\n", + "queries = json_gpt(PROMPT)[\"queries\"]\n", + "\n", + "queries\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The queries look good, let's run the search!\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 5/5 [00:02<00:00, 2.30it/s]\n" + ] + } + ], + "source": [ + "def search_news(query: str):\n", + " # get date 1 week ago\n", + " one_week_ago = (date.today() - timedelta(weeks=1)).strftime(\"%Y-%m-%d\")\n", + "\n", + " response = requests.get(\n", + " \"https://newsapi.org/v2/everything\",\n", + " params={\n", + " \"q\": query,\n", + " \"apiKey\": news_api_key,\n", + " \"pageSize\": 50,\n", + " \"sortBy\": \"relevancy\",\n", + " \"from\": one_week_ago,\n", + " },\n", + " )\n", + "\n", + " return response.json()\n", + "\n", + "\n", + "articles = []\n", + "\n", + "for query in tqdm(queries):\n", + " articles = articles + search_news(query)[\"articles\"]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of articles: 185\n", + "Top 5 articles: \n", + "\n", + "Title: Samsung confirms Galaxy Z Flip 5, Fold 5 launch details\n", + "Description: Samsung has announced when and where the next Galaxy Unpacked event will take place. It's here where the company will unveil its next foldables.\n", + "Content: