Merge branch 'main' into katia/graphdb

pull/946/head
jhills20 3 months ago committed by GitHub
commit 26c6fa64d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -6,14 +6,14 @@
The magic of large language models is that by being trained to minimize this prediction error over vast quantities of text, the models end up learning concepts useful for these predictions. For example, they learn:
* how to spell
* how grammar works
* how to paraphrase
* how to answer questions
* how to hold a conversation
* how to write in many languages
* how to code
* etc.
- how to spell
- how grammar works
- how to paraphrase
- how to answer questions
- how to hold a conversation
- how to write in many languages
- how to code
- etc.
They do this by “reading” a large amount of existing text and learning how words tend to appear in context with other words, and uses what it has learned to predict the next most likely word that might appear in response to a user request, and each subsequent word after that.
@ -25,12 +25,12 @@ Of all the inputs to a large language model, by far the most influential is the
Large language models can be prompted to produce output in a few ways:
* **Instruction**: Tell the model what you want
* **Completion**: Induce the model to complete the beginning of what you want
* **Scenario**: Give the model a situation to play out
* **Demonstration**: Show the model what you want, with either:
* A few examples in the prompt
* Many hundreds or thousands of examples in a fine-tuning training dataset
- **Instruction**: Tell the model what you want
- **Completion**: Induce the model to complete the beginning of what you want
- **Scenario**: Give the model a situation to play out
- **Demonstration**: Show the model what you want, with either:
- A few examples in the prompt
- Many hundreds or thousands of examples in a fine-tuning training dataset
An example of each is shown below.
@ -77,6 +77,7 @@ Output:
Giving the model a scenario to follow or role to play out can be helpful for complex queries or when seeking imaginative responses. When using a hypothetical prompt, you set up a situation, problem, or story, and then ask the model to respond as if it were a character in that scenario or an expert on the topic.
Example scenario prompt:
```text
Your role is to extract the name of the author from any given text
@ -141,11 +142,11 @@ Large language models aren't only great at text - they can be great at code too.
GPT-4 powers [numerous innovative products][OpenAI Customer Stories], including:
* [GitHub Copilot] (autocompletes code in Visual Studio and other IDEs)
* [Replit](https://replit.com/) (can complete, explain, edit and generate code)
* [Cursor](https://cursor.sh/) (build software faster in an editor designed for pair-programming with AI)
- [GitHub Copilot] (autocompletes code in Visual Studio and other IDEs)
- [Replit](https://replit.com/) (can complete, explain, edit and generate code)
- [Cursor](https://cursor.sh/) (build software faster in an editor designed for pair-programming with AI)
GPT-4 is more advanced than previous models like `text-davinci-002`. But, to get the best out of GPT-4 for coding tasks, it's still important to give clear and specific instructions. As a result, designing good prompts can take more care.
GPT-4 is more advanced than previous models like `gpt-3.5-turbo-instruct`. But, to get the best out of GPT-4 for coding tasks, it's still important to give clear and specific instructions. As a result, designing good prompts can take more care.
### More prompt advice
@ -153,12 +154,10 @@ For more prompt examples, visit [OpenAI Examples][OpenAI Examples].
In general, the input prompt is the best lever for improving model outputs. You can try tricks like:
* **Be more specific** E.g., if you want the output to be a comma separated list, ask it to return a comma separated list. If you want it to say "I don't know" when it doesn't know the answer, tell it 'Say "I don't know" if you do not know the answer.' The more specific your instructions, the better the model can respond.
* **Provide Context**: Help the model understand the bigger picture of your request. This could be background information, examples/demonstrations of what you want or explaining the purpose of your task.
* **Ask the model to answer as if it was an expert.** Explicitly asking the model to produce high quality output or output as if it was written by an expert can induce the model to give higher quality answers that it thinks an expert would write. Phrases like "Explain in detail" or "Describe step-by-step" can be effective.
* **Prompt the model to write down the series of steps explaining its reasoning.** If understanding the 'why' behind an answer is important, prompt the model to include its reasoning. This can be done by simply adding a line like "[Let's think step by step](https://arxiv.org/abs/2205.11916)" before each answer.
- **Be more specific** E.g., if you want the output to be a comma separated list, ask it to return a comma separated list. If you want it to say "I don't know" when it doesn't know the answer, tell it 'Say "I don't know" if you do not know the answer.' The more specific your instructions, the better the model can respond.
- **Provide Context**: Help the model understand the bigger picture of your request. This could be background information, examples/demonstrations of what you want or explaining the purpose of your task.
- **Ask the model to answer as if it was an expert.** Explicitly asking the model to produce high quality output or output as if it was written by an expert can induce the model to give higher quality answers that it thinks an expert would write. Phrases like "Explain in detail" or "Describe step-by-step" can be effective.
- **Prompt the model to write down the series of steps explaining its reasoning.** If understanding the 'why' behind an answer is important, prompt the model to include its reasoning. This can be done by simply adding a line like "[Let's think step by step](https://arxiv.org/abs/2205.11916)" before each answer.
[Fine Tuning Docs]: https://platform.openai.com/docs/guides/fine-tuning
[OpenAI Customer Stories]: https://openai.com/customer-stories

@ -14,25 +14,25 @@ If you were asked to multiply 13 by 17, would the answer pop immediately into yo
Similarly, if you give GPT-3 a task that's too complex to do in the time it takes to calculate its next token, it may confabulate an incorrect guess. Yet, akin to humans, that doesn't necessarily mean the model is incapable of the task. With some time and space to reason things out, the model still may be able to answer reliably.
As an example, if you ask `text-davinci-002` the following math problem about juggling balls, it answers incorrectly:
As an example, if you ask `gpt-3.5-turbo-instruct` the following math problem about juggling balls, it answers incorrectly:
```text-davinci-002
```gpt-3.5-turbo-instruct
Q: A juggler has 16 balls. Half of the balls are golf balls and half of the golf balls are blue. How many blue golf balls are there?
A:
```
```text-davinci-002
```gpt-3.5-turbo-instruct
There are 8 blue golf balls.
```
Does this mean that GPT-3 cannot do simple math problems? No; in fact, it turns out that by prompting the model with `Let's think step by step`, the model solves the problem reliably:
```text-davinci-002
```gpt-3.5-turbo-instruct
Q: A juggler has 16 balls. Half of the balls are golf balls and half of the golf balls are blue. How many blue golf balls are there?
A: Let's think step by step.
```
```text-davinci-002
```gpt-3.5-turbo-instruct
There are 16 balls in total.
Half of the balls are golf balls.
That means that there are 8 golf balls.
@ -64,9 +64,9 @@ The rest of this article shares techniques for improving reliability of large la
One way to give a model more time and space to think is to break tasks into simpler pieces.
As an example, consider a task where we ask the model a multiple-choice question about some text - in this case, a game of Clue. When asked directly, `text-davinci-002` isn't able to put clues 3 & 5 together, and answers incorrectly:
As an example, consider a task where we ask the model a multiple-choice question about some text - in this case, a game of Clue. When asked directly, `gpt-3.5-turbo-instruct` isn't able to put clues 3 & 5 together, and answers incorrectly:
```text-davinci-002
```gpt-3.5-turbo-instruct
Use the following clues to answer the following multiple-choice question.
Clues:
@ -84,7 +84,7 @@ Question: Was Colonel Mustard in the observatory with the candlestick?
Solution:
```
```text-davinci-002
```gpt-3.5-turbo-instruct
(c) Unknown; there is not enough information to determine whether Colonel Mustard was in the observatory with the candlestick
```
@ -96,7 +96,7 @@ However, instead of asking for the answer directly, we can split the task into t
- Second, combine the relevant clues to reason out the answer to the question
- Third, write the final answer: either (a), (b), or (c)
```text-davinci-002
```gpt-3.5-turbo-instruct
Use the following clues to answer the following multiple-choice question, using the following procedure:
(1) First, go through the clues one by one and consider whether the clue is potentially relevant
(2) Second, combine the relevant clues to reason out the answer to the question
@ -118,7 +118,7 @@ Solution:
(1) First, go through the clues one by one and consider whether the clue is potentially relevant:
```
```text-davinci-002
```gpt-3.5-turbo-instruct
Clue 1: Miss Scarlett was the only person in the lounge. This clue is not relevant.
Clue 2: The person with the pipe was in the kitchen. This clue is not relevant.
Clue 3: Colonel Mustard was the only person in the observatory. This clue is relevant.
@ -136,9 +136,9 @@ By giving the model more time and space to think, and guiding it along a reasoni
Another benefit of splitting complex instructions into smaller subtasks is that it can help keep the model focused on each subtask.
For example, if we ask `text-davinci-002` to summarize a text in its original language, the model can lapse back into English:
For example, if we ask `gpt-3.5-turbo-instruct` to summarize a text in its original language, the model can lapse back into English:
```text-davinci-002
```gpt-3.5-turbo-instruct
Summarize the text using the original language of the text. The summary should be one sentence long.
Text:
@ -149,13 +149,13 @@ La estadística (la forma femenina del término alemán Statistik, derivado a su
Summary:
```
```text-davinci-002
```gpt-3.5-turbo-instruct
The text explains that statistics is a science that studies the variability, collection, organization, analysis, interpretation, and presentation of data, as well as the random process that generates them following the laws of probability.
```
However, if we first ask the model to identify the language of the text, and then summarize the text, it becomes more reliable:
```text-davinci-002
```gpt-3.5-turbo-instruct
First, identify the language of the text. Second, summarize the text using the original language of the text. The summary should be one sentence long.
Text:
@ -166,7 +166,7 @@ La estadística (la forma femenina del término alemán Statistik, derivado a su
Language:
```
```text-davinci-002
```gpt-3.5-turbo-instruct
Spanish
La estadística es una ciencia que estudia la variabilidad, colección, organización, análisis, interpretación, y presentación de los datos, así como el proceso aleatorio que los genera siguiendo las leyes de la probabilidad.
@ -203,7 +203,7 @@ To learn more, read the [full paper](https://arxiv.org/abs/2205.11916).
If you apply this technique to your own tasks, don't be afraid to experiment with customizing the instruction. `Let's think step by step` is rather generic, so you may find better performance with instructions that hew to a stricter format customized to your use case. For example, you can try more structured variants like `First, think step by step about why X might be true. Second, think step by step about why Y might be true. Third, think step by step about whether X or Y makes more sense.`. And you can even give the model an example format to help keep it on track, e.g.:
```text-davinci-002
```gpt-3.5-turbo-instruct
Using the IRS guidance below, answer the following questions using this format:
(1) For each criterion, determine whether it is met by the vehicle purchase
- {Criterion} Let's think step by step. {explanation} {yes or no, or if the question does not apply then N/A}.
@ -229,7 +229,7 @@ Solution:
- Does the vehicle have at least four wheels? Let's think step by step.
```
```text-davinci-002
```gpt-3.5-turbo-instruct
The Toyota Prius Prime has four wheels, so the answer is yes.
- Does the vehicle weigh less than 14,000 pounds? Let's think step by step. The Toyota Prius Prime weighs less than 14,000 pounds, so the answer is yes.
- Does the vehicle draw energy from a battery with at least 4 kilowatt hours that may be recharged from an external source? Let's think step by step. The Toyota Prius Prime has a battery with at least 4 kilowatt hours that may be recharged from an external source, so the answer is yes.

@ -8,8 +8,8 @@ Embeddings can be used for semantic search, recommendations, cluster analysis, n
For more information, read OpenAI's blog post announcements:
* [Introducing Text and Code Embeddings (Jan 2022)](https://openai.com/blog/introducing-text-and-code-embeddings/)
* [New and Improved Embedding Model (Dec 2022)](https://openai.com/blog/new-and-improved-embedding-model/)
- [Introducing Text and Code Embeddings (Jan 2022)](https://openai.com/blog/introducing-text-and-code-embeddings/)
- [New and Improved Embedding Model (Dec 2022)](https://openai.com/blog/new-and-improved-embedding-model/)
For comparison with other embedding models, see [Massive Text Embedding Benchmark (MTEB) Leaderboard](https://huggingface.co/spaces/mteb/leaderboard)
@ -19,14 +19,14 @@ Embeddings can be used for search either by themselves or as a feature in a larg
The simplest way to use embeddings for search is as follows:
* Before the search (precompute):
* Split your text corpus into chunks smaller than the token limit (8,191 tokens for `text-embedding-ada-002`)
* Embed each chunk of text
* Store those embeddings in your own database or in a vector search provider like [Pinecone](https://www.pinecone.io), [Weaviate](https://weaviate.io) or [Qdrant](https://qdrant.tech)
* At the time of the search (live compute):
* Embed the search query
* Find the closest embeddings in your database
* Return the top results
- Before the search (precompute):
- Split your text corpus into chunks smaller than the token limit (8,191 tokens for `text-embedding-3-small`)
- Embed each chunk of text
- Store those embeddings in your own database or in a vector search provider like [Pinecone](https://www.pinecone.io), [Weaviate](https://weaviate.io) or [Qdrant](https://qdrant.tech)
- At the time of the search (live compute):
- Embed the search query
- Find the closest embeddings in your database
- Return the top results
An example of how to use embeddings for search is shown in [Semantic_text_search_using_embeddings.ipynb](../examples/Semantic_text_search_using_embeddings.ipynb).

@ -53,6 +53,11 @@ prakul:
website: "https://www.linkedin.com/in/prakulagarwal"
avatar: "https://media.licdn.com/dms/image/D5603AQEUug83qKgRBg/profile-displayphoto-shrink_800_800/0/1675384960197?e=1706140800&v=beta&t=qxkDbBr-Bk2ASpcwbR5JVPD6yS-vzmIwNHAa8ApyDq4"
gaborcselle:
name: "Gabor Cselle"
website: "https://www.linkedin.com/in/gaborcselle"
avatar: "https://avatars.githubusercontent.com/u/54096?s=96&v=4"
nghiauet:
name: "Nghia Pham"
website: "https://www.linkedin.com/in/deptraicucmanh/"

@ -191,8 +191,10 @@
],
"source": [
"from openai import OpenAI\n",
"import os\n",
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
"\n",
"client = OpenAI()\n",
"\n",
"assistant = client.beta.assistants.create(\n",
" name=\"Math Tutor\",\n",
@ -586,7 +588,7 @@
"\n",
"MATH_ASSISTANT_ID = assistant.id # or a hard-coded ID like \"asst-...\"\n",
"\n",
"client = OpenAI()\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
"\n",
"def submit_message(assistant_id, thread, user_message):\n",
" client.beta.threads.messages.create(\n",

@ -1,189 +0,0 @@
from typing import List, Union
from smokey import Smokey
import openai
def get_candidates(
prompt: str,
stop: List[str],
temperature: float,
priming_prefix: str,
engine: str,
n: int = 5,
) -> List[str]:
"""
Generate N candidate completions based on the prompt, generated with a specific temperature.
:param prompt: The prompt to start the conversation with.
:param stop: A list of tokens that indicate the end of the generation.
:param temperature: The temperature of the generation.
:param priming_prefix: The prefix to use for the priming.
:param engine: The engine to use for the generation.
:param n: The number of completions to generate.
:return: A list of completions.
"""
response = openai.Completion.create(
engine=engine,
prompt=prompt,
temperature=temperature,
max_tokens=150,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
stop=stop,
n=n,
)
responses = [priming_prefix + choice.text for choice in response.choices]
return responses
def rindex(lst: List, value: str) -> int:
"""
Return the index of the last occurrence of a value in a list.
:param lst: The list to search in.
:param value: The value to search for.
:return: The index of the last occurrence of the value.
"""
try:
return len(lst) - lst[::-1].index(value) - 1
except ValueError:
raise ValueError(f"Answer start token `{value}` not found in the eval template")
def eval_candidate(
candidate_answer: str,
original_instruction: str,
eval_template: str,
answer_start_token: str,
engine: str,
) -> float:
"""
Evaluate a candidate answer by calculating the average log probability
of the original instruction, given the candidate answer with a specific
evaluation template, aimed at reconstructing the original instruction.
:param candidate_answer: The candidate answer to evaluate.
:param original_instruction: The original instruction.
:param eval_template: The template to use for the evaluation.
:param answer_start_token: The token to use to indicate the start of the answer.
:param engine: The engine to use for the evaluation.
:return: The evaluation of the candidate answer.
"""
response = openai.Completion.create(
engine=engine,
prompt=eval_template.format(candidate_answer, original_instruction),
temperature=0,
max_tokens=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
logprobs=1,
echo=True,
)
answer_start = rindex(
response["choices"][0]["logprobs"]["tokens"], answer_start_token
)
logprobs = response["choices"][0]["logprobs"]["token_logprobs"][answer_start + 1 :]
return sum(logprobs) / len(logprobs)
def backtranslation(
prompt_template: str,
additional_info: str,
instruction: str,
eval_template: str,
priming_prefix: str = "SELECT",
stop1: List[str] = ["#", ";"],
answer_start_token: str = "--",
n: int = 5,
temperature: float = 0.5,
return_all_results: bool = False,
engine: str = "davinci-codex",
) -> Union[str, List[str, float]]:
"""
Generate a number of SQL queries given a natural language instruction,
and pick the best one based on the average log probability of explaining the
candidate SQL query with the exact original instruction, when prompted for
a natural language explanation of the candidate SQL query.
:param prompt_template: The template to use for the prompt to generate SQL.
:param additional_info: Additional information to include in the prompt
(SQL Tables, and their properties).
:param instruction: The instruction in natural language.
:param eval_template: The template to use for the evaluation.
:param priming_prefix: The prefix to use for the priming of the SQL query.
:param stop1: A list of tokens that indicate the end of the generation.
:param answer_start_token: The token to use to indicate the start of the
natural answer.
:param n: The number of candidates to generate.
:param temperature: The temperature of the generation.
:param return_all_results: Whether to return all results or just the best one.
:param engine: The engine to use for the generation and evaluation.
:return: The best SQL query, or a list of all scored generated SQL queries.
"""
prompt_template = prompt_template.format(
additional_info, instruction, priming_prefix
)
candidates = []
responses = get_candidates(
prompt_template, stop1, temperature, priming_prefix, engine=engine, n=n
)
for i in range(n):
quality = eval_candidate(
responses[i],
instruction,
eval_template,
answer_start_token,
engine=engine,
)
candidates.append((responses[i], quality))
candidates.sort(key=lambda x: x[1], reverse=True)
if return_all_results:
return candidates
return candidates[0][0]
def main(
nl_query: str = "Return the name of each department that had more than 10 employees in June 2021",
eval_template: str = "{};\n-- Explanation of the above query in human readable format\n-- {}",
table_definitions: str = "# Employee(id, name, department_id)\n# Department(id, name, address)\n# Salary_Payments(id, employee_id, amount, date)\n",
prompt_template: str = "### Postgres SQL tables, with their properties:\n#\n{}#\n### {}\n{}",
n: int = 3,
temperature: float = 0.3,
engine: str = "davinci-codex",
):
"""
Generate a number of SQL queries given a natural language instruction,
and pick the best one based on the highest backtranslation score.
:param nl_query: The natural language query.
:param eval_template: The template to use for the evaluation.
:param table_definitions: The definitions of the tables used in the query.
:param prompt_template: The template to use for the prompt to generate SQL.
:param n: The number of candidates to generate.
:param temperature: The temperature of the generation.
:param engine: The engine to use for the generation and evaluation.
:return: The best SQL query, or a list of all scored generated SQL queries.
"""
result = backtranslation(
prompt_template,
table_definitions,
nl_query,
eval_template,
priming_prefix="SELECT",
temperature=temperature,
n=n,
engine=engine,
)
print(result)
if __name__ == "__main__":
Smokey(main)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

@ -58,8 +58,7 @@
"import os\n",
"from ast import literal_eval\n",
"\n",
"#openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
"client = OpenAI()\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
"COMPLETIONS_MODEL = \"gpt-3.5-turbo\"\n",
"\n",
"# This path leads to a file with data and precomputed embeddings\n",
@ -399,7 +398,7 @@
" .sample(transactions_per_cluster, random_state=42)\n",
" .values\n",
" )\n",
" completion = client.chat.completions.create(\n",
" response = client.chat.completions.create(\n",
" model=COMPLETIONS_MODEL,\n",
" # We'll include a prompt to instruct the model what sort of description we're looking for\n",
" messages=[\n",
@ -413,7 +412,7 @@
" frequency_penalty=0,\n",
" presence_penalty=0,\n",
" )\n",
" print(completion.choices[0].message.content.replace(\"\\n\", \"\"))\n",
" print(response.choices[0].message.content.replace(\"\\n\", \"\"))\n",
" print(\"\\n\")\n",
"\n",
" sample_cluster_rows = embedding_df[embedding_df.Cluster == i].sample(transactions_per_cluster, random_state=42)\n",

@ -22,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@ -110,15 +110,15 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of .py files: 57\n",
"Total number of functions extracted: 118\n"
"Total number of .py files: 51\n",
"Total number of functions extracted: 97\n"
]
}
],
@ -138,12 +138,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have our content, we can pass the data to the text-embedding-ada-002 endpoint to get back our vector embeddings."
"Now that we have our content, we can pass the data to the `text-embedding-3-small` model and get back our vector embeddings."
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@ -179,35 +179,35 @@
" <td>def _console_log_level():\\n if openai.log i...</td>\n",
" <td>_console_log_level</td>\n",
" <td>openai/util.py</td>\n",
" <td>[0.033906757831573486, -0.00418944051489234, 0...</td>\n",
" <td>[0.005937571171671152, 0.05450401455163956, 0....</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>def log_debug(message, **params):\\n msg = l...</td>\n",
" <td>log_debug</td>\n",
" <td>openai/util.py</td>\n",
" <td>[-0.004059609025716782, 0.004895503632724285, ...</td>\n",
" <td>[0.017557814717292786, 0.05647840350866318, -0...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>def log_info(message, **params):\\n msg = lo...</td>\n",
" <td>log_info</td>\n",
" <td>openai/util.py</td>\n",
" <td>[0.0048639848828315735, 0.0033139237202703953,...</td>\n",
" <td>[0.022524144500494003, 0.06219055876135826, -0...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>def log_warn(message, **params):\\n msg = lo...</td>\n",
" <td>log_warn</td>\n",
" <td>openai/util.py</td>\n",
" <td>[0.0024026145692914724, -0.010721310041844845,...</td>\n",
" <td>[0.030524108558893204, 0.0667714849114418, -0....</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>def logfmt(props):\\n def fmt(key, val):\\n ...</td>\n",
" <td>logfmt</td>\n",
" <td>openai/util.py</td>\n",
" <td>[0.01664826273918152, 0.01730910874903202, 0.0...</td>\n",
" <td>[0.05337328091263771, 0.03697286546230316, -0....</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
@ -222,14 +222,14 @@
"4 def logfmt(props):\\n def fmt(key, val):\\n ... logfmt \n",
"\n",
" filepath code_embedding \n",
"0 openai/util.py [0.033906757831573486, -0.00418944051489234, 0... \n",
"1 openai/util.py [-0.004059609025716782, 0.004895503632724285, ... \n",
"2 openai/util.py [0.0048639848828315735, 0.0033139237202703953,... \n",
"3 openai/util.py [0.0024026145692914724, -0.010721310041844845,... \n",
"4 openai/util.py [0.01664826273918152, 0.01730910874903202, 0.0... "
"0 openai/util.py [0.005937571171671152, 0.05450401455163956, 0.... \n",
"1 openai/util.py [0.017557814717292786, 0.05647840350866318, -0... \n",
"2 openai/util.py [0.022524144500494003, 0.06219055876135826, -0... \n",
"3 openai/util.py [0.030524108558893204, 0.0667714849114418, -0.... \n",
"4 openai/util.py [0.05337328091263771, 0.03697286546230316, -0.... "
]
},
"execution_count": 11,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@ -238,7 +238,7 @@
"from utils.embeddings_utils import get_embedding\n",
"\n",
"df = pd.DataFrame(all_funcs)\n",
"df['code_embedding'] = df['code'].apply(lambda x: get_embedding(x, model='text-embedding-ada-002'))\n",
"df['code_embedding'] = df['code'].apply(lambda x: get_embedding(x, model='text-embedding-3-small'))\n",
"df['filepath'] = df['filepath'].map(lambda x: Path(x).relative_to(code_root))\n",
"df.to_csv(\"data/code_search_openai-python.csv\", index=False)\n",
"df.head()"
@ -255,21 +255,21 @@
"\n",
"We define a search_functions method that takes our data that contains our embeddings, a query string, and some other configuration options. The process of searching our database works like such:\n",
"\n",
"1. We first embed our query string (code_query) with text-embedding-ada-002. The reasoning here is that a query string like 'a function that reverses a string' and a function like 'def reverse(string): return string[::-1]' will be very similar when embedded.\n",
"1. We first embed our query string (code_query) with `text-embedding-3-small`. The reasoning here is that a query string like 'a function that reverses a string' and a function like 'def reverse(string): return string[::-1]' will be very similar when embedded.\n",
"2. We then calculate the cosine similarity between our query string embedding and all data points in our database. This gives a distance between each point and our query.\n",
"3. We finally sort all of our data points by their distance to our query string and return the number of results requested in the function parameters. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from utils.embeddings_utils import cosine_similarity\n",
"\n",
"def search_functions(df, code_query, n=3, pprint=True, n_lines=7):\n",
" embedding = get_embedding(code_query, model='text-embedding-ada-002')\n",
" embedding = get_embedding(code_query, model='text-embedding-3-small')\n",
" df['similarities'] = df.code_embedding.apply(lambda x: cosine_similarity(x, embedding))\n",
"\n",
" res = df.sort_values('similarities', ascending=False).head(n)\n",
@ -285,14 +285,14 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"openai/validators.py:format_inferrer_validator score=0.751\n",
"openai/validators.py:format_inferrer_validator score=0.453\n",
"def format_inferrer_validator(df):\n",
" \"\"\"\n",
" This validator will infer the likely fine-tuning format of the data, and display it to the user if it is classification.\n",
@ -301,16 +301,7 @@
" ft_type = infer_task_type(df)\n",
" immediate_msg = None\n",
"----------------------------------------------------------------------\n",
"openai/validators.py:get_validators score=0.748\n",
"def get_validators():\n",
" return [\n",
" num_examples_validator,\n",
" lambda x: necessary_column_validator(x, \"prompt\"),\n",
" lambda x: necessary_column_validator(x, \"completion\"),\n",
" additional_column_validator,\n",
" non_empty_field_validator,\n",
"----------------------------------------------------------------------\n",
"openai/validators.py:infer_task_type score=0.739\n",
"openai/validators.py:infer_task_type score=0.37\n",
"def infer_task_type(df):\n",
" \"\"\"\n",
" Infer the likely fine-tuning task type from the data\n",
@ -318,6 +309,15 @@
" CLASSIFICATION_THRESHOLD = 3 # min_average instances of each class\n",
" if sum(df.prompt.str.len()) == 0:\n",
" return \"open-ended generation\"\n",
"----------------------------------------------------------------------\n",
"openai/validators.py:apply_validators score=0.369\n",
"def apply_validators(\n",
" df,\n",
" fname,\n",
" remediation,\n",
" validators,\n",
" auto_accept,\n",
" write_out_file_func,\n",
"----------------------------------------------------------------------\n"
]
}
@ -328,14 +328,14 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"openai/validators.py:get_common_xfix score=0.794\n",
"openai/validators.py:get_common_xfix score=0.487\n",
"def get_common_xfix(series, xfix=\"suffix\"):\n",
" \"\"\"\n",
" Finds the longest common suffix or prefix of all the values in a series\n",
@ -347,7 +347,7 @@
" if xfix == \"suffix\"\n",
" else series.str[: len(common_xfix) + 1]\n",
"----------------------------------------------------------------------\n",
"openai/validators.py:common_completion_suffix_validator score=0.778\n",
"openai/validators.py:common_completion_suffix_validator score=0.449\n",
"def common_completion_suffix_validator(df):\n",
" \"\"\"\n",
" This validator will suggest to add a common suffix to the completion if one doesn't already exist in case of classification or conditional generation.\n",
@ -368,14 +368,14 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"openai/cli.py:tools_register score=0.78\n",
"openai/cli.py:tools_register score=0.391\n",
"def tools_register(parser):\n",
" subparsers = parser.add_subparsers(\n",
" title=\"Tools\", help=\"Convenience client side tools\"\n",
@ -421,7 +421,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
"version": "3.11.5"
},
"orig_nbformat": 4
},

@ -25,13 +25,14 @@
"source": [
"from IPython.display import display, Image\n",
"from openai import OpenAI\n",
"import os\n",
"import pandas as pd\n",
"import json\n",
"import io\n",
"from PIL import Image\n",
"import requests\n",
"\n",
"client = OpenAI()\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
"\n",
"#Lets import some helper functions for assistants from https://cookbook.openai.com/examples/assistants_api_overview_python\n",
"def show_json(obj):\n",

@ -78,7 +78,7 @@
"source": [
"# input parameters\n",
"embedding_cache_path = \"data/snli_embedding_cache.pkl\" # embeddings will be saved/loaded here\n",
"default_embedding_engine = \"babbage-similarity\" # text-embedding-ada-002 is recommended\n",
"default_embedding_engine = \"text-embedding-3-small\"\n",
"num_pairs_to_embed = 1000 # 1000 is arbitrary\n",
"local_dataset_path = \"data/snli_1.0_train_2k.csv\" # download from: https://nlp.stanford.edu/projects/snli/\n",
"\n",

@ -38,9 +38,12 @@
"import mwclient # for downloading example Wikipedia articles\n",
"import mwparserfromhell # for splitting Wikipedia articles into sections\n",
"import openai # for generating embeddings\n",
"import os # for environment variables\n",
"import pandas as pd # for DataFrames to store article sections and embeddings\n",
"import re # for cutting <ref> links out of Wikipedia articles\n",
"import tiktoken # for counting tokens\n"
"import tiktoken # for counting tokens\n",
"\n",
"client = openai.OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))"
]
},
{
@ -603,8 +606,7 @@
}
],
"source": [
"# calculate embeddings\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\" # OpenAI's best embeddings as of Apr 2023\n",
"EMBEDDING_MODEL = \"text-embedding-3-small\"\n",
"BATCH_SIZE = 1000 # you can submit up to 2048 embedding inputs per request\n",
"\n",
"embeddings = []\n",
@ -612,10 +614,10 @@
" batch_end = batch_start + BATCH_SIZE\n",
" batch = wikipedia_strings[batch_start:batch_end]\n",
" print(f\"Batch {batch_start} to {batch_end-1}\")\n",
" response = openai.Embedding.create(model=EMBEDDING_MODEL, input=batch)\n",
" for i, be in enumerate(response[\"data\"]):\n",
" assert i == be[\"index\"] # double check embeddings are in same order as input\n",
" batch_embeddings = [e[\"embedding\"] for e in response[\"data\"]]\n",
" response = client.embeddings.create(model=EMBEDDING_MODEL, input=batch)\n",
" for i, be in enumerate(response.data):\n",
" assert i == be.index # double check embeddings are in same order as input\n",
" batch_embeddings = [e.embedding for e in response.data]\n",
" embeddings.extend(batch_embeddings)\n",
"\n",
"df = pd.DataFrame({\"text\": wikipedia_strings, \"embedding\": embeddings})\n"

@ -9,7 +9,7 @@
"\n",
"OpenAI's embedding models cannot embed text that exceeds a maximum length. The maximum length varies by model, and is measured by _tokens_, not string length. If you are unfamiliar with tokenization, check out [How to count tokens with tiktoken](How_to_count_tokens_with_tiktoken.ipynb).\n",
"\n",
"This notebook shows how to handle texts that are longer than a model's maximum context length. We'll demonstrate using embeddings from `text-embedding-ada-002`, but the same ideas can be applied to other models and tasks. To learn more about embeddings, check out the [OpenAI Embeddings Guide](https://beta.openai.com/docs/guides/embeddings).\n"
"This notebook shows how to handle texts that are longer than a model's maximum context length. We'll demonstrate using embeddings from `text-embedding-3-small`, but the same ideas can be applied to other models and tasks. To learn more about embeddings, check out the [OpenAI Embeddings Guide](https://beta.openai.com/docs/guides/embeddings).\n"
]
},
{
@ -24,41 +24,44 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from openai import OpenAI\n",
"import os\n",
"import openai\n",
"from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_not_exception_type\n",
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
"\n",
"EMBEDDING_MODEL = 'text-embedding-ada-002'\n",
"EMBEDDING_MODEL = 'text-embedding-3-small'\n",
"EMBEDDING_CTX_LENGTH = 8191\n",
"EMBEDDING_ENCODING = 'cl100k_base'\n",
"\n",
"# let's make sure to not retry on an invalid request, because that is what we want to demonstrate\n",
"@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(openai.InvalidRequestError))\n",
"@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(openai.BadRequestError))\n",
"def get_embedding(text_or_tokens, model=EMBEDDING_MODEL):\n",
" return openai.Embedding.create(input=text_or_tokens, model=model)[\"data\"][0][\"embedding\"]"
" return client.embeddings.create(input=text_or_tokens, model=model).data[0].embedding"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `text-embedding-ada-002` model has a context length of 8191 tokens with the `cl100k_base` encoding, and we can see that going over that limit causes an error."
"The `text-embedding-3-small` model has a context length of 8191 tokens with the `cl100k_base` encoding, and we can see that going over that limit causes an error."
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"This model's maximum context length is 8191 tokens, however you requested 10001 tokens (10001 in your prompt; 0 for the completion). Please reduce your prompt; or completion length.\n"
"Error code: 400 - {'error': {'message': \"This model's maximum context length is 8192 tokens, however you requested 10001 tokens (10001 in your prompt; 0 for the completion). Please reduce your prompt; or completion length.\", 'type': 'invalid_request_error', 'param': None, 'code': None}}\n"
]
}
],
@ -66,7 +69,7 @@
"long_text = 'AGI ' * 5000\n",
"try:\n",
" get_embedding(long_text)\n",
"except openai.InvalidRequestError as e:\n",
"except openai.BadRequestError as e:\n",
" print(e)"
]
},
@ -267,7 +270,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
"version": "3.11.5"
},
"vscode": {
"interpreter": {

@ -44,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@ -53,6 +53,8 @@
"import openai\n",
"import tiktoken\n",
"\n",
"client = openai.OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
"\n",
"# Extract the raw text from each PDF using textract\n",
"text = textract.process('data/fia_f1_power_unit_financial_regulations_issue_1_-_2022-08-16.pdf', method='pdfminer').decode('utf-8')\n",
"clean_text = text.replace(\" \", \" \").replace(\"\\n\", \"; \").replace(';',' ')"
@ -67,7 +69,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 20,
"metadata": {},
"outputs": [
{
@ -96,13 +98,13 @@
"template_prompt=f'''Extract key pieces of information from this regulation document.\n",
"If a particular piece of information is not present, output \\\"Not specified\\\".\n",
"When you extract a key piece of information, include the closest page number.\n",
"Use the following format:\\n0. Who is the author\\n1. What is the amount of the \"Power Unit Cost Cap\" in USD, GBP and EUR\\n2. What is the value of External Manufacturing Costs in USD\\n3. What is the Capital Expenditure Limit in USD\\n\\nDocument: \\\"\\\"\\\"{document}\\\"\\\"\\\"\\n\\n0. Who is the author: Tom Anderson (Page 1)\\n1.'''\n",
"Use the following format:\\n0. Who is the author\\n1. What is the amount of the \"Power Unit Cost Cap\" in USD, GBP and EUR\\n2. What is the value of External Manufacturing Costs in USD\\n3. What is the Capital Expenditure Limit in USD\\n\\nDocument: \\\"\\\"\\\"<document>\\\"\\\"\\\"\\n\\n0. Who is the author: Tom Anderson (Page 1)\\n1.'''\n",
"print(template_prompt)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@ -127,19 +129,23 @@
" i = j\n",
"\n",
"def extract_chunk(document,template_prompt):\n",
" \n",
" prompt=template_prompt.replace('<document>',document)\n",
" prompt = template_prompt.replace('<document>',document)\n",
"\n",
" response = openai.Completion.create(\n",
" model='text-davinci-003', \n",
" prompt=prompt,\n",
" temperature=0,\n",
" max_tokens=1500,\n",
" top_p=1,\n",
" frequency_penalty=0,\n",
" presence_penalty=0\n",
" )\n",
" return \"1.\" + response['choices'][0]['text']"
" messages = [\n",
" {\"role\": \"system\", \"content\": \"You help extract information from documents.\"},\n",
" {\"role\": \"user\", \"content\": prompt}\n",
" ]\n",
"\n",
" response = client.chat.completions.create(\n",
" model='gpt-4', \n",
" messages=messages,\n",
" temperature=0,\n",
" max_tokens=1500,\n",
" top_p=1,\n",
" frequency_penalty=0,\n",
" presence_penalty=0\n",
" )\n",
" return \"1.\" + response.choices[0].message.content"
]
},
{
@ -228,7 +234,7 @@
"template_prompt=f'''Extract key pieces of information from this regulation document.\n",
"If a particular piece of information is not present, output \\\"Not specified\\\".\n",
"When you extract a key piece of information, include the closest page number.\n",
"Use the following format:\\n0. Who is the author\\n1. How is a Minor Overspend Breach calculated\\n2. How is a Major Overspend Breach calculated\\n3. Which years do these financial regulations apply to\\n\\nDocument: \\\"\\\"\\\"{document}\\\"\\\"\\\"\\n\\n0. Who is the author: Tom Anderson (Page 1)\\n1.'''\n",
"Use the following format:\\n0. Who is the author\\n1. How is a Minor Overspend Breach calculated\\n2. How is a Major Overspend Breach calculated\\n3. Which years do these financial regulations apply to\\n\\nDocument: \\\"\\\"\\\"<document>\\\"\\\"\\\"\\n\\n0. Who is the author: Tom Anderson (Page 1)\\n1.'''\n",
"print(template_prompt)"
]
},
@ -306,9 +312,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "embed_retrieve",
"display_name": "Python 3",
"language": "python",
"name": "embed_retrieve"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
@ -320,12 +326,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"vscode": {
"interpreter": {
"hash": "5997d090960a54cd76552f75eca12ec3b416cf9d01a1a5af08ae48cf90878791"
}
"version": "3.11.3"
}
},
"nbformat": 4,

File diff suppressed because one or more lines are too long

@ -21,7 +21,7 @@
"metadata": {},
"source": [
"For context, from the function calling notebook above:\n",
"> `functions` is an optional parameter in the Chat Completion API which can be used to provide function specifications. The purpose of this is to enable models to generate function arguments which adhere to the provided specifications. Note that the API will not actually execute any function calls. It is up to developers to execute function calls using model outputs."
"> `tools` is an optional parameter in the Chat Completion API which can be used to provide function specifications. The purpose of this is to enable models to generate function arguments which adhere to the provided specifications. Note that the API will not actually execute any function calls. It is up to developers to execute function calls using model outputs."
]
},
{
@ -33,7 +33,7 @@
"invocations, and incorrect invocations).\\\n",
"Before fine tuning for function calling, it's best to begin with:\n",
"- Improvements to the function definitions. Make them more clear, and more distinct from one another.\n",
"- Experiment with prompt engineering: often a more detailed prompt can help the model call the correct function.\\\n",
"- Experiment with prompt engineering: often a more detailed prompt can help the model call the correct function.\n",
"\n",
"*If* the steps above fail to improve function calling to a satisfactory level, then you can try fine tuning for function calling."
]
@ -76,17 +76,16 @@
"outputs": [],
"source": [
"# !pip install tenacity\n",
"# !pip insta openai\n",
"# !pip install typing\n"
"# !pip install openai\n",
"# !pip install typing"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"import numpy as np\n",
"import json\n",
"import os\n",
@ -95,7 +94,8 @@
"from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
"from typing import Any, Dict, List, Generator\n",
"import ast\n",
"client = OpenAI()\n"
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))"
]
},
{
@ -114,17 +114,17 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def get_chat_completion(\n",
" messages: list[dict[str, str]],\n",
" model: str = \"gpt-4\",\n",
" model: str = \"gpt-3.5-turbo\",\n",
" max_tokens=500,\n",
" temperature=1.0,\n",
" stop=None,\n",
" functions=None,\n",
" tools=None,\n",
") -> str:\n",
" params = {\n",
" 'model': model,\n",
@ -132,9 +132,8 @@
" 'max_tokens': max_tokens,\n",
" 'temperature': temperature,\n",
" 'stop': stop,\n",
" 'tools': tools,\n",
" }\n",
" if functions:\n",
" params['tools'] = functions\n",
"\n",
" completion = client.chat.completions.create(**params)\n",
" return completion.choices[0].message\n"
@ -158,7 +157,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@ -176,7 +175,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@ -468,7 +467,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@ -507,7 +506,7 @@
" messages = []\n",
" messages.append({\"role\": \"system\", \"content\": DRONE_SYSTEM_PROMPT})\n",
" messages.append({\"role\": \"user\", \"content\": prompt})\n",
" completion = get_chat_completion(model=\"gpt-3.5-turbo\",messages=messages,functions=function_list)\n",
" completion = get_chat_completion(model=\"gpt-3.5-turbo\",messages=messages,tools=function_list)\n",
" print(prompt)\n",
" print(completion.tool_calls[0].function,'\\n')\n"
]
@ -532,26 +531,6 @@
" \"Change drone's paint job color\"]\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_6rwNKuIgO7l8Izw5qoFcuSuk', function=Function(arguments='{}', name='reject_request'), type='function')])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"completion\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
@ -567,17 +546,17 @@
"\n",
"\n",
"Initiate live-streaming on social media\n",
"Function(arguments='{\\n\"mode\": \"video\"\\n}', name='control_camera') \n",
"Function(arguments='{\\n\"mode\": \"video\",\\n\"duration\": 0\\n}', name='control_camera') \n",
"\n",
"\n",
"\n",
"Scan environment for heat signatures\n",
"Function(arguments='{\\n \"mode\": \"photo\"\\n}', name='control_camera') \n",
"Function(arguments='{ \"mode\": \"photo\" }', name='control_camera') \n",
"\n",
"\n",
"\n",
"Enable stealth mode\n",
"Function(arguments='{\\n \"pattern\": \"solid\",\\n \"color\": \"black\"\\n}', name='configure_led_display') \n",
"Function(arguments='{\\n \"mode\": \"off\"\\n}', name='set_drone_lighting') \n",
"\n",
"\n",
"\n",
@ -594,7 +573,7 @@
" messages = []\n",
" messages.append({\"role\": \"system\", \"content\": DRONE_SYSTEM_PROMPT})\n",
" messages.append({\"role\": \"user\", \"content\": prompt})\n",
" completion = get_chat_completion(model=\"gpt-3.5-turbo\",messages=messages,functions=function_list)\n",
" completion = get_chat_completion(model=\"gpt-3.5-turbo\",messages=messages,tools=function_list)\n",
" print(prompt)\n",
" try:\n",
" print(completion.tool_calls[0].function,'\\n')\n",
@ -609,7 +588,7 @@
"metadata": {},
"source": [
"Now we run into some problems.\n",
"The model here should reject all of these requests, as they are impossible given the functions, however instead the model calls functions that are somewhat related to the request, but incorrect. The model sets the camera to video when asked to begin 'live streaming to social media', and changes the LED's to blue when asked to 'change the paint color'...\\\n",
"The model here should reject all of these requests, as they are impossible given the functions, however instead the model calls functions that are somewhat related to the request, but incorrect. The model sets the camera to video when asked to begin 'live streaming to social media', and changes the LED's to blue when asked to 'change the paint color'...\n",
"<br>\n",
"In this simple case, more prompt engineering may resolve some of these issues, but for the purpose of this example we will demonstrate how fine tuning can be used to improve performance. Additionally, while this case is relatively straightforward, as the number of and complexity of the functions increases, fine tuning becomes more and more impactful."
]
@ -843,7 +822,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
@ -860,7 +839,7 @@
" \"arguments\": arguments\n",
" }\n",
" messages = [{\"role\": \"user\", \"content\": INVOCATION_FILLER_PROMPT.format(invocation=input_object,function=function)}]\n",
" input_object = get_chat_completion(model='gpt-4',messages=messages, max_tokens = 200,temperature=.1).content\n",
" input_object = get_chat_completion(model='gpt-4', messages=messages, max_tokens = 200, temperature=.1).content\n",
" else:\n",
" input_object = {\n",
" \"name\": func_name,\n",
@ -879,7 +858,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@ -904,7 +883,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 23,
"metadata": {},
"outputs": [
{
@ -928,7 +907,7 @@
"\u001b[34m12.3% complete\u001b[0m\n",
"{'name': 'control_drone_movement', 'arguments': {'direction': 'right', 'distance': 10}}\n",
"\u001b[34m14.0% complete\u001b[0m\n",
"{'name': 'control_drone_movement', 'arguments': {'direction': 'up', 'distance': 10}}\n",
"{'name': 'control_drone_movement', 'arguments': {'direction': 'up', 'distance': 20}}\n",
"\u001b[34m15.8% complete\u001b[0m\n",
"{'name': 'control_drone_movement', 'arguments': {'direction': 'down', 'distance': 10}}\n",
"\u001b[34m17.5% complete\u001b[0m\n",
@ -946,7 +925,7 @@
"\u001b[34m28.1% complete\u001b[0m\n",
"{'name': 'control_camera', 'arguments': {'mode': 'panorama', 'duration': 0}}\n",
"\u001b[34m29.8% complete\u001b[0m\n",
"{'name': 'control_gimbal', 'arguments': {'tilt': 45, 'pan': 90}}\n",
"{'name': 'control_gimbal', 'arguments': {'tilt': 45, 'pan': 30}}\n",
"\u001b[34m31.6% complete\u001b[0m\n",
"{'name': 'set_drone_lighting', 'arguments': {'mode': 'on'}}\n",
"\u001b[34m33.3% complete\u001b[0m\n",
@ -1029,7 +1008,7 @@
}
],
"source": [
"training_examples_unformatted = create_commands(input_objects)\n"
"training_examples_unformatted = create_commands(input_objects)"
]
},
{
@ -1041,7 +1020,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
@ -1069,7 +1048,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
@ -1092,7 +1071,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
@ -1115,7 +1094,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
@ -1124,7 +1103,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
@ -1171,7 +1150,6 @@
" file_id = file.id\n",
" print(file_id)\n",
" ft = client.fine_tuning.jobs.create(\n",
" # model=\"gpt-4-0613\",\n",
" model=\"gpt-3.5-turbo\",\n",
" training_file=file_id,\n",
")\n"
@ -1224,7 +1202,7 @@
" messages = []\n",
" messages.append({\"role\": \"system\", \"content\": DRONE_SYSTEM_PROMPT})\n",
" messages.append({\"role\": \"user\", \"content\": eval_question})\n",
" completion = get_chat_completion(model=\"ft:gpt-3.5-turbo-0613:openai-internal::8DloQKS2\",messages=messages,functions=function_list)\n",
" completion = get_chat_completion(model=\"ft:gpt-3.5-turbo-0613:openai-internal::8DloQKS2\",messages=messages,tools=function_list)\n",
" print(eval_question)\n",
" print(completion.tool_calls[0].function.name,'\\n')\n"
]
@ -1267,7 +1245,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.11.5"
}
},
"nbformat": 4,

@ -58,12 +58,11 @@
"outputs": [],
"source": [
"import json\n",
"import openai\n",
"from openai import OpenAI\n",
"import os\n",
"import requests\n",
"\n",
"client = OpenAI()\n"
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))"
]
},
{
@ -228,7 +227,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@ -252,33 +251,36 @@
" {\"role\": \"user\", \"content\": user_input}\n",
"],\n",
" temperature=0,\n",
" functions=[\n",
" tools=[\n",
" {\n",
" \"name\": \"call_google_places_api\",\n",
" \"description\": \"This function calls the Google Places API to find the top places of a specified type near a specific location. It can be used when a user expresses a need (e.g., feeling hungry or tired) or wants to find a certain type of place (e.g., restaurant or hotel).\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"place_type\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The type of place to search for.\"\n",
" \"type\": \"function\",\n",
" \"function\" : {\n",
" \"name\": \"call_google_places_api\",\n",
" \"description\": \"This function calls the Google Places API to find the top places of a specified type near a specific location. It can be used when a user expresses a need (e.g., feeling hungry or tired) or wants to find a certain type of place (e.g., restaurant or hotel).\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"place_type\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The type of place to search for.\"\n",
" }\n",
" }\n",
" },\n",
" \"result\": {\n",
" \"type\": \"array\",\n",
" \"items\": {\n",
" \"type\": \"string\"\n",
" }\n",
" }\n",
" },\n",
" \"result\": {\n",
" \"type\": \"array\",\n",
" \"items\": {\n",
" \"type\": \"string\"\n",
" }\n",
" }\n",
" }\n",
" ],\n",
" )\n",
"\n",
" print(response.choices[0].message.function_call)\n",
" print(response.choices[0].message.tool_calls)\n",
"\n",
" if response.choices[0].finish_reason=='function_call':\n",
" function_call = response.choices[0].message.function_call\n",
" if response.choices[0].finish_reason=='tool_calls':\n",
" function_call = response.choices[0].message.tool_calls[0].function\n",
" if function_call.name == \"call_google_places_api\":\n",
" place_type = json.loads(function_call.arguments)[\"place_type\"]\n",
" places = call_google_places_api(user_id, place_type, food_preference)\n",
@ -302,15 +304,15 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FunctionCall(arguments='{\\n \"place_type\": \"restaurant\"\\n}', name='call_google_places_api')\n",
"I couldn't find any places of interest nearby.\n"
"[ChatCompletionMessageToolCall(id='call_Q1mXIi7D6GhobfE4tkruX7nB', function=Function(arguments='{\\n \"place_type\": \"restaurant\"\\n}', name='call_google_places_api'), type='function')]\n",
"Here are some places you might be interested in: Sotto Mare is a restaurant located at 552 Green Street. It has a rating of 4.6 based on 3765 user reviews. Mona Lisa Restaurant is a restaurant located at 353 Columbus Avenue #3907. It has a rating of 4.4 based on 1888 user reviews.\n"
]
}
],
@ -338,7 +340,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.11.5"
},
"orig_nbformat": 4
},

File diff suppressed because one or more lines are too long

@ -27,7 +27,7 @@
"import os\n",
"import requests\n",
"\n",
"client = OpenAI()"
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))"
]
},
{

@ -26,32 +26,30 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"import pandas as pd\n",
"import tiktoken\n",
"\n",
"from utils.embeddings_utils import get_embedding\n"
"from utils.embeddings_utils import get_embedding"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# embedding model parameters\n",
"embedding_model = \"text-embedding-ada-002\"\n",
"embedding_encoding = \"cl100k_base\" # this the encoding for text-embedding-ada-002\n",
"max_tokens = 8000 # the maximum for text-embedding-ada-002 is 8191\n"
"embedding_model = \"text-embedding-3-small\"\n",
"embedding_encoding = \"cl100k_base\"\n",
"max_tokens = 8000 # the maximum for text-embedding-3-small is 8191"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 17,
"metadata": {},
"outputs": [
{
@ -127,7 +125,7 @@
"1 Title: Arrived in pieces; Content: Not pleased... "
]
},
"execution_count": 8,
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
@ -141,12 +139,12 @@
"df[\"combined\"] = (\n",
" \"Title: \" + df.Summary.str.strip() + \"; Content: \" + df.Text.str.strip()\n",
")\n",
"df.head(2)\n"
"df.head(2)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 18,
"metadata": {},
"outputs": [
{
@ -155,7 +153,7 @@
"1000"
]
},
"execution_count": 9,
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
@ -171,7 +169,7 @@
"# omit reviews that are too long to embed\n",
"df[\"n_tokens\"] = df.combined.apply(lambda x: len(encoding.encode(x)))\n",
"df = df[df.n_tokens <= max_tokens].tail(top_n)\n",
"len(df)\n"
"len(df)"
]
},
{
@ -184,7 +182,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
@ -192,7 +190,16 @@
"\n",
"# This may take a few minutes\n",
"df[\"embedding\"] = df.combined.apply(lambda x: get_embedding(x, model=embedding_model))\n",
"df.to_csv(\"data/fine_food_reviews_with_embeddings_1k.csv\")\n"
"df.to_csv(\"data/fine_food_reviews_with_embeddings_1k.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"a = get_embedding(\"hi\", model=embedding_model)"
]
}
],
@ -212,7 +219,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.11.5"
},
"orig_nbformat": 4,
"vscode": {

@ -662,7 +662,7 @@
"import wget\n",
"\n",
"# Here is a URL to a zip archive containing the transcribed podcasts\n",
"# Note that this data has already been split into chunks and embeddings from OpenAI's text-embedding-ada-002 embedding model are included\n",
"# Note that this data has already been split into chunks and embeddings from OpenAI's `text-embedding-3-small` embedding model are included\n",
"content_url = 'https://cdn.openai.com/API/examples/data/sysk_podcast_transcripts_embedded.json.zip'\n",
"\n",
"# Download the file (it is ~541 MB so this will take some time)\n",

@ -39,94 +39,47 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: scipy in /opt/homebrew/lib/python3.11/site-packages (1.10.1)\n",
"Requirement already satisfied: numpy<1.27.0,>=1.19.5 in /opt/homebrew/lib/python3.11/site-packages (from scipy) (1.24.3)\n",
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: tenacity in /opt/homebrew/lib/python3.11/site-packages (8.2.2)\n",
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: tiktoken in /opt/homebrew/lib/python3.11/site-packages (0.4.0)\n",
"Requirement already satisfied: regex>=2022.1.18 in /opt/homebrew/lib/python3.11/site-packages (from tiktoken) (2023.6.3)\n",
"Requirement already satisfied: requests>=2.26.0 in /opt/homebrew/lib/python3.11/site-packages (from tiktoken) (2.30.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (3.1.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (1.25.11)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (2023.5.7)\n",
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: termcolor in /opt/homebrew/lib/python3.11/site-packages (2.3.0)\n",
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: openai in /opt/homebrew/lib/python3.11/site-packages (0.27.6)\n",
"Requirement already satisfied: requests>=2.20 in /opt/homebrew/lib/python3.11/site-packages (from openai) (2.30.0)\n",
"Requirement already satisfied: tqdm in /opt/homebrew/lib/python3.11/site-packages (from openai) (4.65.0)\n",
"Requirement already satisfied: aiohttp in /opt/homebrew/lib/python3.11/site-packages (from openai) (3.8.4)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.20->openai) (3.1.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.20->openai) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.20->openai) (1.25.11)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/lib/python3.11/site-packages (from requests>=2.20->openai) (2023.5.7)\n",
"Requirement already satisfied: attrs>=17.3.0 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (23.1.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (4.0.2)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (1.3.3)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /opt/homebrew/lib/python3.11/site-packages (from aiohttp->openai) (1.3.1)\n",
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: requests in /opt/homebrew/lib/python3.11/site-packages (2.30.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/homebrew/lib/python3.11/site-packages (from requests) (3.1.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/lib/python3.11/site-packages (from requests) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/homebrew/lib/python3.11/site-packages (from requests) (1.25.11)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/lib/python3.11/site-packages (from requests) (2023.5.7)\n",
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: arxiv in /opt/homebrew/lib/python3.11/site-packages (1.4.7)\n",
"Requirement already satisfied: feedparser in /opt/homebrew/lib/python3.11/site-packages (from arxiv) (6.0.10)\n",
"Requirement already satisfied: sgmllib3k in /opt/homebrew/lib/python3.11/site-packages (from feedparser->arxiv) (1.0.0)\n",
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[0mRequirement already satisfied: pandas in /opt/homebrew/lib/python3.11/site-packages (2.0.1)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /opt/homebrew/lib/python3.11/site-packages (from pandas) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /opt/homebrew/lib/python3.11/site-packages (from pandas) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /opt/homebrew/lib/python3.11/site-packages (from pandas) (2023.3)\n",
"Requirement already satisfied: numpy>=1.21.0 in /opt/homebrew/lib/python3.11/site-packages (from pandas) (1.24.3)\n",
"Requirement already satisfied: six>=1.5 in /opt/homebrew/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: PyPDF2 in /opt/homebrew/lib/python3.11/site-packages (3.0.1)\n",
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: tqdm in /opt/homebrew/lib/python3.11/site-packages (4.65.0)\n",
"\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Skipping /opt/homebrew/lib/python3.11/site-packages/PyYAML-6.0-py3.11.egg-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n",
"\u001b[0m"
"Requirement already satisfied: scipy in /usr/local/lib/python3.11/site-packages (1.12.0)\n",
"Requirement already satisfied: numpy<1.29.0,>=1.22.4 in /usr/local/lib/python3.11/site-packages (from scipy) (1.26.3)\n",
"Requirement already satisfied: tenacity in /usr/local/lib/python3.11/site-packages (8.2.3)\n",
"Requirement already satisfied: tiktoken==0.3.3 in /usr/local/lib/python3.11/site-packages (0.3.3)\n",
"Requirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.11/site-packages (from tiktoken==0.3.3) (2023.12.25)\n",
"Requirement already satisfied: requests>=2.26.0 in /usr/local/lib/python3.11/site-packages (from tiktoken==0.3.3) (2.31.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken==0.3.3) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken==0.3.3) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken==0.3.3) (2.1.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken==0.3.3) (2023.11.17)\n",
"Requirement already satisfied: termcolor in /usr/local/lib/python3.11/site-packages (2.4.0)\n",
"Requirement already satisfied: openai in /usr/local/lib/python3.11/site-packages (1.10.0)\n",
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.11/site-packages (from openai) (4.2.0)\n",
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.11/site-packages (from openai) (1.9.0)\n",
"Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.11/site-packages (from openai) (0.26.0)\n",
"Requirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.11/site-packages (from openai) (2.5.3)\n",
"Requirement already satisfied: sniffio in /usr/local/lib/python3.11/site-packages (from openai) (1.3.0)\n",
"Requirement already satisfied: tqdm>4 in /usr/local/lib/python3.11/site-packages (from openai) (4.66.1)\n",
"Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.11/site-packages (from openai) (4.9.0)\n",
"Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.11/site-packages (from anyio<5,>=3.5.0->openai) (3.6)\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.11/site-packages (from httpx<1,>=0.23.0->openai) (2023.11.17)\n",
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.11/site-packages (from httpx<1,>=0.23.0->openai) (1.0.2)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.11/site-packages (from httpcore==1.*->httpx<1,>=0.23.0->openai) (0.14.0)\n",
"Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.11/site-packages (from pydantic<3,>=1.9.0->openai) (0.6.0)\n",
"Requirement already satisfied: pydantic-core==2.14.6 in /usr/local/lib/python3.11/site-packages (from pydantic<3,>=1.9.0->openai) (2.14.6)\n",
"Requirement already satisfied: arxiv in /usr/local/lib/python3.11/site-packages (2.1.0)\n",
"Requirement already satisfied: feedparser==6.0.10 in /usr/local/lib/python3.11/site-packages (from arxiv) (6.0.10)\n",
"Requirement already satisfied: requests==2.31.0 in /usr/local/lib/python3.11/site-packages (from arxiv) (2.31.0)\n",
"Requirement already satisfied: sgmllib3k in /usr/local/lib/python3.11/site-packages (from feedparser==6.0.10->arxiv) (1.0.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/site-packages (from requests==2.31.0->arxiv) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/site-packages (from requests==2.31.0->arxiv) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/site-packages (from requests==2.31.0->arxiv) (2.1.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/site-packages (from requests==2.31.0->arxiv) (2023.11.17)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.11/site-packages (2.2.0)\n",
"Requirement already satisfied: numpy<2,>=1.23.2 in /usr/local/lib/python3.11/site-packages (from pandas) (1.26.3)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/site-packages (from pandas) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/site-packages (from pandas) (2023.3.post1)\n",
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/site-packages (from pandas) (2023.4)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
"Requirement already satisfied: PyPDF2 in /usr/local/lib/python3.11/site-packages (3.0.1)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.11/site-packages (4.66.1)\n"
]
}
],
@ -136,7 +89,6 @@
"!pip install tiktoken==0.3.3\n",
"!pip install termcolor \n",
"!pip install openai\n",
"!pip install requests\n",
"!pip install arxiv\n",
"!pip install pandas\n",
"!pip install PyPDF2\n",
@ -151,26 +103,25 @@
"outputs": [],
"source": [
"import os\n",
"\n",
"import arxiv\n",
"import ast\n",
"import concurrent\n",
"from csv import writer\n",
"from IPython.display import display, Markdown, Latex\n",
"import json\n",
"import openai\n",
"import os\n",
"import pandas as pd\n",
"import tiktoken\n",
"from csv import writer\n",
"from IPython.display import display, Markdown, Latex\n",
"from openai import OpenAI\n",
"from PyPDF2 import PdfReader\n",
"import requests\n",
"from scipy import spatial\n",
"from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
"import tiktoken\n",
"from tqdm import tqdm\n",
"from termcolor import colored\n",
"\n",
"GPT_MODEL = \"gpt-3.5-turbo-0613\"\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n"
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
"client = OpenAI()"
]
},
{
@ -191,7 +142,15 @@
"execution_count": 3,
"id": "2de5d32d",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Directory './data/papers' already exists.\n"
]
}
],
"source": [
"directory = './data/papers'\n",
"\n",
@ -207,7 +166,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "ae5cb7a1",
"metadata": {},
"outputs": [],
@ -223,26 +182,30 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "57217b9d",
"metadata": {},
"outputs": [],
"source": [
"@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n",
"def embedding_request(text):\n",
" response = openai.Embedding.create(input=text, model=EMBEDDING_MODEL)\n",
" response = client.embeddings.create(input=text, model=EMBEDDING_MODEL)\n",
" return response\n",
"\n",
"\n",
"@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n",
"def get_articles(query, library=paper_dir_filepath, top_k=5):\n",
" \"\"\"This function gets the top_k articles based on a user's query, sorted by relevance.\n",
" It also downloads the files and stores them in arxiv_library.csv to be retrieved by the read_article_and_summarize.\n",
" \"\"\"\n",
" client = arxiv.Client()\n",
" search = arxiv.Search(\n",
" query=query, max_results=top_k, sort_by=arxiv.SortCriterion.Relevance\n",
" query = \"quantum\",\n",
" max_results = 10,\n",
" sort_by = arxiv.SortCriterion.SubmittedDate\n",
" )\n",
" result_list = []\n",
" for result in search.results():\n",
" for result in client.results(search):\n",
" result_dict = {}\n",
" result_dict.update({\"title\": result.title})\n",
" result_dict.update({\"summary\": result.summary})\n",
@ -257,7 +220,7 @@
" file_reference = [\n",
" result.title,\n",
" result.download_pdf(data_dir),\n",
" response[\"data\"][0][\"embedding\"],\n",
" response.data[0].embedding,\n",
" ]\n",
"\n",
" # Write to file\n",
@ -270,20 +233,20 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "dda02bdb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'title': 'Proximal Policy Optimization and its Dynamic Version for Sequence Generation',\n",
" 'summary': 'In sequence generation task, many works use policy gradient for model\\noptimization to tackle the intractable backpropagation issue when maximizing\\nthe non-differentiable evaluation metrics or fooling the discriminator in\\nadversarial learning. In this paper, we replace policy gradient with proximal\\npolicy optimization (PPO), which is a proved more efficient reinforcement\\nlearning algorithm, and propose a dynamic approach for PPO (PPO-dynamic). We\\ndemonstrate the efficacy of PPO and PPO-dynamic on conditional sequence\\ngeneration tasks including synthetic experiment and chit-chat chatbot. The\\nresults show that PPO and PPO-dynamic can beat policy gradient by stability and\\nperformance.',\n",
" 'article_url': 'http://arxiv.org/abs/1808.07982v1',\n",
" 'pdf_url': 'http://arxiv.org/pdf/1808.07982v1'}"
"{'title': 'Entanglement entropy and deconfined criticality: emergent SO(5) symmetry and proper lattice bipartition',\n",
" 'summary': \"We study the R\\\\'enyi entanglement entropy (EE) of the two-dimensional $J$-$Q$\\nmodel, the emblematic quantum spin model of deconfined criticality at the phase\\ntransition between antiferromagnetic and valence-bond-solid ground states.\\nQuantum Monte Carlo simulations with an improved EE scheme reveal critical\\ncorner contributions that scale logarithmically with the system size, with a\\ncoefficient in remarkable agreement with the form expected from a large-$N$\\nconformal field theory with SO($N=5$) symmetry. However, details of the\\nbipartition of the lattice are crucial in order to observe this behavior. If\\nthe subsystem for the reduced density matrix does not properly accommodate\\nvalence-bond fluctuations, logarithmic contributions appear even for\\ncorner-less bipartitions. We here use a $45^\\\\circ$ tilted cut on the square\\nlattice. Beyond supporting an SO($5$) deconfined quantum critical point, our\\nresults for both the regular and tilted cuts demonstrate important microscopic\\naspects of the EE that are not captured by conformal field theory.\",\n",
" 'article_url': 'http://arxiv.org/abs/2401.14396v1',\n",
" 'pdf_url': 'http://arxiv.org/pdf/2401.14396v1'}"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -296,7 +259,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "11675627",
"metadata": {},
"outputs": [],
@ -309,7 +272,7 @@
") -> list[str]:\n",
" \"\"\"Returns a list of strings and relatednesses, sorted from most related to least.\"\"\"\n",
" query_embedding_response = embedding_request(query)\n",
" query_embedding = query_embedding_response[\"data\"][0][\"embedding\"]\n",
" query_embedding = query_embedding_response.data[0].embedding\n",
" strings_and_relatednesses = [\n",
" (row[\"filepath\"], relatedness_fn(query_embedding, row[\"embedding\"]))\n",
" for i, row in df.iterrows()\n",
@ -321,7 +284,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "7211df2c",
"metadata": {},
"outputs": [],
@ -362,10 +325,10 @@
"def extract_chunk(content, template_prompt):\n",
" \"\"\"This function applies a prompt to some input content. In this case it returns a summarized chunk of text\"\"\"\n",
" prompt = template_prompt + content\n",
" response = openai.ChatCompletion.create(\n",
" response = client.chat.completions.create(\n",
" model=GPT_MODEL, messages=[{\"role\": \"user\", \"content\": prompt}], temperature=0\n",
" )\n",
" return response[\"choices\"][0][\"message\"][\"content\"]\n",
" return response.choices[0].message.content\n",
"\n",
"\n",
"def summarize_text(query):\n",
@ -418,7 +381,7 @@
"\n",
" # Final summary\n",
" print(\"Summarizing into overall summary\")\n",
" response = openai.ChatCompletion.create(\n",
" response = client.chat.completions.create(\n",
" model=GPT_MODEL,\n",
" messages=[\n",
" {\n",
@ -437,7 +400,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "898b94d4",
"metadata": {},
"outputs": [
@ -453,7 +416,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00, 1.19s/it]\n"
"100%|██████████| 15/15 [00:08<00:00, 1.76it/s]\n"
]
},
{
@ -471,7 +434,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "c715f60d",
"metadata": {},
"outputs": [
@ -479,26 +442,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Core Argument:\n",
"- The paper discusses the use of Proximal Policy Optimization (PPO) in sequence generation tasks, specifically in the context of chit-chat chatbots.\n",
"- The authors argue that PPO is a more efficient reinforcement learning algorithm compared to policy gradient, commonly used in text generation tasks.\n",
"- They propose a dynamic approach for PPO (PPO-dynamic) and demonstrate its efficacy in synthetic experiments and chit-chat chatbot tasks.\n",
"\n",
"Evidence:\n",
"- PPO-dynamic achieves high precision scores comparable to other algorithms in a synthetic counting task.\n",
"- PPO-dynamic shows faster progress and more stable learning curves compared to PPO in the synthetic counting task.\n",
"- In the chit-chat chatbot task, PPO-dynamic achieves a slightly higher BLEU-2 score than other algorithms.\n",
"- PPO and PPO-dynamic have more stable learning curves and converge faster than policy gradient.\n",
"\n",
"Conclusions:\n",
"- PPO is a better optimization method for sequence learning compared to policy gradient.\n",
"- PPO-dynamic further improves the optimization process by dynamically adjusting hyperparameters.\n",
"- PPO can be used as a new optimization method for GAN-based sequence learning for better performance.\n"
"The academic paper discusses the unique decomposition of generators of completely positive dynamical semigroups in infinite dimensions. The main result of the paper is that for any separable complex Hilbert space, any trace-class operator B that does not have a purely imaginary trace, and any generator L of a norm-continuous one-parameter semigroup of completely positive maps, there exists a unique bounded operator K and a unique completely positive map Φ such that L=K(·) + (·)K+ Φ. The paper also introduces a modified version of the Choi formalism, which relates completely positive maps to positive semi-definite operators, and characterizes when this correspondence is injective and surjective. The paper concludes by discussing the challenges and questions that arise when generalizing the results to non-separable Hilbert spaces.\n"
]
}
],
"source": [
"print(chat_test_response[\"choices\"][0][\"message\"][\"content\"])\n"
"print(chat_test_response.choices[0].message.content)\n"
]
},
{
@ -514,25 +463,18 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "77a6fb4f",
"metadata": {},
"outputs": [],
"source": [
"@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n",
"def chat_completion_request(messages, functions=None, model=GPT_MODEL):\n",
" headers = {\n",
" \"Content-Type\": \"application/json\",\n",
" \"Authorization\": \"Bearer \" + openai.api_key,\n",
" }\n",
" json_data = {\"model\": model, \"messages\": messages}\n",
" if functions is not None:\n",
" json_data.update({\"functions\": functions})\n",
" try:\n",
" response = requests.post(\n",
" \"https://api.openai.com/v1/chat/completions\",\n",
" headers=headers,\n",
" json=json_data,\n",
" response = client.chat.completions.create(\n",
" model=model,\n",
" messages=messages,\n",
" functions=functions,\n",
" )\n",
" return response\n",
" except Exception as e:\n",
@ -543,7 +485,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "73f7672d",
"metadata": {},
"outputs": [],
@ -574,7 +516,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "978b7877",
"metadata": {},
"outputs": [],
@ -619,7 +561,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"id": "0c88ae15",
"metadata": {},
"outputs": [],
@ -627,23 +569,23 @@
"def chat_completion_with_function_execution(messages, functions=[None]):\n",
" \"\"\"This function makes a ChatCompletion API call with the option of adding functions\"\"\"\n",
" response = chat_completion_request(messages, functions)\n",
" full_message = response.json()[\"choices\"][0]\n",
" if full_message[\"finish_reason\"] == \"function_call\":\n",
" full_message = response.choices[0]\n",
" if full_message.finish_reason == \"function_call\":\n",
" print(f\"Function generation requested, calling function\")\n",
" return call_arxiv_function(messages, full_message)\n",
" else:\n",
" print(f\"Function not required, responding to user\")\n",
" return response.json()\n",
" return response\n",
"\n",
"\n",
"def call_arxiv_function(messages, full_message):\n",
" \"\"\"Function calling function which executes function calls when the model believes it is necessary.\n",
" Currently extended by adding clauses to this if statement.\"\"\"\n",
"\n",
" if full_message[\"message\"][\"function_call\"][\"name\"] == \"get_articles\":\n",
" if full_message.message.function_call.name == \"get_articles\":\n",
" try:\n",
" parsed_output = json.loads(\n",
" full_message[\"message\"][\"function_call\"][\"arguments\"]\n",
" full_message.message.function_call.arguments\n",
" )\n",
" print(\"Getting search results\")\n",
" results = get_articles(parsed_output[\"query\"])\n",
@ -654,23 +596,23 @@
" messages.append(\n",
" {\n",
" \"role\": \"function\",\n",
" \"name\": full_message[\"message\"][\"function_call\"][\"name\"],\n",
" \"name\": full_message.message.function_call.name,\n",
" \"content\": str(results),\n",
" }\n",
" )\n",
" try:\n",
" print(\"Got search results, summarizing content\")\n",
" response = chat_completion_request(messages)\n",
" return response.json()\n",
" return response\n",
" except Exception as e:\n",
" print(type(e))\n",
" raise Exception(\"Function chat request failed\")\n",
"\n",
" elif (\n",
" full_message[\"message\"][\"function_call\"][\"name\"] == \"read_article_and_summarize\"\n",
" full_message.message.function_call.name == \"read_article_and_summarize\"\n",
" ):\n",
" parsed_output = json.loads(\n",
" full_message[\"message\"][\"function_call\"][\"arguments\"]\n",
" full_message.message.function_call.arguments\n",
" )\n",
" print(\"Finding and reading paper\")\n",
" summary = summarize_text(parsed_output[\"query\"])\n",
@ -693,7 +635,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"id": "c39a1d80",
"metadata": {},
"outputs": [],
@ -709,7 +651,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"id": "253fd0f7",
"metadata": {},
"outputs": [
@ -718,43 +660,30 @@
"output_type": "stream",
"text": [
"Function generation requested, calling function\n",
"Finding and reading paper\n",
"Chunking text from paper\n",
"Summarizing each chunk of text\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████████████| 17/17 [00:06<00:00, 2.65it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Summarizing into overall summary\n"
"Getting search results\n",
"Got search results, summarizing content\n"
]
},
{
"data": {
"text/markdown": [
"Core Argument:\n",
"- The paper focuses on the theoretical analysis of the PPO-Clip algorithm in the context of deep reinforcement learning.\n",
"- The authors propose two core ideas: reinterpreting PPO-Clip from the perspective of hinge loss and introducing a two-step policy improvement scheme.\n",
"- The paper establishes the global convergence of PPO-Clip and characterizes its convergence rate.\n",
"PPO (Proximal Policy Optimization) is a reinforcement learning algorithm used in training agents to make sequential decisions in dynamic environments. It belongs to the family of policy optimization algorithms and addresses the challenge of optimizing policies in a stable and sample-efficient manner. \n",
"\n",
"PPO works by iteratively collecting a batch of data from interacting with the environment, computing advantages to estimate the quality of actions, and then performing multiple policy updates using a clipped surrogate objective. This objective function helps prevent excessive policy updates that could lead to policy divergence and instability. \n",
"\n",
"By iteratively updating the policy using the collected data, PPO seeks to maximize the expected cumulative rewards obtained by the agent. It has been used successfully in a variety of reinforcement learning tasks, including robotic control, game playing, and simulated environments. \n",
"\n",
"To learn more about PPO reinforcement learning, you can read the following papers:\n",
"\n",
"Evidence:\n",
"- The paper addresses the challenges posed by the clipping mechanism and neural function approximation.\n",
"- The authors provide theoretical proofs, lemmas, and mathematical analysis to support their arguments.\n",
"- The paper presents empirical experiments on various reinforcement learning benchmark tasks to validate the effectiveness of PPO-Clip.\n",
"1. Title: \"Proximal Policy Optimization Algorithms\"\n",
" Article URL: [arxiv.org/abs/1707.06347v2](http://arxiv.org/abs/1707.06347v2)\n",
" Summary: This paper introduces PPO and presents two versions of the algorithm: PPO-Penalty and PPO-Clip. It provides a detailed description of PPO's update rule and compares its performance against other popular reinforcement learning algorithms.\n",
"\n",
"Conclusions:\n",
"- The paper offers theoretical insights into the performance of PPO-Clip and provides a framework for analyzing its convergence properties.\n",
"- PPO-Clip is shown to have a global convergence rate of O(1/sqrt(T)), where T is the number of iterations.\n",
"- The hinge loss reinterpretation of PPO-Clip allows for variants with comparable empirical performance.\n",
"- The paper contributes to a better understanding of PPO-Clip in the reinforcement learning community."
"2. Title: \"Emergent Properties of PPO Reinforcement Learning in Resource-Limited Environments\"\n",
" Article URL: [arxiv.org/abs/2001.14342v1](http://arxiv.org/abs/2001.14342v1)\n",
" Summary: This paper explores the emergent properties of PPO reinforcement learning algorithms in resource-limited environments. It discusses the impact of varying the resource constraints and agent population sizes on the learning process and performance.\n",
"\n",
"Reading these papers will give you a deeper understanding of PPO reinforcement learning and its applications in different domains."
],
"text/plain": [
"<IPython.core.display.Markdown object>"
@ -770,14 +699,14 @@
"chat_response = chat_completion_with_function_execution(\n",
" paper_conversation.conversation_history, functions=arxiv_functions\n",
")\n",
"assistant_message = chat_response[\"choices\"][0][\"message\"][\"content\"]\n",
"assistant_message = chat_response.choices[0].message.content\n",
"paper_conversation.add_message(\"assistant\", assistant_message)\n",
"display(Markdown(assistant_message))\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"id": "3ca3e18a",
"metadata": {},
"outputs": [
@ -795,7 +724,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00, 1.08s/it]\n"
"100%|██████████| 15/15 [00:09<00:00, 1.67it/s]\n"
]
},
{
@ -808,21 +737,7 @@
{
"data": {
"text/markdown": [
"Core Argument:\n",
"- The paper discusses the use of proximal policy optimization (PPO) in sequence generation tasks, specifically in the context of chit-chat chatbots.\n",
"- The authors argue that PPO is a more efficient reinforcement learning algorithm compared to policy gradient, which is commonly used in text generation tasks.\n",
"- They propose a dynamic approach for PPO (PPO-dynamic) and demonstrate its efficacy in synthetic experiments and chit-chat chatbot tasks.\n",
"\n",
"Evidence:\n",
"- The authors derive the constraints for PPO-dynamic and provide the pseudo code for both PPO and PPO-dynamic.\n",
"- They compare the performance of PPO-dynamic with other algorithms, including REINFORCE, MIXER, and SeqGAN, on a synthetic counting task and a chit-chat chatbot task using the OpenSubtitles dataset.\n",
"- In the synthetic counting task, PPO-dynamic achieves a high precision score comparable to REINFORCE and MIXER, with a faster learning curve compared to PPO.\n",
"- In the chit-chat chatbot task, PPO-dynamic achieves a slightly higher BLEU-2 score than REINFORCE and PPO, with a more stable and faster learning curve than policy gradient.\n",
"\n",
"Conclusions:\n",
"- The results suggest that PPO is a better optimization method for sequence learning compared to policy gradient.\n",
"- PPO-dynamic further improves the optimization process by dynamically adjusting the hyperparameters.\n",
"- The authors conclude that PPO can be used as a new optimization method for GAN-based sequence learning for better performance."
"The paper discusses the unique decomposition of generators of completely positive dynamical semigroups in infinite dimensions. The main result is that for any separable complex Hilbert space, any trace-class operator B that does not have a purely imaginary trace, and any generator L of a norm-continuous one-parameter semigroup of completely positive maps, there exists a unique bounded operator K and a unique completely positive map Φ such that L=K(·) + (·)K+ Φ. The paper also introduces a modified version of the Choi formalism and characterizes when this correspondence is injective and surjective. The paper concludes by discussing the challenges and questions that arise when generalizing the results to non-separable Hilbert spaces."
],
"text/plain": [
"<IPython.core.display.Markdown object>"
@ -841,15 +756,15 @@
"updated_response = chat_completion_with_function_execution(\n",
" paper_conversation.conversation_history, functions=arxiv_functions\n",
")\n",
"display(Markdown(updated_response[\"choices\"][0][\"message\"][\"content\"]))\n"
"display(Markdown(updated_response.choices[0].message.content))\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tua_test",
"display_name": "Python 3",
"language": "python",
"name": "tua_test"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
@ -861,7 +776,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.12.1"
}
},
"nbformat": 4,

@ -33,37 +33,68 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "80e71f33",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: scipy in /usr/local/lib/python3.11/site-packages (1.12.0)\n",
"Requirement already satisfied: numpy<1.29.0,>=1.22.4 in /usr/local/lib/python3.11/site-packages (from scipy) (1.26.3)\n",
"Requirement already satisfied: tenacity in /usr/local/lib/python3.11/site-packages (8.2.3)\n",
"Requirement already satisfied: tiktoken in /usr/local/lib/python3.11/site-packages (0.3.3)\n",
"Requirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.11/site-packages (from tiktoken) (2023.12.25)\n",
"Requirement already satisfied: requests>=2.26.0 in /usr/local/lib/python3.11/site-packages (from tiktoken) (2.31.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (2.1.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (2023.11.17)\n",
"Requirement already satisfied: termcolor in /usr/local/lib/python3.11/site-packages (2.4.0)\n",
"Requirement already satisfied: openai in /usr/local/lib/python3.11/site-packages (1.10.0)\n",
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.11/site-packages (from openai) (4.2.0)\n",
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.11/site-packages (from openai) (1.9.0)\n",
"Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.11/site-packages (from openai) (0.26.0)\n",
"Requirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.11/site-packages (from openai) (2.5.3)\n",
"Requirement already satisfied: sniffio in /usr/local/lib/python3.11/site-packages (from openai) (1.3.0)\n",
"Requirement already satisfied: tqdm>4 in /usr/local/lib/python3.11/site-packages (from openai) (4.66.1)\n",
"Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.11/site-packages (from openai) (4.9.0)\n",
"Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.11/site-packages (from anyio<5,>=3.5.0->openai) (3.6)\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.11/site-packages (from httpx<1,>=0.23.0->openai) (2023.11.17)\n",
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.11/site-packages (from httpx<1,>=0.23.0->openai) (1.0.2)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.11/site-packages (from httpcore==1.*->httpx<1,>=0.23.0->openai) (0.14.0)\n",
"Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.11/site-packages (from pydantic<3,>=1.9.0->openai) (0.6.0)\n",
"Requirement already satisfied: pydantic-core==2.14.6 in /usr/local/lib/python3.11/site-packages (from pydantic<3,>=1.9.0->openai) (2.14.6)\n"
]
}
],
"source": [
"!pip install scipy\n",
"!pip install tenacity\n",
"!pip install tiktoken\n",
"!pip install termcolor \n",
"!pip install openai\n",
"!pip install requests"
"!pip install openai"
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "dab872c5",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import openai\n",
"import requests\n",
"from openai import OpenAI\n",
"from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
"from termcolor import colored\n",
"from termcolor import colored \n",
"\n",
"GPT_MODEL = \"gpt-3.5-turbo-0613\""
"GPT_MODEL = \"gpt-3.5-turbo-0613\"\n",
"client = OpenAI()"
]
},
{
@ -79,27 +110,19 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "745ceec5",
"metadata": {},
"outputs": [],
"source": [
"@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))\n",
"def chat_completion_request(messages, tools=None, tool_choice=None, model=GPT_MODEL):\n",
" headers = {\n",
" \"Content-Type\": \"application/json\",\n",
" \"Authorization\": \"Bearer \" + openai.api_key,\n",
" }\n",
" json_data = {\"model\": model, \"messages\": messages}\n",
" if tools is not None:\n",
" json_data.update({\"tools\": tools})\n",
" if tool_choice is not None:\n",
" json_data.update({\"tool_choice\": tool_choice})\n",
" try:\n",
" response = requests.post(\n",
" \"https://api.openai.com/v1/chat/completions\",\n",
" headers=headers,\n",
" json=json_data,\n",
" response = client.chat.completions.create(\n",
" model=model,\n",
" messages=messages,\n",
" tools=tools,\n",
" tool_choice=tool_choice,\n",
" )\n",
" return response\n",
" except Exception as e:\n",
@ -110,7 +133,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "c4d1c99f",
"metadata": {},
"outputs": [],
@ -120,7 +143,7 @@
" \"system\": \"red\",\n",
" \"user\": \"green\",\n",
" \"assistant\": \"blue\",\n",
" \"tool\": \"magenta\",\n",
" \"function\": \"magenta\",\n",
" }\n",
" \n",
" for message in messages:\n",
@ -132,7 +155,7 @@
" print(colored(f\"assistant: {message['function_call']}\\n\", role_to_color[message[\"role\"]]))\n",
" elif message[\"role\"] == \"assistant\" and not message.get(\"function_call\"):\n",
" print(colored(f\"assistant: {message['content']}\\n\", role_to_color[message[\"role\"]]))\n",
" elif message[\"role\"] == \"tool\":\n",
" elif message[\"role\"] == \"function\":\n",
" print(colored(f\"function ({message['name']}): {message['content']}\\n\", role_to_color[message[\"role\"]]))\n"
]
},
@ -149,7 +172,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "d2e25069",
"metadata": {},
"outputs": [],
@ -217,18 +240,17 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "518d6827",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant',\n",
" 'content': 'Sure, I can help you with that. Could you please tell me the city and state you are in or the location you want to know the weather for?'}"
"ChatCompletionMessage(content='Sure, I can help you with that. Could you please provide me with your location?', role='assistant', function_call=None, tool_calls=None)"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -240,7 +262,7 @@
"chat_response = chat_completion_request(\n",
" messages, tools=tools\n",
")\n",
"assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
"assistant_message = chat_response.choices[0].message\n",
"messages.append(assistant_message)\n",
"assistant_message\n"
]
@ -256,22 +278,17 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "23c42a6e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant',\n",
" 'content': None,\n",
" 'tool_calls': [{'id': 'call_o7uyztQLeVIoRdjcDkDJY3ni',\n",
" 'type': 'function',\n",
" 'function': {'name': 'get_current_weather',\n",
" 'arguments': '{\\n \"location\": \"Glasgow, Scotland\",\\n \"format\": \"celsius\"\\n}'}}]}"
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_qOYhFO7fKaU6wpG2f1XzkDjW', function=Function(arguments='{\\n \"location\": \"Glasgow, Scotland\",\\n \"format\": \"celsius\"\\n}', name='get_current_weather'), type='function')])"
]
},
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@ -281,7 +298,7 @@
"chat_response = chat_completion_request(\n",
" messages, tools=tools\n",
")\n",
"assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
"assistant_message = chat_response.choices[0].message\n",
"messages.append(assistant_message)\n",
"assistant_message\n"
]
@ -297,18 +314,17 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "fa232e54",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant',\n",
" 'content': 'Sure, I can help you with that. Please let me know the value for x.'}"
"ChatCompletionMessage(content='Sure! Please provide the number of days you would like to know the weather forecast for.', role='assistant', function_call=None, tool_calls=None)"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@ -320,7 +336,7 @@
"chat_response = chat_completion_request(\n",
" messages, tools=tools\n",
")\n",
"assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
"assistant_message = chat_response.choices[0].message\n",
"messages.append(assistant_message)\n",
"assistant_message\n"
]
@ -336,24 +352,17 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "c7d8a543",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'index': 0,\n",
" 'message': {'role': 'assistant',\n",
" 'content': None,\n",
" 'tool_calls': [{'id': 'call_drz2YpGPWEMVySzYgsWYY249',\n",
" 'type': 'function',\n",
" 'function': {'name': 'get_n_day_weather_forecast',\n",
" 'arguments': '{\\n \"location\": \"Glasgow, Scotland\",\\n \"format\": \"celsius\",\\n \"num_days\": 5\\n}'}}]},\n",
" 'finish_reason': 'tool_calls'}"
"Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_HwWHsNZsmkZUroPj6glmEgA5', function=Function(arguments='{\\n \"location\": \"Glasgow, Scotland\",\\n \"format\": \"celsius\",\\n \"num_days\": 5\\n}', name='get_n_day_weather_forecast'), type='function')]), internal_metrics=[{'cached_prompt_tokens': 0, 'total_accepted_tokens': 0, 'total_batched_tokens': 269, 'total_predicted_tokens': 0, 'total_rejected_tokens': 0, 'total_tokens_in_completion': 270, 'cached_embeddings_bytes': 0, 'cached_embeddings_n': 0, 'uncached_embeddings_bytes': 0, 'uncached_embeddings_n': 0, 'fetched_embeddings_bytes': 0, 'fetched_embeddings_n': 0, 'n_evictions': 0, 'sampling_steps': 40, 'sampling_steps_with_predictions': 0, 'batcher_ttft': 0.055008649826049805, 'batcher_initial_queue_time': 0.00098419189453125}])"
]
},
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -363,7 +372,7 @@
"chat_response = chat_completion_request(\n",
" messages, tools=tools\n",
")\n",
"chat_response.json()[\"choices\"][0]\n"
"chat_response.choices[0]\n"
]
},
{
@ -386,22 +395,17 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "559371b7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant',\n",
" 'content': None,\n",
" 'tool_calls': [{'id': 'call_jdmoJQ4lqsu4mBWcVBYtt5cU',\n",
" 'type': 'function',\n",
" 'function': {'name': 'get_n_day_weather_forecast',\n",
" 'arguments': '{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\",\\n \"num_days\": 1\\n}'}}]}"
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_240XQedt4Gi8VZsUwOvFpQfZ', function=Function(arguments='{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\",\\n \"num_days\": 1\\n}', name='get_n_day_weather_forecast'), type='function')])"
]
},
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@ -414,27 +418,22 @@
"chat_response = chat_completion_request(\n",
" messages, tools=tools, tool_choice={\"type\": \"function\", \"function\": {\"name\": \"get_n_day_weather_forecast\"}}\n",
")\n",
"chat_response.json()[\"choices\"][0][\"message\"]\n"
"chat_response.choices[0].message"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "a7ab0f58",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant',\n",
" 'content': None,\n",
" 'tool_calls': [{'id': 'call_RYXaDjxpUCfWmpXU7BZEYVqS',\n",
" 'type': 'function',\n",
" 'function': {'name': 'get_current_weather',\n",
" 'arguments': '{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\"\\n}'}}]}"
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_lQhrFlzIVPpeYG1QrSv7e3H3', function=Function(arguments='{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\"\\n}', name='get_current_weather'), type='function')])"
]
},
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
@ -447,7 +446,7 @@
"chat_response = chat_completion_request(\n",
" messages, tools=tools\n",
")\n",
"chat_response.json()[\"choices\"][0][\"message\"]\n"
"chat_response.choices[0].message"
]
},
{
@ -461,18 +460,17 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "acfe54e6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant',\n",
" 'content': '{ \"location\": \"Toronto, Canada\", \"format\": \"celsius\" }'}"
"ChatCompletionMessage(content='{\\n \"location\": \"Toronto, Canada\",\\n \"format\": \"celsius\"\\n}', role='assistant', function_call=None, tool_calls=None)"
]
},
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@ -484,11 +482,12 @@
"chat_response = chat_completion_request(\n",
" messages, tools=tools, tool_choice=\"none\"\n",
")\n",
"chat_response.json()[\"choices\"][0][\"message\"]\n"
"chat_response.choices[0].message\n"
]
},
{
"cell_type": "markdown",
"id": "b616353b",
"metadata": {},
"source": [
"### Parallel Function Calling\n",
@ -498,23 +497,18 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "380eeb68",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'id': 'call_fLsKR5vGllhbWxvpqsDT3jBj',\n",
" 'type': 'function',\n",
" 'function': {'name': 'get_n_day_weather_forecast',\n",
" 'arguments': '{\"location\": \"San Francisco, CA\", \"format\": \"celsius\", \"num_days\": 4}'}},\n",
" {'id': 'call_CchlsGE8OE03QmeyFbg7pkDz',\n",
" 'type': 'function',\n",
" 'function': {'name': 'get_n_day_weather_forecast',\n",
" 'arguments': '{\"location\": \"Glasgow\", \"format\": \"celsius\", \"num_days\": 4}'}}]"
"[ChatCompletionMessageToolCall(id='call_q8k4geh0uGPRtIfOXYPB0yM8', function=Function(arguments='{\"location\": \"San Francisco, CA\", \"format\": \"celsius\", \"num_days\": 4}', name='get_n_day_weather_forecast'), type='function'),\n",
" ChatCompletionMessageToolCall(id='call_Hdl7Py7aLswCBPptrD4y5BD3', function=Function(arguments='{\"location\": \"Glasgow\", \"format\": \"celsius\", \"num_days\": 4}', name='get_n_day_weather_forecast'), type='function')]"
]
},
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@ -527,8 +521,7 @@
" messages, tools=tools, model='gpt-3.5-turbo-1106'\n",
")\n",
"\n",
"chat_response.json()\n",
"assistant_message = chat_response.json()[\"choices\"][0][\"message\"]['tool_calls']\n",
"assistant_message = chat_response.choices[0].message.tool_calls\n",
"assistant_message"
]
},
@ -558,7 +551,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"id": "30f6b60e",
"metadata": {},
"outputs": [
@ -579,7 +572,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"id": "abec0214",
"metadata": {},
"outputs": [],
@ -622,7 +615,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"id": "0c0104cd",
"metadata": {},
"outputs": [],
@ -647,7 +640,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"id": "0258813a",
"metadata": {},
"outputs": [],
@ -691,7 +684,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"id": "65585e74",
"metadata": {},
"outputs": [],
@ -705,17 +698,17 @@
" return results\n",
"\n",
"def execute_function_call(message):\n",
" if message[\"tool_calls\"][0][\"function\"][\"name\"] == \"ask_database\":\n",
" query = json.loads(message[\"tool_calls\"][0][\"function\"][\"arguments\"])[\"query\"]\n",
" if message.tool_calls[0].function.name == \"ask_database\":\n",
" query = json.loads(message.tool_calls[0].function.arguments)[\"query\"]\n",
" results = ask_database(conn, query)\n",
" else:\n",
" results = f\"Error: function {message['tool_calls'][0]['function']['name']} does not exist\"\n",
" results = f\"Error: function {message.tool_calls[0].function.name} does not exist\"\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 19,
"id": "38c55083",
"metadata": {},
"outputs": [
@ -727,7 +720,7 @@
"\u001b[0m\n",
"\u001b[32muser: Hi, who are the top 5 artists by number of tracks?\n",
"\u001b[0m\n",
"\u001b[34massistant: {'name': 'ask_database', 'arguments': '{\\n \"query\": \"SELECT Artist.Name, COUNT(Track.TrackId) AS TrackCount FROM Artist JOIN Album ON Artist.ArtistId = Album.ArtistId JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Artist.Name ORDER BY TrackCount DESC LIMIT 5\"\\n}'}\n",
"\u001b[34massistant: Function(arguments='{\\n \"query\": \"SELECT artist.Name, COUNT(track.TrackId) AS num_tracks FROM artist JOIN album ON artist.ArtistId = album.ArtistId JOIN track ON album.AlbumId = track.AlbumId GROUP BY artist.ArtistId ORDER BY num_tracks DESC LIMIT 5\"\\n}', name='ask_database')\n",
"\u001b[0m\n",
"\u001b[35mfunction (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Lost', 92)]\n",
"\u001b[0m\n"
@ -739,18 +732,18 @@
"messages.append({\"role\": \"system\", \"content\": \"Answer user questions by generating SQL queries against the Chinook Music Database.\"})\n",
"messages.append({\"role\": \"user\", \"content\": \"Hi, who are the top 5 artists by number of tracks?\"})\n",
"chat_response = chat_completion_request(messages, tools)\n",
"assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
"assistant_message['content'] = str(assistant_message[\"tool_calls\"][0][\"function\"])\n",
"messages.append(assistant_message)\n",
"if assistant_message.get(\"tool_calls\"):\n",
"assistant_message = chat_response.choices[0].message\n",
"assistant_message.content = str(assistant_message.tool_calls[0].function)\n",
"messages.append({\"role\": assistant_message.role, \"content\": assistant_message.content})\n",
"if assistant_message.tool_calls:\n",
" results = execute_function_call(assistant_message)\n",
" messages.append({\"role\": \"tool\", \"tool_call_id\": assistant_message[\"tool_calls\"][0]['id'], \"name\": assistant_message[\"tool_calls\"][0][\"function\"][\"name\"], \"content\": results})\n",
" messages.append({\"role\": \"function\", \"tool_call_id\": assistant_message.tool_calls[0].id, \"name\": assistant_message.tool_calls[0].function.name, \"content\": results})\n",
"pretty_print_conversation(messages)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 20,
"id": "710481dc",
"metadata": {
"scrolled": true
@ -764,13 +757,13 @@
"\u001b[0m\n",
"\u001b[32muser: Hi, who are the top 5 artists by number of tracks?\n",
"\u001b[0m\n",
"\u001b[34massistant: {'name': 'ask_database', 'arguments': '{\\n \"query\": \"SELECT Artist.Name, COUNT(Track.TrackId) AS TrackCount FROM Artist JOIN Album ON Artist.ArtistId = Album.ArtistId JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Artist.Name ORDER BY TrackCount DESC LIMIT 5\"\\n}'}\n",
"\u001b[34massistant: Function(arguments='{\\n \"query\": \"SELECT artist.Name, COUNT(track.TrackId) AS num_tracks FROM artist JOIN album ON artist.ArtistId = album.ArtistId JOIN track ON album.AlbumId = track.AlbumId GROUP BY artist.ArtistId ORDER BY num_tracks DESC LIMIT 5\"\\n}', name='ask_database')\n",
"\u001b[0m\n",
"\u001b[35mfunction (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Lost', 92)]\n",
"\u001b[0m\n",
"\u001b[32muser: What is the name of the album with the most tracks?\n",
"\u001b[0m\n",
"\u001b[34massistant: {'name': 'ask_database', 'arguments': '{\\n \"query\": \"SELECT Album.Title, COUNT(Track.TrackId) AS TrackCount FROM Album JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Album.Title ORDER BY TrackCount DESC LIMIT 1\"\\n}'}\n",
"\u001b[34massistant: Function(arguments='{\\n \"query\": \"SELECT album.Title, COUNT(track.TrackId) AS num_tracks FROM album JOIN track ON album.AlbumId = track.AlbumId GROUP BY album.AlbumId ORDER BY num_tracks DESC LIMIT 1\"\\n}', name='ask_database')\n",
"\u001b[0m\n",
"\u001b[35mfunction (ask_database): [('Greatest Hits', 57)]\n",
"\u001b[0m\n"
@ -780,12 +773,12 @@
"source": [
"messages.append({\"role\": \"user\", \"content\": \"What is the name of the album with the most tracks?\"})\n",
"chat_response = chat_completion_request(messages, tools)\n",
"assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
"assistant_message['content'] = str(assistant_message[\"tool_calls\"][0][\"function\"])\n",
"messages.append(assistant_message)\n",
"if assistant_message.get(\"tool_calls\"):\n",
"assistant_message = chat_response.choices[0].message\n",
"assistant_message.content = str(assistant_message.tool_calls[0].function)\n",
"messages.append({\"role\": assistant_message.role, \"content\": assistant_message.content})\n",
"if assistant_message.tool_calls:\n",
" results = execute_function_call(assistant_message)\n",
" messages.append({\"role\": \"tool\", \"tool_call_id\": assistant_message[\"tool_calls\"][0]['id'], \"name\": assistant_message[\"tool_calls\"][0][\"function\"][\"name\"], \"content\": results})\n",
" messages.append({\"role\": \"function\", \"tool_call_id\": assistant_message.tool_calls[0].id, \"name\": assistant_message.tool_calls[0].function.name, \"content\": results})\n",
"pretty_print_conversation(messages)"
]
},
@ -799,6 +792,12 @@
"\n",
"See our other [notebook](How_to_call_functions_for_knowledge_retrieval.ipynb) that demonstrates how to use the Chat Completions API and functions for knowledge retrieval to interact conversationally with a knowledge base."
]
},
{
"cell_type": "markdown",
"id": "ec721d07",
"metadata": {},
"source": []
}
],
"metadata": {
@ -817,7 +816,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
"version": "3.12.1"
}
},
"nbformat": 4,

@ -22,7 +22,7 @@
"\n",
"| Encoding name | OpenAI models |\n",
"|-------------------------|-----------------------------------------------------|\n",
"| `cl100k_base` | `gpt-4`, `gpt-3.5-turbo`, `text-embedding-ada-002` |\n",
"| `cl100k_base` | `gpt-4`, `gpt-3.5-turbo`, `text-embedding-ada-002`, `text-embedding-3-small`, `text-embedding-3-large` |\n",
"| `p50k_base` | Codex models, `text-davinci-002`, `text-davinci-003`|\n",
"| `r50k_base` (or `gpt2`) | GPT-3 models like `davinci` |\n",
"\n",
@ -89,7 +89,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@ -433,7 +433,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -480,7 +480,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@ -519,7 +519,10 @@
"source": [
"# let's verify the function above matches the OpenAI API response\n",
"\n",
"import openai\n",
"from openai import OpenAI\n",
"import os\n",
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
"\n",
"example_messages = [\n",
" {\n",
@ -564,15 +567,20 @@
" # example token count from the function defined above\n",
" print(f\"{num_tokens_from_messages(example_messages, model)} prompt tokens counted by num_tokens_from_messages().\")\n",
" # example token count from the OpenAI API\n",
" response = openai.ChatCompletion.create(\n",
" model=model,\n",
" messages=example_messages,\n",
" temperature=0,\n",
" max_tokens=1, # we're only counting input tokens here, so let's not waste tokens on the output\n",
" )\n",
" print(f'{response[\"usage\"][\"prompt_tokens\"]} prompt tokens counted by the OpenAI API.')\n",
" response = client.chat.completions.create(model=model,\n",
" messages=example_messages,\n",
" temperature=0,\n",
" max_tokens=1)\n",
" print(f'{response.usage.prompt_tokens} prompt tokens counted by the OpenAI API.')\n",
" print()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@ -591,7 +599,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.11.5"
},
"vscode": {
"interpreter": {

@ -53,8 +53,7 @@
"import pandas as pd\n",
"from pprint import pprint\n",
"\n",
"OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\", \"\")\n",
"openai.api_key = OPENAI_API_KEY"
"client = openai.OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))"
]
},
{
@ -421,14 +420,14 @@
],
"source": [
"with open(training_file_name, \"rb\") as training_fd:\n",
" training_response = openai.files.create(\n",
" training_response = client.files.create(\n",
" file=training_fd, purpose=\"fine-tune\"\n",
" )\n",
"\n",
"training_file_id = training_response.id\n",
"\n",
"with open(validation_file_name, \"rb\") as validation_fd:\n",
" validation_response = openai.files.create(\n",
" validation_response = client.files.create(\n",
" file=validation_fd, purpose=\"fine-tune\"\n",
" )\n",
"validation_file_id = validation_response.id\n",
@ -465,7 +464,7 @@
}
],
"source": [
"response = openai.fine_tuning.jobs.create(\n",
"response = client.fine_tuning.jobs.create(\n",
" training_file=training_file_id,\n",
" validation_file=validation_file_id,\n",
" model=\"gpt-3.5-turbo\",\n",
@ -507,7 +506,7 @@
}
],
"source": [
"response = openai.fine_tuning.jobs.retrieve(job_id)\n",
"response = client.fine_tuning.jobs.retrieve(job_id)\n",
"\n",
"print(\"Job ID:\", response.id)\n",
"print(\"Status:\", response.status)\n",
@ -556,7 +555,7 @@
}
],
"source": [
"response = openai.fine_tuning.jobs.list_events(job_id)\n",
"response = client.fine_tuning.jobs.list_events(job_id)\n",
"\n",
"events = response.data\n",
"events.reverse()\n",
@ -588,7 +587,7 @@
}
],
"source": [
"response = openai.fine_tuning.jobs.retrieve(job_id)\n",
"response = client.fine_tuning.jobs.retrieve(job_id)\n",
"fine_tuned_model_id = response.fine_tuned_model\n",
"\n",
"if fine_tuned_model_id is None: \n",
@ -664,7 +663,7 @@
}
],
"source": [
"response = openai.chat.completions.create(\n",
"response = client.chat.completions.create(\n",
" model=fine_tuned_model_id, messages=test_messages, temperature=0, max_tokens=500\n",
")\n",
"print(response.choices[0].message.content)"

@ -36,12 +36,15 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"# import the OpenAI Python library for calling the OpenAI API\n",
"import openai\n"
"from openai import OpenAI\n",
"import os\n",
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))"
]
},
{
@ -49,17 +52,37 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. An example chat API call\n",
"## 2. An example chat completion API call\n",
"\n",
"A chat API call has two required inputs:\n",
"- `model`: the name of the model you want to use (e.g., `gpt-3.5-turbo`, `gpt-4`, `gpt-3.5-turbo-0613`, `gpt-3.5-turbo-16k-0613`)\n",
"A chat completion API call parameters,\n",
"**Required**\n",
"- `model`: the name of the model you want to use (e.g., `gpt-3.5-turbo`, `gpt-4`, `gpt-3.5-turbo-16k-1106`)\n",
"- `messages`: a list of message objects, where each object has two required fields:\n",
" - `role`: the role of the messenger (either `system`, `user`, or `assistant`)\n",
" - `role`: the role of the messenger (either `system`, `user`, `assistant` or `tool`)\n",
" - `content`: the content of the message (e.g., `Write me a beautiful poem`)\n",
"\n",
"Messages can also contain an optional `name` field, which give the messenger a name. E.g., `example-user`, `Alice`, `BlackbeardBot`. Names may not contain spaces.\n",
"\n",
"As of June 2023, you can also optionally submit a list of `functions` that tell GPT whether it can generate JSON to feed into a function. For details, see the [documentation](https://platform.openai.com/docs/guides/gpt/function-calling), [API reference](https://platform.openai.com/docs/api-reference/chat), or the Cookbook guide [How to call functions with chat models](How_to_call_functions_with_chat_models.ipynb).\n",
"**Optional**\n",
"- `frequency_penalty`: Penalizes tokens based on their frequency, reducing repetition.\n",
"- `logit_bias`: Modifies likelihood of specified tokens with bias values.\n",
"- `logprobs`: Returns log probabilities of output tokens if true.\n",
"- `top_logprobs`: Specifies the number of most likely tokens to return at each position.\n",
"- `max_tokens`: Sets the maximum number of generated tokens in chat completion.\n",
"- `n`: Generates a specified number of chat completion choices for each input.\n",
"- `presence_penalty`: Penalizes new tokens based on their presence in the text.\n",
"- `response_format`: Specifies the output format, e.g., JSON mode.\n",
"- `seed`: Ensures deterministic sampling with a specified seed.\n",
"- `stop`: Specifies up to 4 sequences where the API should stop generating tokens.\n",
"- `stream`: Sends partial message deltas as tokens become available.\n",
"- `temperature`: Sets the sampling temperature between 0 and 2.\n",
"- `top_p`: Uses nucleus sampling; considers tokens with top_p probability mass.\n",
"- `tools`: Lists functions the model may call.\n",
"- `tool_choice`: Controls the model's function calls (none/auto/function).\n",
"- `user`: Unique identifier for end-user monitoring and abuse detection.\n",
"\n",
"\n",
"As of January 2024, you can also optionally submit a list of `functions` that tell GPT whether it can generate JSON to feed into a function. For details, see the [documentation](https://platform.openai.com/docs/guides/function-calling), [API reference](https://platform.openai.com/docs/api-reference/chat), or the Cookbook guide [How to call functions with chat models](How_to_call_functions_with_chat_models.ipynb).\n",
"\n",
"Typically, a conversation will start with a system message that tells the assistant how to behave, followed by alternating user and assistant messages, but you are not required to follow this format.\n",
"\n",
@ -68,44 +91,13 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<OpenAIObject chat.completion id=chatcmpl-7UkgnSDzlevZxiy0YjZcLYdUMz5yZ at 0x118e394f0> JSON: {\n",
" \"id\": \"chatcmpl-7UkgnSDzlevZxiy0YjZcLYdUMz5yZ\",\n",
" \"object\": \"chat.completion\",\n",
" \"created\": 1687563669,\n",
" \"model\": \"gpt-3.5-turbo-0301\",\n",
" \"choices\": [\n",
" {\n",
" \"index\": 0,\n",
" \"message\": {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"Orange who?\"\n",
" },\n",
" \"finish_reason\": \"stop\"\n",
" }\n",
" ],\n",
" \"usage\": {\n",
" \"prompt_tokens\": 39,\n",
" \"completion_tokens\": 3,\n",
" \"total_tokens\": 42\n",
" }\n",
"}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# Example OpenAI Python library request\n",
"MODEL = \"gpt-3.5-turbo\"\n",
"response = openai.ChatCompletion.create(\n",
"response = client.chat.completions.create(\n",
" model=MODEL,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
@ -114,9 +106,48 @@
" {\"role\": \"user\", \"content\": \"Orange.\"},\n",
" ],\n",
" temperature=0,\n",
")\n",
"\n",
"response\n"
")\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"id\": \"chatcmpl-8dee9DuEFcg2QILtT2a6EBXZnpirM\",\n",
" \"choices\": [\n",
" {\n",
" \"finish_reason\": \"stop\",\n",
" \"index\": 0,\n",
" \"logprobs\": null,\n",
" \"message\": {\n",
" \"content\": \"Orange who?\",\n",
" \"role\": \"assistant\",\n",
" \"function_call\": null,\n",
" \"tool_calls\": null\n",
" }\n",
" }\n",
" ],\n",
" \"created\": 1704461729,\n",
" \"model\": \"gpt-3.5-turbo-0613\",\n",
" \"object\": \"chat.completion\",\n",
" \"system_fingerprint\": null,\n",
" \"usage\": {\n",
" \"completion_tokens\": 3,\n",
" \"prompt_tokens\": 35,\n",
" \"total_tokens\": 38\n",
" }\n",
"}\n"
]
}
],
"source": [
"print(json.dumps(json.loads(response.model_dump_json()), indent=4))"
]
},
{
@ -126,14 +157,19 @@
"source": [
"As you can see, the response object has a few fields:\n",
"- `id`: the ID of the request\n",
"- `object`: the type of object returned (e.g., `chat.completion`)\n",
"- `created`: the timestamp of the request\n",
"- `model`: the full name of the model used to generate the response\n",
"- `usage`: the number of tokens used to generate the replies, counting prompt, completion, and total\n",
"- `choices`: a list of completion objects (only one, unless you set `n` greater than 1)\n",
" - `message`: the message object generated by the model, with `role` and `content`\n",
" - `finish_reason`: the reason the model stopped generating text (either `stop`, or `length` if `max_tokens` limit was reached)\n",
" - `index`: the index of the completion in the list of choices"
" - `index`: The index of the choice in the list of choices.\n",
" - `logprobs`: Log probability information for the choice.\n",
" - `message`: the message object generated by the model\n",
" - `content`: content of message\n",
" - `role`: The role of the author of this message.\n",
" - `tool_calls`: The tool calls generated by the model, such as function calls. if the tools is given\n",
"- `created`: the timestamp of the request\n",
"- `model`: the full name of the model used to generate the response\n",
"- `object`: the type of object returned (e.g., `chat.completion`)\n",
"- `system_fingerprint`: This fingerprint represents the backend configuration that the model runs with.\n",
"- `usage`: the number of tokens used to generate the replies, counting prompt, completion, and total"
]
},
{
@ -146,7 +182,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 9,
"metadata": {},
"outputs": [
{
@ -155,13 +191,13 @@
"'Orange who?'"
]
},
"execution_count": 4,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response['choices'][0]['message']['content']\n"
"response.choices[0].message.content\n"
]
},
{
@ -176,20 +212,36 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Ahoy matey! Asynchronous programming be like havin' a crew o' pirates workin' on different tasks at the same time. Ye see, instead o' waitin' for one task to be completed before startin' the next, ye can assign tasks to yer crew and let 'em work on 'em simultaneously. This way, ye can get more done in less time and keep yer ship sailin' smoothly. It be like havin' a bunch o' pirates rowin' the ship at different speeds, but still gettin' us to our destination. Arrr!\n"
"Arr, me matey! Let me tell ye a tale of asynchronous programming, in the style of the fearsome pirate Blackbeard!\n",
"\n",
"Picture this, me hearties. In the vast ocean of programming, there be times when ye need to perform multiple tasks at once. But fear not, for asynchronous programming be here to save the day!\n",
"\n",
"Ye see, in traditional programming, ye be waitin' for one task to be done before movin' on to the next. But with asynchronous programming, ye can be takin' care of multiple tasks at the same time, just like a pirate multitaskin' on the high seas!\n",
"\n",
"Instead of waitin' for a task to be completed, ye can be sendin' it off on its own journey, while ye move on to the next task. It be like havin' a crew of trusty sailors, each takin' care of their own duties, without waitin' for the others.\n",
"\n",
"Now, ye may be wonderin', how does this sorcery work? Well, me matey, it be all about callbacks and promises. When ye be sendin' off a task, ye be attachin' a callback function to it. This be like leavin' a message in a bottle, tellin' the task what to do when it be finished.\n",
"\n",
"While the task be sailin' on its own, ye can be movin' on to the next task, without wastin' any precious time. And when the first task be done, it be sendin' a signal back to ye, lettin' ye know it be finished. Then ye can be takin' care of the callback function, like openin' the bottle and readin' the message inside.\n",
"\n",
"But wait, there be more! With promises, ye can be makin' even fancier arrangements. Instead of callbacks, ye be makin' a promise that the task will be completed. It be like a contract between ye and the task, swearin' that it will be done.\n",
"\n",
"Ye can be attachin' multiple promises to a task, promisin' different outcomes. And when the task be finished, it be fulfillin' the promises, lettin' ye know it be done. Then ye can be handlin' the fulfillments, like collectin' the rewards of yer pirate adventures!\n",
"\n",
"So, me hearties, that be the tale of asynchronous programming, told in the style of the fearsome pirate Blackbeard! With callbacks and promises, ye can be takin' care of multiple tasks at once, just like a pirate conquerin' the seven seas!\n"
]
}
],
"source": [
"# example with a system message\n",
"response = openai.ChatCompletion.create(\n",
"response = client.chat.completions.create(\n",
" model=MODEL,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
@ -198,31 +250,43 @@
" temperature=0,\n",
")\n",
"\n",
"print(response['choices'][0]['message']['content'])\n"
"print(response.choices[0].message.content)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Ahoy mateys! Let me tell ye about asynchronous programming, arrr! It be like havin' a crew of sailors workin' on different tasks at the same time, without waitin' for each other to finish. Ye see, in traditional programming, ye have to wait for one task to be completed before movin' on to the next. But with asynchronous programming, ye can start multiple tasks at once and let them run in the background while ye focus on other things.\n",
"Arr, me hearties! Gather 'round and listen up, for I be tellin' ye about the mysterious art of asynchronous programming, in the style of the fearsome pirate Blackbeard!\n",
"\n",
"It be like havin' a lookout keepin' watch for enemy ships while the rest of the crew be busy with their own tasks. They don't have to stop what they're doin' to keep an eye out, because the lookout be doin' it for them. And when the lookout spots an enemy ship, they can alert the crew and everyone can work together to defend the ship.\n",
"Now, ye see, in the world of programming, there be times when we need to perform tasks that take a mighty long time to complete. These tasks might involve fetchin' data from the depths of the internet, or performin' complex calculations that would make even Davy Jones scratch his head.\n",
"\n",
"In the same way, asynchronous programming allows different parts of yer code to work together without gettin' in each other's way. It be especially useful for tasks that take a long time to complete, like loadin' large files or connectin' to a server. Instead of makin' yer program wait for these tasks to finish, ye can let them run in the background while yer program continues to do other things.\n",
"In the olden days, we pirates used to wait patiently for each task to finish afore movin' on to the next one. But that be a waste of precious time, me hearties! We be pirates, always lookin' for ways to be more efficient and plunder more booty!\n",
"\n",
"So there ye have it, me hearties! Asynchronous programming be like havin' a crew of sailors workin' together without gettin' in each other's way. It be a powerful tool for any programmer, and one that can help ye sail the seas of code with ease!\n"
"That be where asynchronous programming comes in, me mateys. It be a way to tackle multiple tasks at once, without waitin' for each one to finish afore movin' on. It be like havin' a crew of scallywags workin' on different tasks simultaneously, while ye be overseein' the whole operation.\n",
"\n",
"Ye see, in asynchronous programming, we be breakin' down our tasks into smaller chunks called \"coroutines.\" Each coroutine be like a separate pirate, workin' on its own task. When a coroutine be startin' its work, it don't wait for the task to finish afore movin' on to the next one. Instead, it be movin' on to the next task, lettin' the first one continue in the background.\n",
"\n",
"Now, ye might be wonderin', \"But Blackbeard, how be we know when a task be finished if we don't wait for it?\" Ah, me hearties, that be where the magic of callbacks and promises come in!\n",
"\n",
"When a coroutine be startin' its work, it be attachin' a callback or a promise to it. This be like leavin' a message in a bottle, tellin' the coroutine what to do when it be finished. So, while the coroutine be workin' away, the rest of the crew be movin' on to other tasks, plunderin' more booty along the way.\n",
"\n",
"When a coroutine be finished with its task, it be sendin' a signal to the callback or fulfillin' the promise, lettin' the rest of the crew know that it be done. Then, the crew can gather 'round and handle the results of the completed task, celebratin' their victory and countin' their plunder.\n",
"\n",
"So, me hearties, asynchronous programming be like havin' a crew of pirates workin' on different tasks at once, without waitin' for each one to finish afore movin' on. It be a way to be more efficient, plunder more booty, and conquer the vast seas of programming!\n",
"\n",
"Now, set sail, me mateys, and embrace the power of asynchronous programming like true pirates of the digital realm! Arr!\n"
]
}
],
"source": [
"# example without a system message\n",
"response = openai.ChatCompletion.create(\n",
"response = client.chat.completions.create(\n",
" model=MODEL,\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"Explain asynchronous programming in the style of the pirate Blackbeard.\"},\n",
@ -230,7 +294,7 @@
" temperature=0,\n",
")\n",
"\n",
"print(response['choices'][0]['message']['content'])\n"
"print(response.choices[0].message.content)\n"
]
},
{
@ -257,28 +321,28 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sure! Fractions are a way of representing a part of a whole. The top number of a fraction is called the numerator, and it represents how many parts of the whole we are talking about. The bottom number is called the denominator, and it represents how many equal parts the whole is divided into.\n",
"Of course! Fractions are a way to represent parts of a whole. They are made up of two numbers: a numerator and a denominator. The numerator tells you how many parts you have, and the denominator tells you how many equal parts make up the whole.\n",
"\n",
"For example, if we have a pizza that is divided into 8 equal slices, and we take 3 slices, we can represent that as the fraction 3/8. The numerator is 3 because we took 3 slices, and the denominator is 8 because the pizza was divided into 8 slices.\n",
"Let's take an example to understand this better. Imagine you have a pizza that is divided into 8 equal slices. If you eat 3 slices, you can represent that as the fraction 3/8. Here, the numerator is 3 because you ate 3 slices, and the denominator is 8 because the whole pizza is divided into 8 slices.\n",
"\n",
"To add or subtract fractions, we need to have a common denominator. This means that the denominators of the fractions need to be the same. To do this, we can find the least common multiple (LCM) of the denominators and then convert each fraction to an equivalent fraction with the LCM as the denominator.\n",
"Fractions can also be used to represent numbers less than 1. For example, if you eat half of a pizza, you can write it as 1/2. Here, the numerator is 1 because you ate one slice, and the denominator is 2 because the whole pizza is divided into 2 equal parts.\n",
"\n",
"To multiply fractions, we simply multiply the numerators together and the denominators together. To divide fractions, we multiply the first fraction by the reciprocal of the second fraction (flip the second fraction upside down).\n",
"Now, let's talk about equivalent fractions. Equivalent fractions are different fractions that represent the same amount. For example, 1/2 and 2/4 are equivalent fractions because they both represent half of something. To find equivalent fractions, you can multiply or divide both the numerator and denominator by the same number.\n",
"\n",
"Now, here's a question to check for understanding: If we have a pizza that is divided into 12 equal slices, and we take 4 slices, what is the fraction that represents how much of the pizza we took?\n"
"Here's a question to check your understanding: If you have a cake divided into 12 equal slices and you eat 4 slices, what fraction of the cake did you eat?\n"
]
}
],
"source": [
"# An example of a system message that primes the assistant to explain concepts in great depth\n",
"response = openai.ChatCompletion.create(\n",
"response = client.chat.completions.create(\n",
" model=MODEL,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a friendly and helpful teaching assistant. You explain concepts in great depth using simple terms, and you give examples to help people learn. At the end of each explanation, you ask a question to check for understanding\"},\n",
@ -287,25 +351,25 @@
" temperature=0,\n",
")\n",
"\n",
"print(response[\"choices\"][0][\"message\"][\"content\"])\n"
"print(response.choices[0].message.content)\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fractions represent a part of a whole. They consist of a numerator (top number) and a denominator (bottom number) separated by a line. The numerator represents how many parts of the whole are being considered, while the denominator represents the total number of equal parts that make up the whole.\n"
"Fractions represent parts of a whole. They have a numerator (top number) and a denominator (bottom number).\n"
]
}
],
"source": [
"# An example of a system message that primes the assistant to give brief, to-the-point answers\n",
"response = openai.ChatCompletion.create(\n",
"response = client.chat.completions.create(\n",
" model=MODEL,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a laconic assistant. You reply with brief, to-the-point answers with no elaboration.\"},\n",
@ -314,7 +378,7 @@
" temperature=0,\n",
")\n",
"\n",
"print(response[\"choices\"][0][\"message\"][\"content\"])\n"
"print(response.choices[0].message.content)\n"
]
},
{
@ -333,20 +397,20 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"We don't have enough time to complete the entire project perfectly.\n"
"This sudden change in direction means we don't have enough time to complete the entire project for the client.\n"
]
}
],
"source": [
"# An example of a faked few-shot conversation to prime the model into translating business jargon to simpler speech\n",
"response = openai.ChatCompletion.create(\n",
"response = client.chat.completions.create(\n",
" model=MODEL,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful, pattern-following assistant.\"},\n",
@ -361,7 +425,7 @@
" temperature=0,\n",
")\n",
"\n",
"print(response[\"choices\"][0][\"message\"][\"content\"])\n"
"print(response.choices[0].message.content)\n"
]
},
{
@ -376,20 +440,20 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"This sudden change in plans means we don't have enough time to do everything for the client's project.\n"
"This sudden change in direction means we don't have enough time to complete the entire project for the client.\n"
]
}
],
"source": [
"# The business jargon translation example, but with example names for the example messages\n",
"response = openai.ChatCompletion.create(\n",
"response = client.chat.completions.create(\n",
" model=MODEL,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful, pattern-following assistant that translates corporate jargon into plain English.\"},\n",
@ -402,7 +466,7 @@
" temperature=0,\n",
")\n",
"\n",
"print(response[\"choices\"][0][\"message\"][\"content\"])\n"
"print(response.choices[0].message.content)\n"
]
},
{
@ -444,7 +508,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
@ -494,35 +558,29 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gpt-3.5-turbo-0301\n",
"127 prompt tokens counted by num_tokens_from_messages().\n",
"127 prompt tokens counted by the OpenAI API.\n",
"\n",
"gpt-3.5-turbo-0613\n",
"gpt-3.5-turbo-1106\n",
"Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.\n",
"129 prompt tokens counted by num_tokens_from_messages().\n",
"129 prompt tokens counted by the OpenAI API.\n",
"\n",
"gpt-3.5-turbo\n",
"Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.\n",
"129 prompt tokens counted by num_tokens_from_messages().\n",
"127 prompt tokens counted by the OpenAI API.\n",
"\n",
"gpt-4-0314\n",
"129 prompt tokens counted by num_tokens_from_messages().\n",
"129 prompt tokens counted by the OpenAI API.\n",
"\n",
"gpt-4-0613\n",
"gpt-4\n",
"Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.\n",
"129 prompt tokens counted by num_tokens_from_messages().\n",
"129 prompt tokens counted by the OpenAI API.\n",
"\n",
"gpt-4\n",
"gpt-4-1106-preview\n",
"Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.\n",
"129 prompt tokens counted by num_tokens_from_messages().\n",
"129 prompt tokens counted by the OpenAI API.\n",
@ -532,9 +590,6 @@
],
"source": [
"# let's verify the function above matches the OpenAI API response\n",
"\n",
"import openai\n",
"\n",
"example_messages = [\n",
" {\n",
" \"role\": \"system\",\n",
@ -567,26 +622,33 @@
"]\n",
"\n",
"for model in [\n",
" \"gpt-3.5-turbo-0301\",\n",
" \"gpt-3.5-turbo-0613\",\n",
" # \"gpt-3.5-turbo-0301\",\n",
" # \"gpt-4-0314\",\n",
" # \"gpt-4-0613\",\n",
" \"gpt-3.5-turbo-1106\",\n",
" \"gpt-3.5-turbo\",\n",
" \"gpt-4-0314\",\n",
" \"gpt-4-0613\",\n",
" \"gpt-4\",\n",
" \"gpt-4-1106-preview\",\n",
" ]:\n",
" print(model)\n",
" # example token count from the function defined above\n",
" print(f\"{num_tokens_from_messages(example_messages, model)} prompt tokens counted by num_tokens_from_messages().\")\n",
" # example token count from the OpenAI API\n",
" response = openai.ChatCompletion.create(\n",
" model=model,\n",
" messages=example_messages,\n",
" temperature=0,\n",
" max_tokens=1, # we're only counting input tokens here, so let's not waste tokens on the output\n",
" )\n",
" print(f'{response[\"usage\"][\"prompt_tokens\"]} prompt tokens counted by the OpenAI API.')\n",
" response = client.chat.completions.create(model=model,\n",
" messages=example_messages,\n",
" temperature=0,\n",
" max_tokens=1)\n",
" token = response.usage.prompt_tokens\n",
" print(f'{token} prompt tokens counted by the OpenAI API.')\n",
" print()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@ -605,7 +667,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
"version": "3.11.5"
},
"orig_nbformat": 4,
"vscode": {

@ -48,6 +48,18 @@
"- [OpenAI Rate Limit Increase Request form](https://forms.gle/56ZrwXXoxAN1yt6i9)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"import os\n",
"\n",
"client = openai.OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -69,9 +81,6 @@
"metadata": {},
"outputs": [],
"source": [
"from openai import OpenAI # for making OpenAI API requests\n",
"client = OpenAI()\n",
"\n",
"# request a bunch of completions in a loop\n",
"for _ in range(100):\n",
" client.chat.completions.create(\n",
@ -132,8 +141,6 @@
}
],
"source": [
"from openai import OpenAI # for OpenAI API calls\n",
"client = OpenAI()\n",
"from tenacity import (\n",
" retry,\n",
" stop_after_attempt,\n",
@ -177,10 +184,6 @@
],
"source": [
"import backoff # for exponential backoff\n",
"from openai import OpenAI # for OpenAI API calls\n",
"import openai\n",
"client = OpenAI() \n",
"\n",
"\n",
"@backoff.on_exception(backoff.expo, openai.RateLimitError)\n",
"def completions_with_backoff(**kwargs):\n",
@ -220,10 +223,6 @@
"import random\n",
"import time\n",
"\n",
"import openai\n",
"from openai import OpenAI\n",
"client = OpenAI()\n",
"\n",
"# define a retry decorator\n",
"def retry_with_exponential_backoff(\n",
" func,\n",
@ -316,8 +315,6 @@
"source": [
"# imports\n",
"import time\n",
"from openai import OpenAI\n",
"client = OpenAI()\n",
"\n",
"# Define a function that adds a delay to a Completion API call\n",
"def delayed_completion(delay_in_seconds: float = 1, **kwargs):\n",
@ -383,10 +380,6 @@
}
],
"source": [
"from openai import OpenAI # for making OpenAI API requests\n",
"client = OpenAI() \n",
"\n",
"\n",
"num_stories = 10\n",
"content = \"Once upon a time,\"\n",
"\n",
@ -432,15 +425,11 @@
}
],
"source": [
"from openai import OpenAI # for making OpenAI API requests\n",
"client = OpenAI() \n",
"\n",
"\n",
"num_stories = 10\n",
"prompts = [\"Once upon a time,\"] * num_stories\n",
"\n",
"# batched example, with 10 stories completions per request\n",
"response = client.completions.create(\n",
"response = client.chat.completions.create(\n",
" model=\"curie\",\n",
" prompt=prompts,\n",
" max_tokens=20,\n",

@ -26,8 +26,7 @@
"Below, this notebook shows:\n",
"1. What a typical chat completion response looks like\n",
"2. What a streaming chat completion response looks like\n",
"3. How much time is saved by streaming a chat completion\n",
"4. How to stream non-chat completions (used by older models like `text-davinci-003`)"
"3. How much time is saved by streaming a chat completion"
]
},
{
@ -36,19 +35,20 @@
"metadata": {},
"outputs": [],
"source": [
"# !pip install openai==1.3.7"
"# !pip install openai"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"import time # for measuring time duration of API calls\n",
"from openai import OpenAI\n",
"client = OpenAI() # for OpenAI API calls"
"import os\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))"
]
},
{
@ -63,28 +63,28 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Full response received 7.90 seconds after request\n",
"Full response received 5.27 seconds after request\n",
"Full response received:\n",
"ChatCompletion(id='chatcmpl-8SjX2G5nLRWrfnnPeFn9HAaJvSVZf', choices=[Choice(finish_reason='stop', index=0, message=ChatCompletionMessage(content='1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100.', role='assistant', function_call=None, tool_calls=None))], created=1701858900, model='gpt-3.5-turbo-0613', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=299, prompt_tokens=36, total_tokens=335))\n"
"ChatCompletion(id='chatcmpl-8ZB8ywkV5DuuJO7xktqUcNYfG8j6I', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100.', role='assistant', function_call=None, tool_calls=None))], created=1703395008, model='gpt-3.5-turbo-0613', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=299, prompt_tokens=36, total_tokens=335))\n"
]
}
],
"source": [
"# Example of an OpenAI ChatCompletion request\n",
"# https://platform.openai.com/docs/guides/chat\n",
"# https://platform.openai.com/docs/guides/text-generation/chat-completions-api\n",
"\n",
"# record the time before the request is sent\n",
"start_time = time.time()\n",
"\n",
"# send a ChatCompletion request to count to 100\n",
"completion = client.chat.completions.create(\n",
"response = client.chat.completions.create(\n",
" model='gpt-3.5-turbo',\n",
" messages=[\n",
" {'role': 'user', 'content': 'Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ...'}\n",
@ -96,7 +96,7 @@
"\n",
"# print the time delay and text received\n",
"print(f\"Full response received {response_time:.2f} seconds after request\")\n",
"print(f\"Full response received:\\n{completion}\")\n"
"print(f\"Full response received:\\n{response}\")\n"
]
},
{
@ -104,14 +104,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The reply can be extracted with `response['choices'][0]['message']`.\n",
"The reply can be extracted with `response.choices[0].message`.\n",
"\n",
"The content of the reply can be extracted with `response['choices'][0]['message']['content']`."
"The content of the reply can be extracted with `response.choices[0].message.content`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
@ -126,10 +126,10 @@
}
],
"source": [
"reply = completion.choices[0].message\n",
"reply = response.choices[0].message\n",
"print(f\"Extracted reply: \\n{reply}\")\n",
"\n",
"reply_content = completion.choices[0].message.content\n",
"reply_content = response.choices[0].message.content\n",
"print(f\"Extracted content: \\n{reply_content}\")\n"
]
},
@ -147,20 +147,20 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ChatCompletionChunk(id='chatcmpl-8SjXATR0aAejfQHMDHkNYYRTiUdca', choices=[Choice(delta=ChoiceDelta(content='', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=0)], created=1701858908, model='gpt-3.5-turbo-0613', object='chat.completion.chunk', system_fingerprint=None)\n",
"ChatCompletionChunk(id='chatcmpl-8ZB9m2Ubv8FJs3CIb84WvYwqZCHST', choices=[Choice(delta=ChoiceDelta(content='', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=0, logprobs=None)], created=1703395058, model='gpt-3.5-turbo-0613', object='chat.completion.chunk', system_fingerprint=None)\n",
"\n",
"****************\n",
"ChatCompletionChunk(id='chatcmpl-8SjXATR0aAejfQHMDHkNYYRTiUdca', choices=[Choice(delta=ChoiceDelta(content='2', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0)], created=1701858908, model='gpt-3.5-turbo-0613', object='chat.completion.chunk', system_fingerprint=None)\n",
"ChatCompletionChunk(id='chatcmpl-8ZB9m2Ubv8FJs3CIb84WvYwqZCHST', choices=[Choice(delta=ChoiceDelta(content='2', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=None)], created=1703395058, model='gpt-3.5-turbo-0613', object='chat.completion.chunk', system_fingerprint=None)\n",
"2\n",
"****************\n",
"ChatCompletionChunk(id='chatcmpl-8SjXATR0aAejfQHMDHkNYYRTiUdca', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, role=None, tool_calls=None), finish_reason='stop', index=0)], created=1701858908, model='gpt-3.5-turbo-0613', object='chat.completion.chunk', system_fingerprint=None)\n",
"ChatCompletionChunk(id='chatcmpl-8ZB9m2Ubv8FJs3CIb84WvYwqZCHST', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, role=None, tool_calls=None), finish_reason='stop', index=0, logprobs=None)], created=1703395058, model='gpt-3.5-turbo-0613', object='chat.completion.chunk', system_fingerprint=None)\n",
"None\n",
"****************\n"
]
@ -168,9 +168,10 @@
],
"source": [
"# Example of an OpenAI ChatCompletion request with stream=True\n",
"# https://platform.openai.com/docs/guides/chat\n",
"# https://platform.openai.com/docs/api-reference/streaming#chat/create-stream\n",
"\n",
"completion = client.chat.completions.create(\n",
"# a ChatCompletion request\n",
"response = client.chat.completions.create(\n",
" model='gpt-3.5-turbo',\n",
" messages=[\n",
" {'role': 'user', 'content': \"What's 1+1? Answer in one word.\"}\n",
@ -179,7 +180,7 @@
" stream=True # this time, we set stream=True\n",
")\n",
"\n",
"for chunk in completion:\n",
"for chunk in response:\n",
" print(chunk)\n",
" print(chunk.choices[0].delta.content)\n",
" print(\"****************\")"
@ -208,328 +209,328 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Message received 0.99 seconds after request: \n",
"Message received 1.06 seconds after request: 1\n",
"Message received 1.06 seconds after request: ,\n",
"Message received 1.06 seconds after request: \n",
"Message received 1.14 seconds after request: 2\n",
"Message received 1.14 seconds after request: ,\n",
"Message received 1.14 seconds after request: \n",
"Message received 1.20 seconds after request: 3\n",
"Message received 1.20 seconds after request: ,\n",
"Message received 1.20 seconds after request: \n",
"Message received 1.27 seconds after request: 4\n",
"Message received 0.31 seconds after request: \n",
"Message received 0.31 seconds after request: 1\n",
"Message received 0.34 seconds after request: ,\n",
"Message received 0.34 seconds after request: \n",
"Message received 0.34 seconds after request: 2\n",
"Message received 0.39 seconds after request: ,\n",
"Message received 0.39 seconds after request: \n",
"Message received 0.39 seconds after request: 3\n",
"Message received 0.42 seconds after request: ,\n",
"Message received 0.42 seconds after request: \n",
"Message received 0.42 seconds after request: 4\n",
"Message received 0.47 seconds after request: ,\n",
"Message received 0.47 seconds after request: \n",
"Message received 0.47 seconds after request: 5\n",
"Message received 0.51 seconds after request: ,\n",
"Message received 0.51 seconds after request: \n",
"Message received 0.51 seconds after request: 6\n",
"Message received 0.55 seconds after request: ,\n",
"Message received 0.55 seconds after request: \n",
"Message received 0.55 seconds after request: 7\n",
"Message received 0.59 seconds after request: ,\n",
"Message received 0.59 seconds after request: \n",
"Message received 0.59 seconds after request: 8\n",
"Message received 0.63 seconds after request: ,\n",
"Message received 0.63 seconds after request: \n",
"Message received 0.63 seconds after request: 9\n",
"Message received 0.67 seconds after request: ,\n",
"Message received 0.67 seconds after request: \n",
"Message received 0.67 seconds after request: 10\n",
"Message received 0.71 seconds after request: ,\n",
"Message received 0.71 seconds after request: \n",
"Message received 0.71 seconds after request: 11\n",
"Message received 0.75 seconds after request: ,\n",
"Message received 0.75 seconds after request: \n",
"Message received 0.75 seconds after request: 12\n",
"Message received 0.98 seconds after request: ,\n",
"Message received 0.98 seconds after request: \n",
"Message received 0.98 seconds after request: 13\n",
"Message received 1.02 seconds after request: ,\n",
"Message received 1.02 seconds after request: \n",
"Message received 1.02 seconds after request: 14\n",
"Message received 1.04 seconds after request: ,\n",
"Message received 1.04 seconds after request: \n",
"Message received 1.04 seconds after request: 15\n",
"Message received 1.08 seconds after request: ,\n",
"Message received 1.08 seconds after request: \n",
"Message received 1.08 seconds after request: 16\n",
"Message received 1.12 seconds after request: ,\n",
"Message received 1.12 seconds after request: \n",
"Message received 1.12 seconds after request: 17\n",
"Message received 1.16 seconds after request: ,\n",
"Message received 1.16 seconds after request: \n",
"Message received 1.16 seconds after request: 18\n",
"Message received 1.19 seconds after request: ,\n",
"Message received 1.19 seconds after request: \n",
"Message received 1.19 seconds after request: 19\n",
"Message received 1.23 seconds after request: ,\n",
"Message received 1.23 seconds after request: \n",
"Message received 1.23 seconds after request: 20\n",
"Message received 1.27 seconds after request: ,\n",
"Message received 1.27 seconds after request: \n",
"Message received 1.34 seconds after request: 5\n",
"Message received 1.34 seconds after request: ,\n",
"Message received 1.34 seconds after request: \n",
"Message received 1.40 seconds after request: 6\n",
"Message received 1.40 seconds after request: ,\n",
"Message received 1.40 seconds after request: \n",
"Message received 1.47 seconds after request: 7\n",
"Message received 1.48 seconds after request: ,\n",
"Message received 1.48 seconds after request: \n",
"Message received 1.56 seconds after request: 8\n",
"Message received 1.56 seconds after request: ,\n",
"Message received 1.56 seconds after request: \n",
"Message received 1.63 seconds after request: 9\n",
"Message received 1.63 seconds after request: ,\n",
"Message received 1.63 seconds after request: \n",
"Message received 1.71 seconds after request: 10\n",
"Message received 1.71 seconds after request: ,\n",
"Message received 1.71 seconds after request: \n",
"Message received 1.78 seconds after request: 11\n",
"Message received 1.78 seconds after request: ,\n",
"Message received 1.78 seconds after request: \n",
"Message received 1.86 seconds after request: 12\n",
"Message received 1.27 seconds after request: 21\n",
"Message received 1.31 seconds after request: ,\n",
"Message received 1.31 seconds after request: \n",
"Message received 1.31 seconds after request: 22\n",
"Message received 1.35 seconds after request: ,\n",
"Message received 1.35 seconds after request: \n",
"Message received 1.35 seconds after request: 23\n",
"Message received 1.39 seconds after request: ,\n",
"Message received 1.39 seconds after request: \n",
"Message received 1.39 seconds after request: 24\n",
"Message received 1.43 seconds after request: ,\n",
"Message received 1.43 seconds after request: \n",
"Message received 1.43 seconds after request: 25\n",
"Message received 1.47 seconds after request: ,\n",
"Message received 1.47 seconds after request: \n",
"Message received 1.47 seconds after request: 26\n",
"Message received 1.51 seconds after request: ,\n",
"Message received 1.51 seconds after request: \n",
"Message received 1.51 seconds after request: 27\n",
"Message received 1.55 seconds after request: ,\n",
"Message received 1.55 seconds after request: \n",
"Message received 1.55 seconds after request: 28\n",
"Message received 1.59 seconds after request: ,\n",
"Message received 1.59 seconds after request: \n",
"Message received 1.59 seconds after request: 29\n",
"Message received 1.59 seconds after request: ,\n",
"Message received 1.59 seconds after request: \n",
"Message received 1.59 seconds after request: 30\n",
"Message received 1.59 seconds after request: ,\n",
"Message received 1.59 seconds after request: \n",
"Message received 1.59 seconds after request: 31\n",
"Message received 1.59 seconds after request: ,\n",
"Message received 1.59 seconds after request: \n",
"Message received 1.60 seconds after request: 32\n",
"Message received 1.60 seconds after request: ,\n",
"Message received 1.60 seconds after request: \n",
"Message received 1.60 seconds after request: 33\n",
"Message received 1.60 seconds after request: ,\n",
"Message received 1.60 seconds after request: \n",
"Message received 1.67 seconds after request: 34\n",
"Message received 1.67 seconds after request: ,\n",
"Message received 1.67 seconds after request: \n",
"Message received 1.68 seconds after request: 35\n",
"Message received 1.68 seconds after request: ,\n",
"Message received 1.68 seconds after request: \n",
"Message received 1.86 seconds after request: 36\n",
"Message received 1.86 seconds after request: ,\n",
"Message received 1.86 seconds after request: \n",
"Message received 2.20 seconds after request: 13\n",
"Message received 2.20 seconds after request: ,\n",
"Message received 2.20 seconds after request: \n",
"Message received 2.31 seconds after request: 14\n",
"Message received 1.90 seconds after request: 37\n",
"Message received 1.90 seconds after request: ,\n",
"Message received 1.90 seconds after request: \n",
"Message received 1.94 seconds after request: 38\n",
"Message received 1.94 seconds after request: ,\n",
"Message received 1.94 seconds after request: \n",
"Message received 1.98 seconds after request: 39\n",
"Message received 1.98 seconds after request: ,\n",
"Message received 1.98 seconds after request: \n",
"Message received 2.05 seconds after request: 40\n",
"Message received 2.05 seconds after request: ,\n",
"Message received 2.05 seconds after request: \n",
"Message received 2.09 seconds after request: 41\n",
"Message received 2.09 seconds after request: ,\n",
"Message received 2.09 seconds after request: \n",
"Message received 2.14 seconds after request: 42\n",
"Message received 2.14 seconds after request: ,\n",
"Message received 2.14 seconds after request: \n",
"Message received 2.14 seconds after request: 43\n",
"Message received 2.14 seconds after request: ,\n",
"Message received 2.14 seconds after request: \n",
"Message received 2.14 seconds after request: 44\n",
"Message received 2.14 seconds after request: ,\n",
"Message received 2.14 seconds after request: \n",
"Message received 2.14 seconds after request: 45\n",
"Message received 2.14 seconds after request: ,\n",
"Message received 2.14 seconds after request: \n",
"Message received 2.15 seconds after request: 46\n",
"Message received 2.15 seconds after request: ,\n",
"Message received 2.15 seconds after request: \n",
"Message received 2.30 seconds after request: 47\n",
"Message received 2.30 seconds after request: ,\n",
"Message received 2.30 seconds after request: \n",
"Message received 2.30 seconds after request: 48\n",
"Message received 2.30 seconds after request: ,\n",
"Message received 2.30 seconds after request: \n",
"Message received 2.30 seconds after request: 49\n",
"Message received 2.30 seconds after request: ,\n",
"Message received 2.30 seconds after request: \n",
"Message received 2.31 seconds after request: 50\n",
"Message received 2.31 seconds after request: ,\n",
"Message received 2.31 seconds after request: \n",
"Message received 2.38 seconds after request: 15\n",
"Message received 2.38 seconds after request: ,\n",
"Message received 2.39 seconds after request: 51\n",
"Message received 2.39 seconds after request: ,\n",
"Message received 2.39 seconds after request: \n",
"Message received 2.46 seconds after request: 16\n",
"Message received 2.46 seconds after request: ,\n",
"Message received 2.46 seconds after request: \n",
"Message received 2.53 seconds after request: 17\n",
"Message received 2.53 seconds after request: ,\n",
"Message received 2.53 seconds after request: \n",
"Message received 2.61 seconds after request: 18\n",
"Message received 2.61 seconds after request: ,\n",
"Message received 2.61 seconds after request: \n",
"Message received 2.67 seconds after request: 19\n",
"Message received 2.67 seconds after request: ,\n",
"Message received 2.67 seconds after request: \n",
"Message received 2.74 seconds after request: 20\n",
"Message received 2.74 seconds after request: ,\n",
"Message received 2.74 seconds after request: \n",
"Message received 2.83 seconds after request: 21\n",
"Message received 2.83 seconds after request: ,\n",
"Message received 2.83 seconds after request: \n",
"Message received 2.89 seconds after request: 22\n",
"Message received 2.40 seconds after request: 52\n",
"Message received 2.40 seconds after request: ,\n",
"Message received 2.40 seconds after request: \n",
"Message received 2.48 seconds after request: 53\n",
"Message received 2.48 seconds after request: ,\n",
"Message received 2.48 seconds after request: \n",
"Message received 2.49 seconds after request: 54\n",
"Message received 2.49 seconds after request: ,\n",
"Message received 2.49 seconds after request: \n",
"Message received 2.68 seconds after request: 55\n",
"Message received 2.68 seconds after request: ,\n",
"Message received 2.68 seconds after request: \n",
"Message received 2.72 seconds after request: 56\n",
"Message received 2.72 seconds after request: ,\n",
"Message received 2.72 seconds after request: \n",
"Message received 2.77 seconds after request: 57\n",
"Message received 2.77 seconds after request: ,\n",
"Message received 2.77 seconds after request: \n",
"Message received 2.80 seconds after request: 58\n",
"Message received 2.80 seconds after request: ,\n",
"Message received 2.80 seconds after request: \n",
"Message received 2.85 seconds after request: 59\n",
"Message received 2.85 seconds after request: ,\n",
"Message received 2.85 seconds after request: \n",
"Message received 2.88 seconds after request: 60\n",
"Message received 2.88 seconds after request: ,\n",
"Message received 2.88 seconds after request: \n",
"Message received 2.88 seconds after request: 61\n",
"Message received 2.88 seconds after request: ,\n",
"Message received 2.88 seconds after request: \n",
"Message received 2.89 seconds after request: 62\n",
"Message received 2.89 seconds after request: ,\n",
"Message received 2.89 seconds after request: \n",
"Message received 2.89 seconds after request: 63\n",
"Message received 2.89 seconds after request: ,\n",
"Message received 2.89 seconds after request: \n",
"Message received 3.02 seconds after request: 23\n",
"Message received 3.02 seconds after request: ,\n",
"Message received 3.03 seconds after request: \n",
"Message received 3.09 seconds after request: 24\n",
"Message received 3.09 seconds after request: ,\n",
"Message received 3.09 seconds after request: \n",
"Message received 3.15 seconds after request: 25\n",
"Message received 3.15 seconds after request: ,\n",
"Message received 3.16 seconds after request: \n",
"Message received 3.22 seconds after request: 26\n",
"Message received 3.22 seconds after request: ,\n",
"Message received 3.22 seconds after request: \n",
"Message received 3.30 seconds after request: 27\n",
"Message received 3.30 seconds after request: ,\n",
"Message received 3.30 seconds after request: \n",
"Message received 3.37 seconds after request: 28\n",
"Message received 2.92 seconds after request: 64\n",
"Message received 2.92 seconds after request: ,\n",
"Message received 2.92 seconds after request: \n",
"Message received 3.37 seconds after request: 65\n",
"Message received 3.37 seconds after request: ,\n",
"Message received 3.37 seconds after request: \n",
"Message received 3.55 seconds after request: 29\n",
"Message received 3.55 seconds after request: ,\n",
"Message received 3.55 seconds after request: \n",
"Message received 3.55 seconds after request: 30\n",
"Message received 3.55 seconds after request: ,\n",
"Message received 3.55 seconds after request: \n",
"Message received 3.55 seconds after request: 31\n",
"Message received 3.55 seconds after request: ,\n",
"Message received 3.56 seconds after request: \n",
"Message received 3.56 seconds after request: 32\n",
"Message received 3.56 seconds after request: ,\n",
"Message received 3.56 seconds after request: \n",
"Message received 3.56 seconds after request: 33\n",
"Message received 3.38 seconds after request: 66\n",
"Message received 3.38 seconds after request: ,\n",
"Message received 3.38 seconds after request: \n",
"Message received 3.38 seconds after request: 67\n",
"Message received 3.38 seconds after request: ,\n",
"Message received 3.38 seconds after request: \n",
"Message received 3.38 seconds after request: 68\n",
"Message received 3.38 seconds after request: ,\n",
"Message received 3.38 seconds after request: \n",
"Message received 3.42 seconds after request: 69\n",
"Message received 3.42 seconds after request: ,\n",
"Message received 3.42 seconds after request: \n",
"Message received 3.43 seconds after request: 70\n",
"Message received 3.43 seconds after request: ,\n",
"Message received 3.43 seconds after request: \n",
"Message received 3.46 seconds after request: 71\n",
"Message received 3.46 seconds after request: ,\n",
"Message received 3.46 seconds after request: \n",
"Message received 3.47 seconds after request: 72\n",
"Message received 3.47 seconds after request: ,\n",
"Message received 3.47 seconds after request: \n",
"Message received 3.50 seconds after request: 73\n",
"Message received 3.50 seconds after request: ,\n",
"Message received 3.50 seconds after request: \n",
"Message received 3.51 seconds after request: 74\n",
"Message received 3.51 seconds after request: ,\n",
"Message received 3.51 seconds after request: \n",
"Message received 3.52 seconds after request: 75\n",
"Message received 3.52 seconds after request: ,\n",
"Message received 3.52 seconds after request: \n",
"Message received 3.54 seconds after request: 76\n",
"Message received 3.54 seconds after request: ,\n",
"Message received 3.54 seconds after request: \n",
"Message received 3.56 seconds after request: 77\n",
"Message received 3.56 seconds after request: ,\n",
"Message received 3.56 seconds after request: \n",
"Message received 3.60 seconds after request: 34\n",
"Message received 3.60 seconds after request: ,\n",
"Message received 3.60 seconds after request: \n",
"Message received 3.63 seconds after request: 35\n",
"Message received 3.63 seconds after request: ,\n",
"Message received 3.63 seconds after request: \n",
"Message received 3.95 seconds after request: 36\n",
"Message received 3.59 seconds after request: 78\n",
"Message received 3.59 seconds after request: ,\n",
"Message received 3.59 seconds after request: \n",
"Message received 3.59 seconds after request: 79\n",
"Message received 3.59 seconds after request: ,\n",
"Message received 3.59 seconds after request: \n",
"Message received 3.59 seconds after request: 80\n",
"Message received 3.59 seconds after request: ,\n",
"Message received 3.59 seconds after request: \n",
"Message received 3.61 seconds after request: 81\n",
"Message received 3.61 seconds after request: ,\n",
"Message received 3.61 seconds after request: \n",
"Message received 3.65 seconds after request: 82\n",
"Message received 3.65 seconds after request: ,\n",
"Message received 3.65 seconds after request: \n",
"Message received 3.85 seconds after request: 83\n",
"Message received 3.85 seconds after request: ,\n",
"Message received 3.85 seconds after request: \n",
"Message received 3.90 seconds after request: 84\n",
"Message received 3.90 seconds after request: ,\n",
"Message received 3.90 seconds after request: \n",
"Message received 3.95 seconds after request: 85\n",
"Message received 3.95 seconds after request: ,\n",
"Message received 3.95 seconds after request: \n",
"Message received 4.02 seconds after request: 37\n",
"Message received 4.02 seconds after request: ,\n",
"Message received 4.02 seconds after request: \n",
"Message received 4.08 seconds after request: 38\n",
"Message received 4.00 seconds after request: 86\n",
"Message received 4.00 seconds after request: ,\n",
"Message received 4.00 seconds after request: \n",
"Message received 4.04 seconds after request: 87\n",
"Message received 4.04 seconds after request: ,\n",
"Message received 4.04 seconds after request: \n",
"Message received 4.08 seconds after request: 88\n",
"Message received 4.08 seconds after request: ,\n",
"Message received 4.08 seconds after request: \n",
"Message received 4.17 seconds after request: 39\n",
"Message received 4.17 seconds after request: ,\n",
"Message received 4.17 seconds after request: \n",
"Message received 4.22 seconds after request: 40\n",
"Message received 4.22 seconds after request: ,\n",
"Message received 4.22 seconds after request: \n",
"Message received 4.29 seconds after request: 41\n",
"Message received 4.29 seconds after request: ,\n",
"Message received 4.29 seconds after request: \n",
"Message received 4.37 seconds after request: 42\n",
"Message received 4.37 seconds after request: ,\n",
"Message received 4.37 seconds after request: \n",
"Message received 4.37 seconds after request: 43\n",
"Message received 4.37 seconds after request: ,\n",
"Message received 4.37 seconds after request: \n",
"Message received 4.37 seconds after request: 44\n",
"Message received 4.37 seconds after request: ,\n",
"Message received 4.37 seconds after request: \n",
"Message received 4.38 seconds after request: 45\n",
"Message received 4.38 seconds after request: ,\n",
"Message received 4.38 seconds after request: \n",
"Message received 4.39 seconds after request: 46\n",
"Message received 4.12 seconds after request: 89\n",
"Message received 4.12 seconds after request: ,\n",
"Message received 4.12 seconds after request: \n",
"Message received 4.18 seconds after request: 90\n",
"Message received 4.18 seconds after request: ,\n",
"Message received 4.18 seconds after request: \n",
"Message received 4.18 seconds after request: 91\n",
"Message received 4.18 seconds after request: ,\n",
"Message received 4.18 seconds after request: \n",
"Message received 4.18 seconds after request: 92\n",
"Message received 4.18 seconds after request: ,\n",
"Message received 4.18 seconds after request: \n",
"Message received 4.19 seconds after request: 93\n",
"Message received 4.19 seconds after request: ,\n",
"Message received 4.19 seconds after request: \n",
"Message received 4.20 seconds after request: 94\n",
"Message received 4.20 seconds after request: ,\n",
"Message received 4.20 seconds after request: \n",
"Message received 4.23 seconds after request: 95\n",
"Message received 4.23 seconds after request: ,\n",
"Message received 4.23 seconds after request: \n",
"Message received 4.27 seconds after request: 96\n",
"Message received 4.27 seconds after request: ,\n",
"Message received 4.27 seconds after request: \n",
"Message received 4.39 seconds after request: 97\n",
"Message received 4.39 seconds after request: ,\n",
"Message received 4.39 seconds after request: \n",
"Message received 4.39 seconds after request: 98\n",
"Message received 4.39 seconds after request: ,\n",
"Message received 4.39 seconds after request: \n",
"Message received 4.66 seconds after request: 47\n",
"Message received 4.66 seconds after request: ,\n",
"Message received 4.66 seconds after request: \n",
"Message received 4.66 seconds after request: 48\n",
"Message received 4.66 seconds after request: ,\n",
"Message received 4.66 seconds after request: \n",
"Message received 4.66 seconds after request: 49\n",
"Message received 4.66 seconds after request: ,\n",
"Message received 4.66 seconds after request: \n",
"Message received 4.69 seconds after request: 50\n",
"Message received 4.69 seconds after request: ,\n",
"Message received 4.69 seconds after request: \n",
"Message received 4.80 seconds after request: 51\n",
"Message received 4.80 seconds after request: ,\n",
"Message received 4.80 seconds after request: \n",
"Message received 4.83 seconds after request: 52\n",
"Message received 4.83 seconds after request: ,\n",
"Message received 4.83 seconds after request: \n",
"Message received 4.95 seconds after request: 53\n",
"Message received 4.95 seconds after request: ,\n",
"Message received 4.95 seconds after request: \n",
"Message received 4.97 seconds after request: 54\n",
"Message received 4.97 seconds after request: ,\n",
"Message received 4.97 seconds after request: \n",
"Message received 5.31 seconds after request: 55\n",
"Message received 5.31 seconds after request: ,\n",
"Message received 5.31 seconds after request: \n",
"Message received 5.37 seconds after request: 56\n",
"Message received 5.37 seconds after request: ,\n",
"Message received 5.37 seconds after request: \n",
"Message received 5.47 seconds after request: 57\n",
"Message received 5.47 seconds after request: ,\n",
"Message received 5.47 seconds after request: \n",
"Message received 5.54 seconds after request: 58\n",
"Message received 5.54 seconds after request: ,\n",
"Message received 5.54 seconds after request: \n",
"Message received 5.60 seconds after request: 59\n",
"Message received 5.60 seconds after request: ,\n",
"Message received 5.60 seconds after request: \n",
"Message received 5.66 seconds after request: 60\n",
"Message received 5.66 seconds after request: ,\n",
"Message received 5.66 seconds after request: \n",
"Message received 5.67 seconds after request: 61\n",
"Message received 5.67 seconds after request: ,\n",
"Message received 5.67 seconds after request: \n",
"Message received 5.67 seconds after request: 62\n",
"Message received 5.67 seconds after request: ,\n",
"Message received 5.67 seconds after request: \n",
"Message received 5.67 seconds after request: 63\n",
"Message received 5.67 seconds after request: ,\n",
"Message received 5.67 seconds after request: \n",
"Message received 5.69 seconds after request: 64\n",
"Message received 5.69 seconds after request: ,\n",
"Message received 5.69 seconds after request: \n",
"Message received 5.95 seconds after request: 65\n",
"Message received 5.95 seconds after request: ,\n",
"Message received 5.95 seconds after request: \n",
"Message received 5.95 seconds after request: 66\n",
"Message received 5.95 seconds after request: ,\n",
"Message received 5.95 seconds after request: \n",
"Message received 5.95 seconds after request: 67\n",
"Message received 5.95 seconds after request: ,\n",
"Message received 5.95 seconds after request: \n",
"Message received 5.97 seconds after request: 68\n",
"Message received 5.97 seconds after request: ,\n",
"Message received 5.97 seconds after request: \n",
"Message received 6.08 seconds after request: 69\n",
"Message received 6.08 seconds after request: ,\n",
"Message received 6.08 seconds after request: \n",
"Message received 6.10 seconds after request: 70\n",
"Message received 6.10 seconds after request: ,\n",
"Message received 6.10 seconds after request: \n",
"Message received 6.17 seconds after request: 71\n",
"Message received 6.17 seconds after request: ,\n",
"Message received 6.17 seconds after request: \n",
"Message received 6.24 seconds after request: 72\n",
"Message received 6.24 seconds after request: ,\n",
"Message received 6.24 seconds after request: \n",
"Message received 6.35 seconds after request: 73\n",
"Message received 6.35 seconds after request: ,\n",
"Message received 6.35 seconds after request: \n",
"Message received 6.37 seconds after request: 74\n",
"Message received 6.37 seconds after request: ,\n",
"Message received 6.37 seconds after request: \n",
"Message received 6.45 seconds after request: 75\n",
"Message received 6.45 seconds after request: ,\n",
"Message received 6.45 seconds after request: \n",
"Message received 6.52 seconds after request: 76\n",
"Message received 6.52 seconds after request: ,\n",
"Message received 6.52 seconds after request: \n",
"Message received 6.59 seconds after request: 77\n",
"Message received 6.59 seconds after request: ,\n",
"Message received 6.59 seconds after request: \n",
"Message received 6.83 seconds after request: 78\n",
"Message received 6.83 seconds after request: ,\n",
"Message received 6.83 seconds after request: \n",
"Message received 6.83 seconds after request: 79\n",
"Message received 6.84 seconds after request: ,\n",
"Message received 6.84 seconds after request: \n",
"Message received 6.84 seconds after request: 80\n",
"Message received 6.84 seconds after request: ,\n",
"Message received 6.84 seconds after request: \n",
"Message received 6.85 seconds after request: 81\n",
"Message received 6.85 seconds after request: ,\n",
"Message received 6.85 seconds after request: \n",
"Message received 6.92 seconds after request: 82\n",
"Message received 6.92 seconds after request: ,\n",
"Message received 6.92 seconds after request: \n",
"Message received 7.23 seconds after request: 83\n",
"Message received 7.23 seconds after request: ,\n",
"Message received 7.23 seconds after request: \n",
"Message received 7.30 seconds after request: 84\n",
"Message received 7.30 seconds after request: ,\n",
"Message received 7.30 seconds after request: \n",
"Message received 7.37 seconds after request: 85\n",
"Message received 7.37 seconds after request: ,\n",
"Message received 7.37 seconds after request: \n",
"Message received 7.43 seconds after request: 86\n",
"Message received 7.43 seconds after request: ,\n",
"Message received 7.43 seconds after request: \n",
"Message received 7.70 seconds after request: 87\n",
"Message received 7.70 seconds after request: ,\n",
"Message received 7.70 seconds after request: \n",
"Message received 7.71 seconds after request: 88\n",
"Message received 7.71 seconds after request: ,\n",
"Message received 7.71 seconds after request: \n",
"Message received 7.73 seconds after request: 89\n",
"Message received 7.73 seconds after request: ,\n",
"Message received 7.73 seconds after request: \n",
"Message received 7.80 seconds after request: 90\n",
"Message received 7.80 seconds after request: ,\n",
"Message received 7.80 seconds after request: \n",
"Message received 7.80 seconds after request: 91\n",
"Message received 7.80 seconds after request: ,\n",
"Message received 7.80 seconds after request: \n",
"Message received 7.80 seconds after request: 92\n",
"Message received 7.80 seconds after request: ,\n",
"Message received 7.80 seconds after request: \n",
"Message received 7.80 seconds after request: 93\n",
"Message received 7.80 seconds after request: ,\n",
"Message received 7.80 seconds after request: \n",
"Message received 7.82 seconds after request: 94\n",
"Message received 7.82 seconds after request: ,\n",
"Message received 7.82 seconds after request: \n",
"Message received 7.90 seconds after request: 95\n",
"Message received 7.90 seconds after request: ,\n",
"Message received 7.90 seconds after request: \n",
"Message received 7.96 seconds after request: 96\n",
"Message received 7.96 seconds after request: ,\n",
"Message received 7.97 seconds after request: \n",
"Message received 8.13 seconds after request: 97\n",
"Message received 8.14 seconds after request: ,\n",
"Message received 8.14 seconds after request: \n",
"Message received 8.14 seconds after request: 98\n",
"Message received 8.14 seconds after request: ,\n",
"Message received 8.14 seconds after request: \n",
"Message received 8.16 seconds after request: 99\n",
"Message received 8.16 seconds after request: ,\n",
"Message received 8.16 seconds after request: \n",
"Message received 8.16 seconds after request: 100\n",
"Message received 8.16 seconds after request: .\n",
"Message received 8.16 seconds after request: None\n",
"Full response received 8.16 seconds after request\n",
"Message received 4.41 seconds after request: 99\n",
"Message received 4.41 seconds after request: ,\n",
"Message received 4.41 seconds after request: \n",
"Message received 4.41 seconds after request: 100\n",
"Message received 4.41 seconds after request: .\n",
"Message received 4.41 seconds after request: None\n",
"Full response received 4.41 seconds after request\n",
"Full conversation received: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100.\n"
]
}
],
"source": [
"# Example of an OpenAI ChatCompletion request with stream=True\n",
"# https://platform.openai.com/docs/guides/chat\n",
"# https://platform.openai.com/docs/api-reference/streaming#chat/create-stream\n",
"\n",
"# record the time before the request is sent\n",
"start_time = time.time()\n",
"\n",
"# send a ChatCompletion request to count to 100\n",
"completion = client.chat.completions.create(\n",
"response = client.chat.completions.create(\n",
" model='gpt-3.5-turbo',\n",
" messages=[\n",
" {'role': 'user', 'content': 'Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ...'}\n",
@ -541,7 +542,7 @@
"collected_chunks = []\n",
"collected_messages = []\n",
"# iterate through the stream of events\n",
"for chunk in completion:\n",
"for chunk in response:\n",
" chunk_time = time.time() - start_time # calculate the time delay of the chunk\n",
" collected_chunks.append(chunk) # save the event response\n",
" chunk_message = chunk.choices[0].delta.content # extract the message\n",
@ -563,10 +564,15 @@
"source": [
"#### Time comparison\n",
"\n",
"In the example above, both requests took about 3 seconds to fully complete. Request times will vary depending on load and other stochastic factors.\n",
"In the example above, both requests took about 4 to 5 seconds to fully complete. Request times will vary depending on load and other stochastic factors.\n",
"\n",
"However, with the streaming request, we received the first token after 0.1 seconds, and subsequent tokens every ~0.01-0.02 seconds."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
@ -585,7 +591,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.3"
},
"orig_nbformat": 4,
"vscode": {

@ -46,8 +46,9 @@
"import json\n",
"import os\n",
"\n",
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
"COMPLETIONS_MODEL = \"text-davinci-002\"\n"
"COMPLETIONS_MODEL = \"gpt-4\"\n",
"\n",
"client = openai.OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if you didn't set as an env var>\"))"
]
},
{
@ -193,15 +194,14 @@
"source": [
"def request_completion(prompt):\n",
"\n",
" completion_response = openai.completions.create(\n",
" completion_response = openai.chat.completions.create(\n",
" prompt=prompt,\n",
" temperature=0,\n",
" max_tokens=5,\n",
" top_p=1,\n",
" frequency_penalty=0,\n",
" presence_penalty=0,\n",
" model=COMPLETIONS_MODEL\n",
" )\n",
" model=COMPLETIONS_MODEL)\n",
"\n",
" return completion_response\n",
"\n",
@ -211,7 +211,7 @@
" prompt = prompt.replace('DESCRIPTION_TEXT',transaction['Description'])\n",
" prompt = prompt.replace('TRANSACTION_VALUE',str(transaction['Transaction value (£)']))\n",
"\n",
" classification = request_completion(prompt).choices[0].text.replace('\\n','')\n",
" classification = request_completion(prompt).choices[0].message.content.replace('\\n','')\n",
"\n",
" return classification\n",
"\n",
@ -351,7 +351,7 @@
" Building Improvement 14\n",
" Could not classify 5\n",
" Literature & Archive 3\n",
" Software/IT 2\n",
" Software/IT 2\n",
" Utility Bills 1\n",
"Name: Classification, dtype: int64"
]
@ -916,8 +916,8 @@
"source": [
"from utils.embeddings_utils import get_embedding\n",
"\n",
"df['babbage_similarity'] = df.combined.apply(lambda x: get_embedding(x, model='text-similarity-babbage-001'))\n",
"df['babbage_search'] = df.combined.apply(lambda x: get_embedding(x, model='text-search-babbage-doc-001'))\n",
"df['babbage_similarity'] = df.combined.apply(lambda x: get_embedding(x, model='gpt-4'))\n",
"df['babbage_search'] = df.combined.apply(lambda x: get_embedding(x, model='gpt-4'))\n",
"df.to_csv(embedding_path)\n"
]
},
@ -1813,7 +1813,7 @@
"metadata": {},
"outputs": [],
"source": [
"test_set['predicted_class'] = test_set.apply(lambda x: openai.Completion.create(model=fine_tuned_model, prompt=x['prompt'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n",
"test_set['predicted_class'] = test_set.apply(lambda x: openai.chat.completions.create(model=fine_tuned_model, prompt=x['prompt'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n",
"test_set['pred'] = test_set.apply(lambda x : x['predicted_class']['choices'][0]['text'],axis=1)\n"
]
},
@ -1963,7 +1963,7 @@
"outputs": [],
"source": [
"holdout_df['combined'] = \"Supplier: \" + holdout_df['Supplier'].str.strip() + \"; Description: \" + holdout_df['Description'].str.strip() + '\\n\\n###\\n\\n' # + \"; Value: \" + str(df['Transaction value (£)']).strip()\n",
"holdout_df['prediction_result'] = holdout_df.apply(lambda x: openai.Completion.create(model=fine_tuned_model, prompt=x['combined'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n",
"holdout_df['prediction_result'] = holdout_df.apply(lambda x: openai.chat.completions.create(model=fine_tuned_model, prompt=x['combined'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n",
"holdout_df['pred'] = holdout_df.apply(lambda x : x['prediction_result']['choices'][0]['text'],axis=1)\n"
]
},

@ -102,7 +102,8 @@
"logging.basicConfig(level=logging.INFO, format=' %(asctime)s - %(levelname)s - %(message)s')\n",
"\n",
"OPENAI_MODEL = 'gpt-3.5-turbo-0613'\n",
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")"
"\n",
"client = openai.OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))"
]
},
{
@ -266,14 +267,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In an OpenAI API call, we can describe `functions` to `gpt-3.5-turbo-0613` and `gpt-4-0613` and have the model intelligently choose to output a `JSON` object containing arguments to call those `functions`. It's important to note that the [chat completions API](https://platform.openai.com/docs/guides/gpt/chat-completions-api) doesn't actually execute the `function`. Instead, it provides the `JSON` output, which can then be used to call the `function` in our code. For more details, refer to the [OpenAI Function Calling Guide](https://platform.openai.com/docs/guides/gpt/function-calling)."
"In an OpenAI API call, we can describe `functions` to `gpt-3.5-turbo-0613` and `gpt-4-0613` and have the model intelligently choose to output a `JSON` object containing arguments to call those `functions`. It's important to note that the [chat completions API](https://platform.openai.com/docs/guides/gpt/chat-completions-api) doesn't actually execute the `function`. Instead, it provides the `JSON` output, which can then be used to call the `function` in our code. For more details, refer to the [OpenAI Function Calling Guide](https://platform.openai.com/docs/guides/function-calling)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our function, `enrich_entities(text, label_entities` gets a block of text and a dictionary containing identified labels and entities as parameters. It then associates the recognized entities with their corresponding links to the Wikipedia articles."
"Our function, `enrich_entities(text, label_entities)` gets a block of text and a dictionary containing identified labels and entities as parameters. It then associates the recognized entities with their corresponding links to the Wikipedia articles."
]
},
{
@ -349,7 +350,7 @@
"As previously highlighted, `gpt-3.5-turbo-0613` and `gpt-4-0613` have been fine-tuned to detect when a `function` should to be called. Moreover, they can produce a `JSON` response that conforms to the `function` signature. Here's the sequence we follow:\n",
"\n",
"1. Define our `function` and its associated `JSON` Schema.\n",
"2. Invoke the model using the `messages`, `functions` and `function_call` parameters.\n",
"2. Invoke the model using the `messages`, `tools` and `tool_choice` parameters.\n",
"3. Convert the output into a `JSON` object, and then call the `function` with the `arguments` provided by the model.\n",
"\n",
"In practice, one might want to re-invoke the model again by appending the `function` response as a new message, and let the model summarize the results back to the user. Nevertheless, for our purposes, this step is not needed.\n",
@ -381,7 +382,7 @@
" \"time\": [\"workday\"] \n",
"} \n",
"```\n",
"we need to define the corresponding `JSON` schema to be passed to the `functions` parameter: "
"we need to define the corresponding `JSON` schema to be passed to the `tools` parameter: "
]
},
{
@ -392,22 +393,25 @@
"source": [
"def generate_functions(labels: dict) -> list:\n",
" return [\n",
" {\n",
" \"name\": \"enrich_entities\",\n",
" \"description\": \"Enrich Text with Knowledge Base Links\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"r'^(?:' + '|'.join({labels}) + ')$'\": \n",
" {\n",
" \"type\": \"array\",\n",
" \"items\": {\n",
" \"type\": \"string\"\n",
" { \n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"enrich_entities\",\n",
" \"description\": \"Enrich Text with Knowledge Base Links\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"r'^(?:' + '|'.join({labels}) + ')$'\": \n",
" {\n",
" \"type\": \"array\",\n",
" \"items\": {\n",
" \"type\": \"string\"\n",
" }\n",
" }\n",
" }\n",
" },\n",
" \"additionalProperties\": False\n",
" },\n",
" },\n",
" \"additionalProperties\": False\n",
" },\n",
" }\n",
" }\n",
" ]"
]
@ -423,7 +427,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we invoke the model. It's important to note that we direct the API to use a specific function by setting the `function_call` parameter to `{\"name\": \"enrich_entities\"}`."
"Now, we invoke the model. It's important to note that we direct the API to use a specific function by setting the `tool_choice` parameter to `{\"type\": \"function\", \"function\" : {\"name\": \"enrich_entities\"}}`."
]
},
{
@ -440,25 +444,27 @@
" {\"role\": \"user\", \"content\": user_message(text=text)}\n",
" ]\n",
"\n",
" response = openai.ChatCompletion.create(\n",
" # TODO: functions and function_call are deprecated, need to be updated\n",
" # See: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools\n",
" response = openai.chat.completions.create(\n",
" model=\"gpt-3.5-turbo-0613\",\n",
" messages=messages,\n",
" functions=generate_functions(labels),\n",
" function_call={\"name\": \"enrich_entities\"}, \n",
" tools=generate_functions(labels),\n",
" tool_choice={\"type\": \"function\", \"function\" : {\"name\": \"enrich_entities\"}}, \n",
" temperature=0,\n",
" frequency_penalty=0,\n",
" presence_penalty=0,\n",
" )\n",
"\n",
" response_message = response[\"choices\"][0][\"message\"]\n",
" response_message = response.choices[0].message\n",
" \n",
" available_functions = {\"enrich_entities\": enrich_entities} \n",
" function_name = response_message[\"function_call\"][\"name\"]\n",
" function_name = response_message.tool_calls[0].function.name\n",
" \n",
" function_to_call = available_functions[function_name]\n",
" logging.info(f\"function_to_call: {function_to_call}\")\n",
"\n",
" function_args = json.loads(response_message[\"function_call\"][\"arguments\"])\n",
" function_args = json.loads(response_message.tool_calls[0].function.arguments)\n",
" logging.info(f\"function_args: {function_args}\")\n",
"\n",
" function_response = function_to_call(text, function_args)\n",
@ -564,8 +570,8 @@
],
"source": [
"# estimate inference cost assuming gpt-3.5-turbo (4K context)\n",
"i_tokens = result[\"model_response\"][\"usage\"][\"prompt_tokens\"] \n",
"o_tokens = result[\"model_response\"][\"usage\"][\"completion_tokens\"] \n",
"i_tokens = result[\"model_response\"].usage.prompt_tokens \n",
"o_tokens = result[\"model_response\"].usage.completion_tokens \n",
"\n",
"i_cost = (i_tokens / 1000) * 0.0015\n",
"o_cost = (o_tokens / 1000) * 0.002\n",

@ -63,11 +63,13 @@
"from IPython import display # for pretty printing\n",
"import json # for parsing the JSON api responses and model outputs\n",
"from numpy import dot # for cosine similarity\n",
"import openai # for using GPT and getting embeddings\n",
"from openai import OpenAI\n",
"import os # for loading environment variables\n",
"import requests # for making the API requests\n",
"from tqdm.notebook import tqdm # for printing progress bars\n",
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
"\n",
"# Load environment variables\n",
"news_api_key = os.getenv(\"NEWS_API_KEY\")\n",
"\n",
@ -76,14 +78,12 @@
"\n",
"# Helper functions\n",
"def json_gpt(input: str):\n",
" completion = openai.ChatCompletion.create(\n",
" model=GPT_MODEL,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"Output only valid JSON\"},\n",
" {\"role\": \"user\", \"content\": input},\n",
" ],\n",
" temperature=0.5,\n",
" )\n",
" completion = client.chat.completions.create(model=GPT_MODEL,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"Output only valid JSON\"},\n",
" {\"role\": \"user\", \"content\": input},\n",
" ],\n",
" temperature=0.5)\n",
"\n",
" text = completion.choices[0].message.content\n",
" parsed = json.loads(text)\n",
@ -92,7 +92,7 @@
"\n",
"\n",
"def embeddings(input: list[str]) -> list[list[str]]:\n",
" response = openai.Embedding.create(model=\"text-embedding-ada-002\", input=input)\n",
" response = client.embeddings.create(model=\"text-embedding-3-small\", input=input)\n",
" return [data.embedding for data in response.data]"
]
},
@ -489,7 +489,7 @@
"Include as much information as possible in the answer. Reference the relevant search result urls as markdown links.\n",
"\"\"\"\n",
"\n",
"completion = openai.ChatCompletion.create(\n",
"completion = client.chat.completions.create(\n",
" model=GPT_MODEL,\n",
" messages=[{\"role\": \"user\", \"content\": ANSWER_INPUT}],\n",
" temperature=0.5,\n",
@ -498,7 +498,7 @@
"\n",
"text = \"\"\n",
"for chunk in completion:\n",
" text += chunk.choices[0].delta.get(\"content\", \"\")\n",
" text += chunk.choices[0].delta.content\n",
" display.clear_output(wait=True)\n",
" display.display(display.Markdown(text))"
]

@ -49,6 +49,8 @@
"| `gpt-4` | 8,192 tokens (~10 pages) |\n",
"| `gpt-4-32k` | 32,768 tokens (~40 pages) |\n",
"\n",
"(New model is available with longer contexts, gpt-4-1106-preview have 128K context window)\n",
"\n",
"Continuing the analogy, you can think of the model like a student who can only look at a few pages of notes at a time, despite potentially having shelves of textbooks to draw upon.\n",
"\n",
"Therefore, to build a system capable of drawing upon large quantities of text to answer questions, we recommend using a Search-Ask approach.\n"
@ -136,7 +138,9 @@
"\n",
"# models\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
"GPT_MODEL = \"gpt-3.5-turbo\""
"GPT_MODEL = \"gpt-3.5-turbo\"\n",
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n"
]
},
{
@ -194,10 +198,7 @@
"# an example question about the 2022 Olympics\n",
"query = 'Which athletes won the gold medal in curling at the 2022 Winter Olympics?'\n",
"\n",
"openai = OpenAI(\n",
" api_key = os.environ.get(\"OPENAI_API_KEY\"), # this is also the default, it can be omitted\n",
")\n",
"response = openai.chat.completions.create(\n",
"response = client.chat.completions.create(\n",
" messages=[\n",
" {'role': 'system', 'content': 'You answer questions about the 2022 Winter Olympics.'},\n",
" {'role': 'user', 'content': query},\n",
@ -557,7 +558,7 @@
"\n",
"Question: Which athletes won the gold medal in curling at the 2022 Winter Olympics?\"\"\"\n",
"\n",
"response = openai.chat.completions.create(\n",
"response = client.chat.completions.create(\n",
" messages=[\n",
" {'role': 'system', 'content': 'You answer questions about the 2022 Winter Olympics.'},\n",
" {'role': 'user', 'content': query},\n",
@ -786,7 +787,7 @@
" top_n: int = 100\n",
") -> tuple[list[str], list[float]]:\n",
" \"\"\"Returns a list of strings and relatednesses, sorted from most related to least.\"\"\"\n",
" query_embedding_response = openai.embeddings.create(\n",
" query_embedding_response = client.embeddings.create(\n",
" model=EMBEDDING_MODEL,\n",
" input=query,\n",
" )\n",
@ -964,7 +965,7 @@
" {\"role\": \"system\", \"content\": \"You answer questions about the 2022 Winter Olympics.\"},\n",
" {\"role\": \"user\", \"content\": message},\n",
" ]\n",
" response = openai.chat.completions.create(\n",
" response = client.chat.completions.create(\n",
" model=model,\n",
" messages=messages,\n",
" temperature=0\n",

@ -403,7 +403,7 @@
"source": [
"from langchain.vectorstores.neo4j_vector import Neo4jVector\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"embeddings_model = \"text-embedding-ada-002\""
"embeddings_model = \"text-embedding-3-small\""
]
},
{
@ -626,16 +626,16 @@
"outputs": [],
"source": [
"from openai import OpenAI\n",
"client = OpenAI()\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
"\n",
"# Define the entities to look for\n",
"def define_query(prompt, model=\"gpt-4-1106-preview\"):\n",
" completion = client.chat.completions.create(\n",
" model=model,\n",
" temperature=0,\n",
" response_format= {\n",
" \"type\": \"json_object\"\n",
" },\n",
" model=model,\n",
" temperature=0,\n",
" response_format= {\n",
" \"type\": \"json_object\"\n",
" },\n",
" messages=[\n",
" {\n",
" \"role\": \"system\",\n",

File diff suppressed because it is too large Load Diff

@ -14,14 +14,14 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ada-002 embedding performance on 1k Amazon reviews: mse=0.62, mae=0.53\n"
"text-embedding-3-small performance on 1k Amazon reviews: mse=0.65, mae=0.52\n"
]
}
],
@ -48,12 +48,12 @@
"mse = mean_squared_error(y_test, preds)\n",
"mae = mean_absolute_error(y_test, preds)\n",
"\n",
"print(f\"ada-002 embedding performance on 1k Amazon reviews: mse={mse:.2f}, mae={mae:.2f}\")\n"
"print(f\"text-embedding-3-small performance on 1k Amazon reviews: mse={mse:.2f}, mae={mae:.2f}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@ -105,7 +105,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
"version": "3.11.5"
},
"orig_nbformat": 4,
"vscode": {

@ -54,9 +54,14 @@
"import arxiv\n",
"from math import exp\n",
"import openai\n",
"import os\n",
"import pandas as pd\n",
"from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
"import tiktoken"
"import tiktoken\n",
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
"\n",
"OPENAI_MODEL = \"gpt-4\""
]
},
{
@ -202,7 +207,7 @@
],
"source": [
"tokens = [\" Yes\", \" No\"]\n",
"tokenizer = tiktoken.encoding_for_model(\"text-davinci-003\")\n",
"tokenizer = tiktoken.encoding_for_model(OPENAI_MODEL)\n",
"ids = [tokenizer.encode(token) for token in tokens]\n",
"ids[0], ids[1]"
]
@ -245,19 +250,19 @@
"\n",
"@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n",
"def document_relevance(query, document):\n",
" response = openai.Completion.create(\n",
" response = openai.chat.completions.create(\n",
" model=\"text-davinci-003\",\n",
" prompt=prompt.format(query=query, document=content),\n",
" message=prompt.format(query=query, document=document),\n",
" temperature=0,\n",
" logprobs=1,\n",
" logprobs=True,\n",
" logit_bias={3363: 1, 1400: 1},\n",
" )\n",
"\n",
" return (\n",
" query,\n",
" document,\n",
" response[\"choices\"][0][\"text\"],\n",
" response[\"choices\"][0][\"logprobs\"][\"token_logprobs\"][0],\n",
" response.choices[0].message.content,\n",
" response.choices[0].logprobs.token_logprobs[0],\n",
" )"
]
},
@ -271,8 +276,8 @@
"content = result_list[0][\"title\"] + \": \" + result_list[0][\"summary\"]\n",
"\n",
"# Set logprobs to 1 so our response will include the most probable token the model identified\n",
"response = openai.Completion.create(\n",
" model=\"text-davinci-003\",\n",
"response = openai.chat.completions.create(\n",
" model=OPENAI_MODEL,\n",
" prompt=prompt.format(query=query, document=content),\n",
" temperature=0,\n",
" logprobs=1,\n",
@ -317,9 +322,9 @@
}
],
"source": [
"result = response[\"choices\"][0]\n",
"print(f\"Result was {result['text']}\")\n",
"print(f\"Logprobs was {result['logprobs']['token_logprobs'][0]}\")\n",
"result = response.choices[0]\n",
"print(f\"Result was {result.message.content}\")\n",
"print(f\"Logprobs was {result.logprobs.token_logprobs[0]}\")\n",
"print(\"\\nBelow is the full logprobs object\\n\\n\")\n",
"print(result[\"logprobs\"])"
]

@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -36,18 +36,18 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Good Buy: I liked the beans. They were vacuum sealed, plump and moist. Would recommend them for any use. I personally split and stuck them in some vodka to make vanilla extract. Yum!\n",
"Delicious!: I enjoy this white beans seasoning, it gives a rich flavor to the beans I just love it, my mother in law didn't know about this Zatarain's brand and now she is traying different seasoning\n",
"\n",
"Jamaican Blue beans: Excellent coffee bean for roasting. Our family just purchased another 5 pounds for more roasting. Plenty of flavor and mild on acidity when roasted to a dark brown bean and befor\n",
"Fantastic Instant Refried beans: Fantastic Instant Refried Beans have been a staple for my family now for nearly 20 years. All 7 of us love it and my grown kids are passing on the tradition.\n",
"\n",
"Delicious!: I enjoy this white beans seasoning, it gives a rich flavor to the beans I just love it, my mother in law didn't know about this Zatarain's brand and now she is traying different seasoning\n",
"Delicious: While there may be better coffee beans available, this is my first purchase and my first time grinding my own beans. I read several reviews before purchasing this brand, and am extremely \n",
"\n"
]
}
@ -59,7 +59,7 @@
"def search_reviews(df, product_description, n=3, pprint=True):\n",
" product_embedding = get_embedding(\n",
" product_description,\n",
" model=\"text-embedding-ada-002\"\n",
" model=\"text-embedding-3-small\"\n",
" )\n",
" df[\"similarity\"] = df.embedding.apply(lambda x: cosine_similarity(x, product_embedding))\n",
"\n",
@ -81,7 +81,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@ -92,7 +92,7 @@
"\n",
"sooo good: tastes so good. Worth the money. My boyfriend hates wheat pasta and LOVES this. cooks fast tastes great.I love this brand and started buying more of their pastas. Bulk is best.\n",
"\n",
"Handy: Love the idea of ready in a minute pasta and for that alone this product gets praise. The pasta is whole grain so that's a big plus and it actually comes out al dente. The vegetable marinara\n",
"Bland and vaguely gamy tasting, skip this one: As far as prepared dinner kits go, \"Barilla Whole Grain Mezze Penne with Tomato and Basil Sauce\" just did not do it for me...and this is coming from a p\n",
"\n"
]
}
@ -111,7 +111,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
@ -137,14 +137,14 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extremely dissapointed: Hi,<br />I am very disappointed with the past shipment I received of the ONE coconut water. 3 of the boxes were leaking and the coconut water was spoiled.<br /><br />Thanks.<b\n",
"Disappointed: The metal cover has severely disformed. And most of the cookies inside have been crushed into small pieces. Shopping experience is awful. I'll never buy it online again.\n",
"\n"
]
}
@ -155,16 +155,16 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Good food: The only dry food my queen cat will eat. Helps prevent hair balls. Good packaging. Arrives promptly. Recommended by a friend who sells pet food.\n",
"Great food!: I wanted a food for a a dog with skin problems. His skin greatly improved with the switch, though he still itches some. He loves the food. No recalls, American made with American ingred\n",
"\n",
"The cats like it: My 7 cats like this food but it is a little yucky for the human. Pieces of mackerel swimming in a dark broth. It is billed as a \"complete\" food and contains carrots, peas and pasta.\n",
"Great food!: I wanted a food for a a dog with skin problems. His skin greatly improved with the switch, though he still itches some. He loves the food. No recalls, American made with American ingred\n",
"\n"
]
}
@ -190,7 +190,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
"version": "3.11.5"
},
"orig_nbformat": 4,
"vscode": {

@ -26,13 +26,16 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# imports needed to run the code in this notebook\n",
"import ast # used for detecting whether generated Python code is valid\n",
"import openai # used for calling the OpenAI API\n",
"import os\n",
"from openai import OpenAI\n",
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
"\n",
"color_prefix_by_role = {\n",
" \"system\": \"\\033[0m\", # gray\n",
@ -96,19 +99,17 @@
" if print_text:\n",
" print_messages(explain_messages)\n",
"\n",
" explanation_response = openai.ChatCompletion.create(\n",
" model=explain_model,\n",
" messages=explain_messages,\n",
" temperature=temperature,\n",
" stream=True,\n",
" )\n",
" explanation_response = client.chat.completions.create(model=explain_model,\n",
" messages=explain_messages,\n",
" temperature=temperature,\n",
" stream=True)\n",
" explanation = \"\"\n",
" for chunk in explanation_response:\n",
" delta = chunk[\"choices\"][0][\"delta\"]\n",
" delta = chunk.choices[0].delta\n",
" if print_text:\n",
" print_message_delta(delta)\n",
" if \"content\" in delta:\n",
" explanation += delta[\"content\"]\n",
" explanation += delta.content\n",
" explain_assistant_message = {\"role\": \"assistant\", \"content\": explanation}\n",
"\n",
" # Step 2: Generate a plan to write a unit test\n",
@ -133,19 +134,17 @@
" ]\n",
" if print_text:\n",
" print_messages([plan_user_message])\n",
" plan_response = openai.ChatCompletion.create(\n",
" model=plan_model,\n",
" messages=plan_messages,\n",
" temperature=temperature,\n",
" stream=True,\n",
" )\n",
" plan_response = client.chat.completions.create(model=plan_model,\n",
" messages=plan_messages,\n",
" temperature=temperature,\n",
" stream=True)\n",
" plan = \"\"\n",
" for chunk in plan_response:\n",
" delta = chunk[\"choices\"][0][\"delta\"]\n",
" delta = chunk.choices[0].delta\n",
" if print_text:\n",
" print_message_delta(delta)\n",
" if \"content\" in delta:\n",
" plan += delta[\"content\"]\n",
" explanation += delta.content\n",
" plan_assistant_message = {\"role\": \"assistant\", \"content\": plan}\n",
"\n",
" # Step 2b: If the plan is short, ask GPT to elaborate further\n",
@ -167,19 +166,17 @@
" ]\n",
" if print_text:\n",
" print_messages([elaboration_user_message])\n",
" elaboration_response = openai.ChatCompletion.create(\n",
" model=plan_model,\n",
" messages=elaboration_messages,\n",
" temperature=temperature,\n",
" stream=True,\n",
" )\n",
" elaboration_response = client.chat.completions.create(model=plan_model,\n",
" messages=elaboration_messages,\n",
" temperature=temperature,\n",
" stream=True)\n",
" elaboration = \"\"\n",
" for chunk in elaboration_response:\n",
" delta = chunk[\"choices\"][0][\"delta\"]\n",
" if print_text:\n",
" print_message_delta(delta)\n",
" if \"content\" in delta:\n",
" elaboration += delta[\"content\"]\n",
" delta = chunk.choices[0].delta\n",
" if print_text:\n",
" print_message_delta(delta)\n",
" if \"content\" in delta:\n",
" explanation += delta.content\n",
" elaboration_assistant_message = {\"role\": \"assistant\", \"content\": elaboration}\n",
"\n",
" # Step 3: Generate the unit test\n",
@ -222,19 +219,17 @@
" if print_text:\n",
" print_messages([execute_system_message, execute_user_message])\n",
"\n",
" execute_response = openai.ChatCompletion.create(\n",
" model=execute_model,\n",
" execute_response = client.chat.completions.create(model=execute_model,\n",
" messages=execute_messages,\n",
" temperature=temperature,\n",
" stream=True,\n",
" )\n",
" stream=True)\n",
" execution = \"\"\n",
" for chunk in execute_response:\n",
" delta = chunk[\"choices\"][0][\"delta\"]\n",
" delta = chunk.choices[0].delta\n",
" if print_text:\n",
" print_message_delta(delta)\n",
" if \"content\" in delta:\n",
" execution += delta[\"content\"]\n",
" if delta.content:\n",
" execution += delta.content\n",
"\n",
" # check the output for errors\n",
" code = execution.split(\"```python\")[1].split(\"```\")[0].strip()\n",
@ -263,7 +258,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 12,
"metadata": {},
"outputs": [
{
@ -297,18 +292,7 @@
" return ' '.join(translated_words)\n",
"\n",
"```\n",
"\u001b[92m\n",
"[assistant]\n",
"The `pig_latin` function takes a string of text and returns the text translated into pig latin. Here's how it works:\n",
"\n",
"* The function defines a nested function called `translate` that takes a single word as input and returns the word translated into pig latin.\n",
"* The `translate` function first defines a string of vowels.\n",
"* If the first letter of the input word is a vowel, the function adds \"way\" to the end of the word and returns the result.\n",
"* If the first letter of the input word is a consonant, the function loops through the word's letters until it finds a vowel.\n",
"* The function then takes the consonants at the beginning of the word and moves them to the end of the word, adding \"ay\" to the end of the word.\n",
"* The `pig_latin` function lowercases the input text and splits it into a list of words.\n",
"* The function then applies the `translate` function to each word in the list using a list comprehension.\n",
"* Finally, the function joins the translated words back together into a single string with spaces between each word and returns the result.\u001b[0m\n",
"\u001b[0m\n",
"[user]\n",
"A good unit test suite should aim to:\n",
"- Test the function's behavior for a wide range of possible inputs\n",
@ -318,50 +302,10 @@
"- Be deterministic, so that the tests always pass or fail in the same way\n",
"\n",
"To help unit test the function above, list diverse scenarios that the function should be able to handle (and under each scenario, include a few examples as sub-bullets).\n",
"\u001b[92m\n",
"[assistant]\n",
"Here are some scenarios that the `pig_latin` function should be able to handle, along with examples:\n",
"\n",
"* Words that start with a vowel:\n",
" * \"apple\" -> \"appleway\"\n",
" * \"elephant\" -> \"elephantway\"\n",
"* Words that start with a single consonant:\n",
" * \"pig\" -> \"igpay\"\n",
" * \"latin\" -> \"atinlay\"\n",
"* Words that start with multiple consonants:\n",
" * \"string\" -> \"ingstray\"\n",
" * \"glove\" -> \"oveglay\"\n",
"* Words that contain numbers or special characters:\n",
" * \"hello!\" -> \"ellohay!\"\n",
" * \"world123\" -> \"orldway123\"\n",
"* Sentences with multiple words:\n",
" * \"hello world\" -> \"ellohay orldway\"\n",
" * \"the quick brown fox\" -> \"hetay ickquay ownbray oxfay\"\n",
"* Sentences with punctuation:\n",
" * \"Hello, world!\" -> \"Ellohay, orldway!\"\n",
" * \"The quick brown fox...\" -> \"Hetay ickquay ownbray oxfay...\" \n",
"* Empty strings:\n",
" * \"\" -> \"\"\u001b[0m\n",
"\u001b[0m\n",
"[user]\n",
"In addition to those scenarios above, list a few rare or unexpected edge cases (and as before, under each edge case, include a few examples as sub-bullets).\n",
"\u001b[92m\n",
"[assistant]\n",
"Here are some rare or unexpected edge cases that the `pig_latin` function should be able to handle, along with examples:\n",
"\n",
"* Words that consist entirely of consonants:\n",
" * \"xyz\" -> \"xyzay\"\n",
" * \"rhythm\" -> \"ythmrhay\"\n",
"* Words that consist entirely of vowels:\n",
" * \"aeiou\" -> \"aeiouway\"\n",
" * \"ouiea\" -> \"ouieaway\"\n",
"* Words that contain mixed case:\n",
" * \"PyThOn\" -> \"ythonpay\"\n",
" * \"eLePhAnT\" -> \"elephantway\"\n",
"* Strings that contain only whitespace:\n",
" * \" \" -> \"\"\n",
"* Non-string inputs:\n",
" * None -> TypeError\n",
" * 42 -> AttributeError\u001b[0m\n",
"\u001b[0m\n",
"[system]\n",
"You are a world-class Python developer with an eagle eye for unintended bugs and edge cases. You write careful, accurate unit tests. When asked to reply only with code, you write all of your code in a single block.\n",
"\u001b[0m\n",
@ -397,74 +341,7 @@
"# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator\n",
"{insert unit test code here}\n",
"```\n",
"\u001b[92m\n",
"[assistant]\n",
"```python\n",
"# imports\n",
"import pytest\n",
"\n",
"# function to test\n",
"def pig_latin(text):\n",
" def translate(word):\n",
" vowels = 'aeiou'\n",
" if word[0] in vowels:\n",
" return word + 'way'\n",
" else:\n",
" consonants = ''\n",
" for letter in word:\n",
" if letter not in vowels:\n",
" consonants += letter\n",
" else:\n",
" break\n",
" return word[len(consonants):] + consonants + 'ay'\n",
"\n",
" words = text.lower().split()\n",
" translated_words = [translate(word) for word in words]\n",
" return ' '.join(translated_words)\n",
"\n",
"\n",
"# unit tests\n",
"# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator\n",
"\n",
"# Tests for normal cases\n",
"@pytest.mark.parametrize('text, expected', [\n",
" ('apple', 'appleway'),\n",
" ('elephant', 'elephantway'),\n",
" ('pig', 'igpay'),\n",
" ('latin', 'atinlay'),\n",
" ('string', 'ingstray'),\n",
" ('glove', 'oveglay'),\n",
" ('hello world', 'ellohay orldway'),\n",
" ('the quick brown fox', 'hetay ickquay ownbray oxfay'),\n",
" ('Hello, world!', 'Ellohay, orldway!'),\n",
" ('The quick brown fox...', 'Hetay ickquay ownbray oxfay...'),\n",
" ('', ''),\n",
"])\n",
"\n",
"def test_pig_latin_normal_cases(text, expected):\n",
" assert pig_latin(text) == expected\n",
"\n",
"\n",
"# Tests for edge cases\n",
"@pytest.mark.parametrize('text, expected', [\n",
" ('xyz', 'xyzay'),\n",
" ('rhythm', 'ythmrhay'),\n",
" ('aeiou', 'aeiouway'),\n",
" ('ouiea', 'ouieaway'),\n",
" ('PyThOn', 'ythonpay'),\n",
" ('eLePhAnT', 'elephantway'),\n",
" (' ', ''),\n",
" (None, TypeError),\n",
" (42, AttributeError)\n",
"])\n",
"\n",
"def test_pig_latin_edge_cases(text, expected):\n",
" if type(expected) == type:\n",
" with pytest.raises(expected):\n",
" pig_latin(text)\n",
" else:\n",
" assert pig_latin(text) == expected\n",
"```"
"execute messages: [{'role': 'system', 'content': 'You are a world-class Python developer with an eagle eye for unintended bugs and edge cases. You write careful, accurate unit tests. When asked to reply only with code, you write all of your code in a single block.'}, {'role': 'user', 'content': \"Please explain the following Python function. Review what each element of the function is doing precisely and what the author's intentions may have been. Organize your explanation as a markdown-formatted, bulleted list.\\n\\n```python\\ndef pig_latin(text):\\n def translate(word):\\n vowels = 'aeiou'\\n if word[0] in vowels:\\n return word + 'way'\\n else:\\n consonants = ''\\n for letter in word:\\n if letter not in vowels:\\n consonants += letter\\n else:\\n break\\n return word[len(consonants):] + consonants + 'ay'\\n\\n words = text.lower().split()\\n translated_words = [translate(word) for word in words]\\n return ' '.join(translated_words)\\n\\n```\"}, {'role': 'assistant', 'content': ''}, {'role': 'user', 'content': \"A good unit test suite should aim to:\\n- Test the function's behavior for a wide range of possible inputs\\n- Test edge cases that the author may not have foreseen\\n- Take advantage of the features of `pytest` to make the tests easy to write and maintain\\n- Be easy to read and understand, with clean code and descriptive names\\n- Be deterministic, so that the tests always pass or fail in the same way\\n\\nTo help unit test the function above, list diverse scenarios that the function should be able to handle (and under each scenario, include a few examples as sub-bullets).\"}, {'role': 'assistant', 'content': ''}, {'role': 'user', 'content': 'In addition to those scenarios above, list a few rare or unexpected edge cases (and as before, under each edge case, include a few examples as sub-bullets).'}, {'role': 'assistant', 'content': ''}, {'role': 'user', 'content': \"Using Python and the `pytest` package, write a suite of unit tests for the function, following the cases above. Include helpful comments to explain each line. Reply only with code, formatted as follows:\\n\\n```python\\n# imports\\nimport pytest # used for our unit tests\\n{insert other imports as needed}\\n\\n# function to test\\ndef pig_latin(text):\\n def translate(word):\\n vowels = 'aeiou'\\n if word[0] in vowels:\\n return word + 'way'\\n else:\\n consonants = ''\\n for letter in word:\\n if letter not in vowels:\\n consonants += letter\\n else:\\n break\\n return word[len(consonants):] + consonants + 'ay'\\n\\n words = text.lower().split()\\n translated_words = [translate(word) for word in words]\\n return ' '.join(translated_words)\\n\\n\\n# unit tests\\n# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator\\n{insert unit test code here}\\n```\"}]\n"
]
}
],
@ -497,7 +374,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 13,
"metadata": {},
"outputs": [
{
@ -528,46 +405,19 @@
"\n",
"\n",
"# unit tests\n",
"# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator\n",
"\n",
"# Tests for normal cases\n",
"@pytest.mark.parametrize('text, expected', [\n",
" ('apple', 'appleway'),\n",
" ('elephant', 'elephantway'),\n",
" ('pig', 'igpay'),\n",
" ('latin', 'atinlay'),\n",
" ('string', 'ingstray'),\n",
" ('glove', 'oveglay'),\n",
" ('hello world', 'ellohay orldway'),\n",
" ('the quick brown fox', 'hetay ickquay ownbray oxfay'),\n",
" ('Hello, world!', 'Ellohay, orldway!'),\n",
" ('The quick brown fox...', 'Hetay ickquay ownbray oxfay...'),\n",
" ('', ''),\n",
" ('hello world', 'ellohay orldway'), # basic test case\n",
" ('Python is awesome', 'ythonPay isway awesomeway'), # test case with multiple words\n",
" ('apple', 'appleway'), # test case with a word starting with a vowel\n",
" ('', ''), # test case with an empty string\n",
" ('123', '123'), # test case with non-alphabetic characters\n",
" ('Hello World!', 'elloHay orldWay!'), # test case with punctuation\n",
" ('The quick brown fox', 'ethay ickquay ownbray oxfay'), # test case with mixed case words\n",
" ('a e i o u', 'away eway iway oway uway'), # test case with all vowels\n",
" ('bcd fgh jkl mnp', 'bcday fghay jklway mnpay'), # test case with all consonants\n",
"])\n",
"\n",
"def test_pig_latin_normal_cases(text, expected):\n",
" assert pig_latin(text) == expected\n",
"\n",
"\n",
"# Tests for edge cases\n",
"@pytest.mark.parametrize('text, expected', [\n",
" ('xyz', 'xyzay'),\n",
" ('rhythm', 'ythmrhay'),\n",
" ('aeiou', 'aeiouway'),\n",
" ('ouiea', 'ouieaway'),\n",
" ('PyThOn', 'ythonpay'),\n",
" ('eLePhAnT', 'elephantway'),\n",
" (' ', ''),\n",
" (None, TypeError),\n",
" (42, AttributeError)\n",
"])\n",
"\n",
"def test_pig_latin_edge_cases(text, expected):\n",
" if type(expected) == type:\n",
" with pytest.raises(expected):\n",
" pig_latin(text)\n",
" else:\n",
" assert pig_latin(text) == expected\n"
"def test_pig_latin(text, expected):\n",
" assert pig_latin(text) == expected\n"
]
}
],
@ -600,7 +450,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
"version": "3.11.3"
},
"orig_nbformat": 4,
"vscode": {

@ -18,7 +18,7 @@
"The code example illustrates a few optional embellishments on the chained, multi-step prompt:\n",
"\n",
"- Conditional branching (e.g., only asking for elaboration if the first plan is too short)\n",
"- Different models for different steps (e.g., `text-davinci-002` for the text planning steps and `code-davinci-002` for the code writing step)\n",
"- Different models for different steps (e.g., `gpt-3.5-turbo-instruct` for the text planning steps and `gpt-4` for the code writing step)\n",
"- A check that re-runs the function if the output is unsatisfactory (e.g., if the output code cannot be parsed by Python's `ast` module)\n",
"- Streaming output so that you can start reading the output before it's fully generated (useful for long, multi-step outputs)\n",
"\n",
@ -67,9 +67,8 @@
"metadata": {},
"outputs": [],
"source": [
"# imports needed to run the code in this notebook\n",
"import ast # used for detecting whether generated Python code is valid\n",
"import openai # used for calling the OpenAI API\n",
"import openai\n",
"\n",
"# example of a function that uses a multi-step prompt to write unit tests\n",
"def unit_test_from_function(\n",
@ -77,8 +76,8 @@
" unit_test_package: str = \"pytest\", # unit testing package; use the name as it appears in the import statement\n",
" approx_min_cases_to_cover: int = 7, # minimum number of test case categories to cover (approximate)\n",
" print_text: bool = False, # optionally prints text; helpful for understanding the function & debugging\n",
" text_model: str = \"text-davinci-002\", # model used to generate text plans in steps 1, 2, and 2b\n",
" code_model: str = \"code-davinci-002\", # if you don't have access to code models, you can use text models here instead\n",
" text_model: str = \"gpt-3.5-turbo-instruct\", # model used to generate text plans in steps 1, 2, and 2b\n",
" code_model: str = \"gpt-3.5-turbo-instruct\", # if you don't have access to code models, you can use text models here instead\n",
" max_tokens: int = 1000, # can set this high, as generations should be stopped earlier by stop sequences\n",
" temperature: float = 0.4, # temperature = 0 can sometimes get stuck in repetitive loops, so we use 0.4\n",
" reruns_if_fail: int = 1, # if the output code cannot be parsed, this will re-run the function up to N times\n",

File diff suppressed because one or more lines are too long

@ -7,7 +7,7 @@
"source": [
"## Using embeddings\n",
"\n",
"This notebook contains some helpful snippets you can use to embed text with the 'text-embedding-ada-002' model via the OpenAI API."
"This notebook contains some helpful snippets you can use to embed text with the `text-embedding-3-small` model via the OpenAI API."
]
},
{
@ -30,7 +30,7 @@
"import openai\n",
"\n",
"embedding = openai.Embedding.create(\n",
" input=\"Your text goes here\", model=\"text-embedding-ada-002\"\n",
" input=\"Your text goes here\", model=\"text-embedding-3-small\"\n",
")[\"data\"][0][\"embedding\"]\n",
"len(embedding)\n"
]
@ -55,7 +55,7 @@
"num_embeddings = 10000 # Some large number\n",
"for i in range(num_embeddings):\n",
" embedding = openai.Embedding.create(\n",
" input=\"Your text goes here\", model=\"text-embedding-ada-002\"\n",
" input=\"Your text goes here\", model=\"text-embedding-3-small\"\n",
" )[\"data\"][0][\"embedding\"]\n",
" print(len(embedding))"
]
@ -80,10 +80,10 @@
"\n",
"# Retry up to 6 times with exponential backoff, starting at 1 second and maxing out at 20 seconds delay\n",
"@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))\n",
"def get_embedding(text: str, model=\"text-embedding-ada-002\") -> list[float]:\n",
"def get_embedding(text: str, model=\"text-embedding-3-small\") -> list[float]:\n",
" return openai.Embedding.create(input=[text], model=model)[\"data\"][0][\"embedding\"]\n",
"\n",
"embedding = get_embedding(\"Your text goes here\", model=\"text-embedding-ada-002\")\n",
"embedding = get_embedding(\"Your text goes here\", model=\"text-embedding-3-small\")\n",
"print(len(embedding))"
]
}

@ -50,9 +50,9 @@
"from math import exp\n",
"import numpy as np\n",
"from IPython.display import display, HTML\n",
"import os\n",
"\n",
"\n",
"client = OpenAI()"
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))"
]
},
{

File diff suppressed because one or more lines are too long

@ -138,7 +138,7 @@
"source": [
"from utils.embeddings_utils import get_embeddings\n",
"# NOTE: The following code will send a query of batch size 200 to /embeddings\n",
"matrix = get_embeddings(samples[\"text\"].to_list(), model=\"text-embedding-ada-002\")\n"
"matrix = get_embeddings(samples[\"text\"].to_list(), model=\"text-embedding-3-small\")\n"
]
},
{

@ -32,8 +32,11 @@
"outputs": [],
"source": [
"# imports\n",
"import openai # for making OpenAI API calls\n",
"import urllib # for downloading example audio files\n"
"from openai import OpenAI # for making OpenAI API calls\n",
"import urllib # for downloading example audio files\n",
"import os # for accessing environment variables\n",
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))"
]
},
{
@ -90,7 +93,7 @@
"# define a wrapper function for seeing how prompts affect transcriptions\n",
"def transcribe(prompt: str, audio_filepath) -> str:\n",
" \"\"\"Given a prompt, transcribe the audio file.\"\"\"\n",
" transcript = openai.audio.transcriptions.create(\n",
" transcript = client.audio.transcriptions.create(\n",
" file=open(audio_filepath, \"rb\"),\n",
" model=\"whisper-1\",\n",
" prompt=prompt,\n",
@ -206,7 +209,7 @@
"source": [
"# define a wrapper function for seeing how prompts affect transcriptions\n",
"def transcribe_with_spellcheck(system_message, audio_filepath):\n",
" completion = openai.chat.completions.create(\n",
" completion = client.chat.completions.create(\n",
" model=\"gpt-4\",\n",
" temperature=0,\n",
" messages=[\n",

@ -48,8 +48,7 @@
"metadata": {},
"outputs": [],
"source": [
"client = OpenAI()\n",
"client.api_key = os.getenv(\"OPENAI_API_KEY\")"
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))"
]
},
{
@ -492,7 +491,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.11.3"
},
"orig_nbformat": 4
},

@ -48,8 +48,11 @@
"outputs": [],
"source": [
"# imports\n",
"import openai # for making OpenAI API calls\n",
"import urllib # for downloading example audio files"
"from openai import OpenAI # for making OpenAI API calls\n",
"import urllib # for downloading example audio files\n",
"import os\n",
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))"
]
},
{
@ -106,7 +109,7 @@
"# define a wrapper function for seeing how prompts affect transcriptions\n",
"def transcribe(audio_filepath, prompt: str) -> str:\n",
" \"\"\"Given a prompt, transcribe the audio file.\"\"\"\n",
" transcript = openai.audio.transcriptions.create(\n",
" transcript = client.audio.transcriptions.create(\n",
" file=open(audio_filepath, \"rb\"),\n",
" model=\"whisper-1\",\n",
" prompt=prompt,\n",
@ -457,7 +460,7 @@
"# define a function for GPT to generate fictitious prompts\n",
"def fictitious_prompt_from_instruction(instruction: str) -> str:\n",
" \"\"\"Given an instruction, generate a fictitious prompt.\"\"\"\n",
" response = openai.chat.completions.create(\n",
" response = client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo-0613\",\n",
" temperature=0,\n",
" messages=[\n",

File diff suppressed because one or more lines are too long

@ -32,14 +32,14 @@ Inputs:
- requests_filepath : str
- path to the file containing the requests to be processed
- file should be a jsonl file, where each line is a json object with API parameters and an optional metadata field
- e.g., {"model": "text-embedding-ada-002", "input": "embed me", "metadata": {"row_id": 1}}
- e.g., {"model": "text-embedding-3-small", "input": "embed me", "metadata": {"row_id": 1}}
- as with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically)
- an example file is provided at examples/data/example_requests_to_parallel_process.jsonl
- the code to generate the example file is appended to the bottom of this script
- save_filepath : str, optional
- path to the file where the results will be saved
- file will be a jsonl file, where each line is an array with the original request plus the API response
- e.g., [{"model": "text-embedding-ada-002", "input": "embed me"}, {...}]
- e.g., [{"model": "text-embedding-3-small", "input": "embed me"}, {...}]
- if omitted, results will be saved to {requests_filename}_results.jsonl
- request_url : str, optional
- URL of the API endpoint to call
@ -58,7 +58,7 @@ Inputs:
- if omitted, will default to 125,000
- token_encoding_name : str, optional
- name of the token encoding used, as defined in the `tiktoken` package
- if omitted, will default to "cl100k_base" (used by `text-embedding-ada-002`)
- if omitted, will default to "cl100k_base" (used by `text-embedding-3-small`)
- max_attempts : int, optional
- number of times to retry a failed request before giving up
- if omitted, will default to 5
@ -133,7 +133,7 @@ async def process_api_requests_from_file(
api_endpoint = api_endpoint_from_url(request_url)
request_header = {"Authorization": f"Bearer {api_key}"}
# use api-key header for Azure deployments
if '/deployments' in request_url:
if "/deployments" in request_url:
request_header = {"api-key": f"{api_key}"}
# initialize trackers
@ -371,7 +371,9 @@ def api_endpoint_from_url(request_url):
match = re.search("^https://[^/]+/v\\d+/(.+)$", request_url)
if match is None:
# for Azure OpenAI deployment urls
match = re.search(r"^https://[^/]+/openai/deployments/[^/]+/(.+?)(\?|$)", request_url)
match = re.search(
r"^https://[^/]+/openai/deployments/[^/]+/(.+?)(\?|$)", request_url
)
return match[1]
@ -488,7 +490,7 @@ if __name__ == "__main__":
"""
APPENDIX
The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-ada-002.
The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-3-small.
It was generated with the following code:
@ -497,7 +499,7 @@ import json
filename = "data/example_requests_to_parallel_process.jsonl"
n_requests = 10_000
jobs = [{"model": "text-embedding-ada-002", "input": str(x) + "\n"} for x in range(n_requests)]
jobs = [{"model": "text-embedding-3-small", "input": str(x) + "\n"} for x in range(n_requests)]
with open(filename, "w") as f:
for job in jobs:
json_string = json.dumps(job)

@ -174,7 +174,7 @@
"metadata": {},
"source": [
"## Deployments\n",
"In this section we are going to create a deployment using the `text-davinci-002` model that we can then use to create completions."
"In this section we are going to create a deployment using the `gpt-3.5-turbo-instruct` model that we can then use to create completions."
]
},
{
@ -183,7 +183,7 @@
"metadata": {},
"source": [
"### Deployments: Create manually\n",
"Create a new deployment by going to your Resource in your portal under \"Resource Management\" -> \"Model deployments\". Select `text-davinci-002` as the model."
"Create a new deployment by going to your Resource in your portal under \"Resource Management\" -> \"Model deployments\". Select `gpt-3.5-turbo-instruct` as the model."
]
},
{

@ -9,7 +9,7 @@
"\n",
"Azure OpenAI on your data enables you to run supported chat models such as GPT-3.5-Turbo and GPT-4 on your data without needing to train or fine-tune models. Running models on your data enables you to chat on top of, and analyze your data with greater accuracy and speed. One of the key benefits of Azure OpenAI on your data is its ability to tailor the content of conversational AI. Because the model has access to, and can reference specific sources to support its responses, answers are not only based on its pretrained knowledge but also on the latest information available in the designated data source. This grounding data also helps the model avoid generating responses based on outdated or incorrect information.\n",
"\n",
"Azure OpenAI on your own data with Azure Cognitive Search provides a customizable, pre-built solution for knowledge retrieval, from which a conversational AI application can be built. To see alternative methods for knowledge retrieval and semantic search, check out the cookbook examples for [vector databases](https://github.com/openai/openai-cookbook/tree/main/examples/vector_databases)."
"Azure OpenAI on your own data with Azure AI Search (f.k.a. Azure Cognitive Search) provides a customizable, pre-built solution for knowledge retrieval, from which a conversational AI application can be built. To see alternative methods for knowledge retrieval and semantic search, check out the cookbook examples for [vector databases](https://github.com/openai/openai-cookbook/tree/main/examples/vector_databases)."
]
},
{
@ -18,7 +18,7 @@
"source": [
"## How it works\n",
"\n",
"[Azure OpenAI on your own data](https://learn.microsoft.com/azure/ai-services/openai/concepts/use-your-data) connects the model with your data, giving it the ability to retrieve and utilize data in a way that enhances the model's output. Together with Azure Cognitive Search, data is retrieved from designated data sources based on the user input and provided conversation history. The data is then augmented and resubmitted as a prompt to the model, giving the model contextual information it can use to generate a response.\n",
"[Azure OpenAI on your own data](https://learn.microsoft.com/azure/ai-services/openai/concepts/use-your-data) connects the model with your data, giving it the ability to retrieve and utilize data in a way that enhances the model's output. Together with Azure AI Search, data is retrieved from designated data sources based on the user input and provided conversation history. The data is then augmented and resubmitted as a prompt to the model, giving the model contextual information it can use to generate a response.\n",
"\n",
"See the [Data, privacy, and security for Azure OpenAI Service](https://learn.microsoft.com/legal/cognitive-services/openai/data-privacy?context=%2Fazure%2Fai-services%2Fopenai%2Fcontext%2Fcontext) for more information."
]
@ -35,7 +35,7 @@
"To use your own data with Azure OpenAI models, you will need:\n",
"\n",
"1. Azure OpenAI access and a resource with a chat model deployed (for example, GPT-3 or GPT-4)\n",
"2. Azure Cognitive Search resource\n",
"2. Azure AI Search (f.k.a. Azure Cognitive Search) resource\n",
"3. Azure Blob Storage resource\n",
"4. Your documents to be used as data (See [data source options](https://learn.microsoft.com/azure/ai-services/openai/concepts/use-your-data#data-source-options))\n",
"\n",
@ -70,8 +70,8 @@
"\n",
"* `AZURE_OPENAI_ENDPOINT` - the Azure OpenAI endpoint. This can be found under \"Keys and Endpoints\" for your Azure OpenAI resource in the Azure Portal.\n",
"* `AZURE_OPENAI_API_KEY` - the Azure OpenAI API key. This can be found under \"Keys and Endpoints\" for your Azure OpenAI resource in the Azure Portal. Omit if using Azure Active Directory authentication (see below `Authentication using Microsoft Active Directory`)\n",
"* `SEARCH_ENDPOINT` - the Cognitive Search endpoint. This URL be found on the \"Overview\" of your Search resource on the Azure Portal.\n",
"* `SEARCH_KEY` - the Cognitive Search API key. Found under \"Keys\" for your Search resource in the Azure Portal.\n",
"* `SEARCH_ENDPOINT` - the AI Search endpoint. This URL be found on the \"Overview\" of your Search resource on the Azure Portal.\n",
"* `SEARCH_KEY` - the AI Search API key. Found under \"Keys\" for your Search resource in the Azure Portal.\n",
"* `SEARCH_INDEX_NAME` - the name of the index you created with your own data."
]
},

@ -1,320 +0,0 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Azure completions example\n",
"\n",
"This example will cover completions using the Azure OpenAI service. It also includes information on content filtering."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"First, we install the necessary dependencies and import the libraries we will be using."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! pip install \"openai>=1.0.0,<2.0.0\"\n",
"! pip install python-dotenv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import openai\n",
"import dotenv\n",
"\n",
"dotenv.load_dotenv()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Authentication\n",
"\n",
"The Azure OpenAI service supports multiple authentication mechanisms that include API keys and Azure Active Directory token credentials."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"use_azure_active_directory = False # Set this flag to True if you are using Azure Active Directory"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Authentication using API key\n",
"\n",
"To set up the OpenAI SDK to use an *Azure API Key*, we need to set `api_key` to a key associated with your endpoint (you can find this key in *\"Keys and Endpoints\"* under *\"Resource Management\"* in the [Azure Portal](https://portal.azure.com)). You'll also find the endpoint for your resource here."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"if not use_azure_active_directory:\n",
" endpoint = os.environ[\"AZURE_OPENAI_ENDPOINT\"]\n",
" api_key = os.environ[\"AZURE_OPENAI_API_KEY\"]\n",
"\n",
" client = openai.AzureOpenAI(\n",
" azure_endpoint=endpoint,\n",
" api_key=api_key,\n",
" api_version=\"2023-09-01-preview\"\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Authentication using Azure Active Directory\n",
"Let's now see how we can autheticate via Azure Active Directory. We'll start by installing the `azure-identity` library. This library will provide the token credentials we need to authenticate and help us build a token credential provider through the `get_bearer_token_provider` helper function. It's recommended to use `get_bearer_token_provider` over providing a static token to `AzureOpenAI` because this API will automatically cache and refresh tokens for you. \n",
"\n",
"For more information on how to set up Azure Active Directory authentication with Azure OpenAI, see the [documentation](https://learn.microsoft.com/azure/ai-services/openai/how-to/managed-identity)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! pip install \"azure-identity>=1.15.0\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n",
"\n",
"if use_azure_active_directory:\n",
" endpoint = os.environ[\"AZURE_OPENAI_ENDPOINT\"]\n",
" api_key = os.environ[\"AZURE_OPENAI_API_KEY\"]\n",
"\n",
" client = openai.AzureOpenAI(\n",
" azure_endpoint=endpoint,\n",
" azure_ad_token_provider=get_bearer_token_provider(DefaultAzureCredential(), \"https://cognitiveservices.azure.com/.default\"),\n",
" api_version=\"2023-09-01-preview\"\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Note: the AzureOpenAI infers the following arguments from their corresponding environment variables if they are not provided:\n",
"\n",
"- `api_key` from `AZURE_OPENAI_API_KEY`\n",
"- `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`\n",
"- `api_version` from `OPENAI_API_VERSION`\n",
"- `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Deployments\n",
"\n",
"In this section we are going to create a deployment of a model that we can use to create completions."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Deployments: Create in the Azure OpenAI Studio\n",
"Let's deploy a model to use with completions. Go to https://portal.azure.com, find your Azure OpenAI resource, and then navigate to the Azure OpenAI Studio. Click on the \"Deployments\" tab and then create a deployment for the model you want to use for completions. The deployment name that you give the model will be used in the code below."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"deployment = \"\" # Fill in the deployment name from the portal here"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Completions\n",
"\n",
"Now let's create a completion using the client we built."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"The food was delicious and the waiter\"\n",
"completion = client.completions.create(\n",
" model=deployment,\n",
" prompt=prompt,\n",
" stop=\".\",\n",
" temperature=0\n",
")\n",
" \n",
"print(f\"{prompt}{completion.choices[0].text}.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create a streaming completion\n",
"\n",
"We can also stream the response."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"The food was delicious and the waiter\"\n",
"response = client.completions.create(\n",
" model=deployment,\n",
" prompt=prompt,\n",
" stream=True,\n",
")\n",
"for completion in response:\n",
" if len(completion.choices) > 0:\n",
" print(f\"{completion.choices[0].text}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Content filtering\n",
"\n",
"Azure OpenAI service includes content filtering of prompts and completion responses. You can learn more about content filtering and how to configure it [here](https://learn.microsoft.com/azure/ai-services/openai/concepts/content-filter).\n",
"\n",
"If the prompt is flagged by the content filter, the library will raise a `BadRequestError` exception with a `content_filter` error code. Otherwise, you can access the `prompt_filter_results` and `content_filter_results` on the response to see the results of the content filtering and what categories were flagged."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Prompt flagged by content filter"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"try:\n",
" completion = client.completions.create(\n",
" prompt=\"<text violating the content policy>\",\n",
" model=deployment,\n",
" )\n",
"except openai.BadRequestError as e:\n",
" err = json.loads(e.response.text)\n",
" if err[\"error\"][\"code\"] == \"content_filter\":\n",
" print(\"Content filter triggered!\")\n",
" content_filter_result = err[\"error\"][\"innererror\"][\"content_filter_result\"]\n",
" for category, details in content_filter_result.items():\n",
" print(f\"{category}:\\n filtered={details['filtered']}\\n severity={details['severity']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Checking the result of the content filter"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"completion = client.completions.create(\n",
" prompt=\"What's the biggest city in Washington?\",\n",
" model=deployment,\n",
")\n",
"\n",
"print(f\"Answer: {completion.choices[0].text}\")\n",
"\n",
"# prompt content filter result in \"model_extra\" for azure\n",
"prompt_filter_result = completion.model_extra[\"prompt_filter_results\"][0][\"content_filter_results\"]\n",
"print(\"\\nPrompt content filter results:\")\n",
"for category, details in prompt_filter_result.items():\n",
" print(f\"{category}:\\n filtered={details['filtered']}\\n severity={details['severity']}\")\n",
"\n",
"# completion content filter result\n",
"print(\"\\nCompletion content filter results:\")\n",
"completion_filter_result = completion.choices[0].model_extra[\"content_filter_results\"]\n",
"for category, details in completion_filter_result.items():\n",
" print(f\"{category}:\\n filtered={details['filtered']}\\n severity={details['severity']}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
},
"vscode": {
"interpreter": {
"hash": "3a5103089ab7e7c666b279eeded403fcec76de49a40685dbdfe9f9c78ad97c17"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -228,7 +228,7 @@
"source": [
"#### 2. Pass function definition(s) into chat completions API\n",
"\n",
"Now we can pass the function into the chat completions API. If the model determines it should call the function, a `finish_reason` of \"function_call\" will be populated on the choice and the details of which function to call and its arguments will be present in the `message`. Optionally, you can set the `function_call` keyword argument to force the model to call a particular function (e.g. `function_call={\"name\": get_current_weather}`). By default, this is set to `auto`, allowing the model to choose whether to call the function or not. "
"Now we can pass the function into the chat completions API. If the model determines it should call the function, a `finish_reason` of \"tool_calls\" will be populated on the choice and the details of which function to call and its arguments will be present in the `message`. Optionally, you can set the `tool_choice` keyword argument to force the model to call a particular function (e.g. `{\"type\": \"function\", \"function\": {\"name\": get_current_weather}}`). By default, this is set to `auto`, allowing the model to choose whether to call the function or not. "
]
},
{
@ -245,7 +245,7 @@
"chat_completion = client.chat.completions.create(\n",
" model=deployment,\n",
" messages=messages,\n",
" functions=functions,\n",
" tools=functions,\n",
")\n",
"print(chat_completion)"
]
@ -278,7 +278,7 @@
" unit = request.get(\"unit\")\n",
" return {\"temperature\": \"22\", \"unit\": \"celsius\", \"description\": \"Sunny\"}\n",
"\n",
"function_call = chat_completion.choices[0].message.function_call\n",
"function_call = chat_completion.choices[0].message.tool_calls[0].function\n",
"print(function_call.name)\n",
"print(function_call.arguments)\n",
"\n",
@ -313,7 +313,7 @@
"function_completion = client.chat.completions.create(\n",
" model=deployment,\n",
" messages=messages,\n",
" functions=functions,\n",
" tools=functions,\n",
")\n",
"\n",
"print(function_completion.choices[0].message.content.strip())"

@ -35,9 +35,12 @@
}
],
"source": [
"import openai\n",
"from openai import OpenAI\n",
"import os\n",
"from transformers import GPT2Tokenizer\n",
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if you didn't set as an env var>\"))\n",
"\n",
"# OpenAI GPT-2 tokenizer is the same as GPT-3 tokenizer\n",
"# we use it to count the number of tokens in the text\n",
"tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
@ -212,7 +215,7 @@
}
],
"source": [
"def translate_chunk(chunk, engine='text-davinci-002',\n",
"def translate_chunk(chunk, model='gpt-3.5-turbo',\n",
" dest_language='English',\n",
" sample_translation=(\"\\poglavje{Osnove Geometrije} \\label{osn9Geom}\", \"\\poglavje{The basics of Geometry} \\label{osn9Geom}\")\n",
" ):\n",
@ -224,17 +227,17 @@
"\n",
"{sample_translation[1]}\n",
"'''\n",
" response = openai.Completion.create(\n",
" prompt=prompt,\n",
" engine=engine,\n",
" response = client.chat.completions.create(\n",
" messages=[{\"role\": \"user\", \"content\":prompt}],\n",
" model=model,\n",
" temperature=0,\n",
" top_p=1,\n",
" max_tokens=1500,\n",
" )\n",
" result = response['choices'][0]['text'].strip()\n",
" result = response.choices[0].message.content.strip()\n",
" result = result.replace('\"\"\"', '') # remove the double quotes, as we used them to surround the text\n",
" return result\n",
"print(translate_chunk(chunks[800], engine='text-davinci-002', dest_language='English'))"
"print(translate_chunk(chunks[800], model='gpt-3.5-turbo', dest_language='English'))"
]
},
{
@ -1134,7 +1137,7 @@
"for i, chunk in enumerate(chunks):\n",
" print(str(i+1) + \" / \" + str(len(chunks)))\n",
" # translate each chunk\n",
" translated_chunks.append(translate_chunk(chunk, engine='text-davinci-002', dest_language=dest_language))\n",
" translated_chunks.append(translate_chunk(chunk, model='gpt-3.5-turbo', dest_language=dest_language))\n",
"\n",
"# join the chunks together\n",
"result = '\\n\\n'.join(translated_chunks)\n",

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

@ -1,4 +1,5 @@
"""
TODO: This example is deprecated.
Note: To answer questions based on text documents, we recommend the procedure in
[Question Answering using Embeddings](https://github.com/openai/openai-cookbook/blob/main/examples/Question_answering_using_embeddings.ipynb).
Some of the code below may rely on [deprecated API endpoints](https://github.com/openai/openai-cookbook/tree/main/transition_guides_for_deprecated_API_endpoints).
@ -6,7 +7,10 @@ Some of the code below may rely on [deprecated API endpoints](https://github.com
import argparse
import openai
from openai import OpenAI
import os
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"))
def create_context(
@ -21,7 +25,8 @@ def create_context(
:param max_rerank: The maximum number of reranking
:return: The context
"""
results = openai.Engine(search_model).search(
# TODO: openai.Engine(search_model) is deprecated
results = client.Engine(search_model).search(
search_model=search_model,
query=question,
max_rerank=max_rerank,
@ -80,16 +85,14 @@ def answer_question(
and fine_tuned_qa_model.split(":")[1].startswith("ft")
else {"engine": fine_tuned_qa_model}
)
response = openai.Completion.create(
prompt=f"Answer the question based on the context below\n\nText: {context}\n\n---\n\nQuestion: {question}\nAnswer:",
temperature=0,
max_tokens=max_tokens,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
stop=stop_sequence,
**model_param,
)
response = client.chat.completions.create(prompt=f"Answer the question based on the context below\n\nText: {context}\n\n---\n\nQuestion: {question}\nAnswer:",
temperature=0,
max_tokens=max_tokens,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
stop=stop_sequence,
**model_param)
return response["choices"][0]["text"]
except Exception as e:
print(e)

@ -92,7 +92,7 @@
"import time\n",
"\n",
"import pandas as pd\n",
"import openai\n",
"from openai import OpenAI\n",
"import tiktoken\n",
"import seaborn as sns\n",
"from tenacity import retry, wait_exponential\n",
@ -106,7 +106,9 @@
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"tqdm.pandas()"
"tqdm.pandas()\n",
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n"
]
},
{
@ -123,7 +125,6 @@
"metadata": {},
"outputs": [],
"source": [
"openai.api_key = \"sk-xxx\"\n",
"os.environ[\"QDRANT_URL\"] = \"https://xxx.cloud.qdrant.io:6333\"\n",
"os.environ[\"QDRANT_API_KEY\"] = \"xxx\""
]
@ -265,7 +266,7 @@
"# Function with tenacity for retries\n",
"@retry(wait=wait_exponential(multiplier=1, min=2, max=6))\n",
"def api_call(messages, model):\n",
" return openai.ChatCompletion.create(\n",
" return client.chat.completions.create(\n",
" model=model,\n",
" messages=messages,\n",
" stop=[\"\\n\\n\"],\n",
@ -275,10 +276,10 @@
"\n",
"\n",
"# Main function to answer question\n",
"def answer_question(row, prompt_func=get_prompt, model=\"gpt-3.5-turbo-0613\"):\n",
"def answer_question(row, prompt_func=get_prompt, model=\"gpt-3.5-turbo\"):\n",
" messages = prompt_func(row)\n",
" response = api_call(messages, model)\n",
" return response[\"choices\"][0][\"message\"][\"content\"]"
" return response.choices[0].message.content"
]
},
{
@ -557,7 +558,7 @@
" self.model_id = None\n",
"\n",
" def create_openai_file(self):\n",
" self.file_object = openai.File.create(\n",
" self.file_object = client.files.create(\n",
" file=open(self.training_file_path, \"r\"),\n",
" purpose=\"fine-tune\",\n",
" )\n",
@ -569,7 +570,7 @@
" print(\"File Status: \", self.file_object.status)\n",
"\n",
" def create_fine_tuning_job(self):\n",
" self.fine_tuning_job = openai.FineTuningJob.create(\n",
" self.fine_tuning_job = client.fine_tuning.jobs.create(\n",
" training_file=self.file_object[\"id\"],\n",
" model=self.model_name,\n",
" suffix=self.suffix,\n",
@ -582,7 +583,7 @@
" print(\"Job Status: \", self.fine_tuning_job.status)\n",
"\n",
" def retrieve_fine_tuned_model(self):\n",
" self.model_id = openai.FineTuningJob.retrieve(self.fine_tuning_job[\"id\"]).fine_tuned_model\n",
" self.model_id = client.fine_tuning.jobs.retrieve(self.fine_tuning_job[\"id\"]).fine_tuned_model\n",
" return self.model_id\n",
"\n",
" def fine_tune_model(self):\n",
@ -631,7 +632,7 @@
"metadata": {},
"outputs": [],
"source": [
"completion = openai.ChatCompletion.create(\n",
"completion = client.chat.completions.create(\n",
" model=model_id,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
@ -1246,7 +1247,7 @@
"outputs": [],
"source": [
"# Let's try this out\n",
"completion = openai.ChatCompletion.create(\n",
"completion = client.chat.completions.create(\n",
" model=model_id,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",

@ -170,21 +170,21 @@
}
],
"source": [
"import openai\n",
"from openai import OpenAI\n",
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
"\n",
"def get_questions(context):\n",
" try:\n",
" response = openai.Completion.create(\n",
" engine=\"davinci-instruct-beta-v3\",\n",
" prompt=f\"Write questions based on the text below\\n\\nText: {context}\\n\\nQuestions:\\n1.\",\n",
" temperature=0,\n",
" max_tokens=257,\n",
" top_p=1,\n",
" frequency_penalty=0,\n",
" presence_penalty=0,\n",
" stop=[\"\\n\\n\"]\n",
" )\n",
" return response['choices'][0]['text']\n",
" response = client.chat.completions.create(model=\"davinci-instruct-beta-v3\",\n",
" prompt=f\"Write questions based on the text below\\n\\nText: {context}\\n\\nQuestions:\\n1.\",\n",
" temperature=0,\n",
" max_tokens=257,\n",
" top_p=1,\n",
" frequency_penalty=0,\n",
" presence_penalty=0,\n",
" stop=[\"\\n\\n\"])\n",
" return response.choices[0].text\n",
" except:\n",
" return \"\"\n",
"\n",
@ -254,7 +254,7 @@
"source": [
"def get_answers(row):\n",
" try:\n",
" response = openai.Completion.create(\n",
" response = client.chat.completions.create(\n",
" engine=\"davinci-instruct-beta-v3\",\n",
" prompt=f\"Write answer based on the text below\\n\\nText: {row.context}\\n\\nQuestions:\\n{row.questions}\\n\\nAnswers:\\n1.\",\n",
" temperature=0,\n",
@ -263,7 +263,7 @@
" frequency_penalty=0,\n",
" presence_penalty=0\n",
" )\n",
" return response['choices'][0]['text']\n",
" return response.choices[0].text\n",
" except Exception as e:\n",
" print (e)\n",
" return \"\"\n",
@ -320,7 +320,7 @@
"df = df[df.tokens<2000]\n",
"df[['context', 'tokens']].rename(columns={'context':'text','tokens':'metadata'}).to_json('olympics-data/olympics_search.jsonl', orient='records', lines=True)\n",
"\n",
"search_file = openai.File.create(\n",
"search_file = client.files.create(\n",
" file=open(\"olympics-data/olympics_search.jsonl\"),\n",
" purpose='search'\n",
")\n",
@ -473,6 +473,7 @@
" \"\"\"\n",
" \n",
" try:\n",
" # TODO: openai.Engine(search_model) is deprecated\n",
" results = openai.Engine(search_model).search(\n",
" search_model=search_model, \n",
" query=question, \n",

@ -244,6 +244,7 @@
" Find similar contexts to the given context using the search file\n",
" \"\"\"\n",
" try:\n",
" # TODO: openai.Engine(search_model) is deprecated\n",
" results = openai.Engine(search_model).search(\n",
" search_model=search_model, \n",
" query=question, \n",
@ -438,7 +439,7 @@
" Apply the fine tuned discriminator to a question, to assess whether it can be answered from the context.\n",
" \"\"\"\n",
" prompt = f\"{context}\\nQuestion: {question}\\n Related:\"\n",
" result = openai.Completion.create(model=discriminator_model, prompt=prompt, max_tokens=1, temperature=0, top_p=1, n=1, logprobs=2)\n",
" result = openai.chat.completions.create(model=discriminator_model, prompt=prompt, max_tokens=1, temperature=0, top_p=1, n=1, logprobs=2)\n",
" return result['choices'][0]['logprobs']['top_logprobs']\n",
"\n",
"apply_ft_discriminator('The first human-made object in space was the Soviet Union satellite Sputnik 1 on 4 October 1957.', \n",
@ -474,7 +475,7 @@
" Apply the fine tuned discriminator to a question\n",
" \"\"\"\n",
" prompt = f\"{context}\\nQuestion: {question}\\nAnswer:\"\n",
" result = openai.Completion.create(model=answering_model, prompt=prompt, max_tokens=30, temperature=0, top_p=1, n=1, stop=['.','\\n'])\n",
" result = openai.chat.completions.create(model=answering_model, prompt=prompt, max_tokens=30, temperature=0, top_p=1, n=1, stop=['.','\\n'])\n",
" return result['choices'][0]['text']\n",
"\n",
"apply_ft_qa_answer('The first human-made object in space was the Soviet Union satellite Sputnik 1 on 4 October 1957.', \n",

@ -41,12 +41,23 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "9617e95e",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import openai\n",
"from openai import OpenAI\n",
"import json\n",
"import boto3\n",
"import os\n",
@ -69,18 +80,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "6d5b1991",
"metadata": {},
"outputs": [],
"source": [
"openai.api_key = os.environ.get(\"OPENAI_API_KEY\")\n",
"OpenAI.api_key = os.environ.get(\"OPENAI_API_KEY\")\n",
"GPT_MODEL = \"gpt-3.5-turbo\""
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "a571b8d7",
"metadata": {},
"outputs": [],
@ -90,7 +101,10 @@
"# os.environ['AWS_SECRET_ACCESS_KEY'] = ''\n",
"\n",
"# Create S3 client\n",
"s3_client = boto3.client('s3')"
"s3_client = boto3.client('s3')\n",
"\n",
"# Create openai client\n",
"client = OpenAI()"
]
},
{
@ -113,73 +127,88 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"id": "da4a804b",
"metadata": {},
"outputs": [],
"source": [
"# Functions dict to pass S3 operations details for the GPT model\n",
"functions = [\n",
" {\n",
" \"name\": \"list_buckets\",\n",
" \"description\": \"List all available S3 buckets\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {}\n",
" { \n",
" \"type\": \"function\",\n",
" \"function\":{\n",
" \"name\": \"list_buckets\",\n",
" \"description\": \"List all available S3 buckets\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {}\n",
" }\n",
" }\n",
" },\n",
" {\n",
" \"name\": \"list_objects\",\n",
" \"description\": \"List the objects or files inside a given S3 bucket\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"bucket\": {\"type\": \"string\", \"description\": \"The name of the S3 bucket\"},\n",
" \"prefix\": {\"type\": \"string\", \"description\": \"The folder path in the S3 bucket\"},\n",
" \"type\": \"function\",\n",
" \"function\":{\n",
" \"name\": \"list_objects\",\n",
" \"description\": \"List the objects or files inside a given S3 bucket\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"bucket\": {\"type\": \"string\", \"description\": \"The name of the S3 bucket\"},\n",
" \"prefix\": {\"type\": \"string\", \"description\": \"The folder path in the S3 bucket\"},\n",
" },\n",
" \"required\": [\"bucket\"],\n",
" },\n",
" \"required\": [\"bucket\"],\n",
" },\n",
" }\n",
" },\n",
" {\n",
" \"name\": \"download_file\",\n",
" \"description\": \"Download a specific file from an S3 bucket to a local distribution folder.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"bucket\": {\"type\": \"string\", \"description\": \"The name of the S3 bucket\"},\n",
" \"key\": {\"type\": \"string\", \"description\": \"The path to the file inside the bucket\"},\n",
" \"directory\": {\"type\": \"string\", \"description\": \"The local destination directory to download the file, should be specificed by the user.\"},\n",
" },\n",
" \"required\": [\"bucket\", \"key\", \"directory\"],\n",
" { \n",
" \"type\": \"function\",\n",
" \"function\":{\n",
" \"name\": \"download_file\",\n",
" \"description\": \"Download a specific file from an S3 bucket to a local distribution folder.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"bucket\": {\"type\": \"string\", \"description\": \"The name of the S3 bucket\"},\n",
" \"key\": {\"type\": \"string\", \"description\": \"The path to the file inside the bucket\"},\n",
" \"directory\": {\"type\": \"string\", \"description\": \"The local destination directory to download the file, should be specificed by the user.\"},\n",
" },\n",
" \"required\": [\"bucket\", \"key\", \"directory\"],\n",
" }\n",
" }\n",
" },\n",
" {\n",
" \"name\": \"upload_file\",\n",
" \"description\": \"Upload a file to an S3 bucket\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"source\": {\"type\": \"string\", \"description\": \"The local source path or remote URL\"},\n",
" \"bucket\": {\"type\": \"string\", \"description\": \"The name of the S3 bucket\"},\n",
" \"key\": {\"type\": \"string\", \"description\": \"The path to the file inside the bucket\"},\n",
" \"is_remote_url\": {\"type\": \"boolean\", \"description\": \"Is the provided source a URL (True) or local path (False)\"},\n",
" },\n",
" \"required\": [\"source\", \"bucket\", \"key\", \"is_remote_url\"],\n",
" \"type\": \"function\",\n",
" \"function\":{\n",
" \"name\": \"upload_file\",\n",
" \"description\": \"Upload a file to an S3 bucket\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"source\": {\"type\": \"string\", \"description\": \"The local source path or remote URL\"},\n",
" \"bucket\": {\"type\": \"string\", \"description\": \"The name of the S3 bucket\"},\n",
" \"key\": {\"type\": \"string\", \"description\": \"The path to the file inside the bucket\"},\n",
" \"is_remote_url\": {\"type\": \"boolean\", \"description\": \"Is the provided source a URL (True) or local path (False)\"},\n",
" },\n",
" \"required\": [\"source\", \"bucket\", \"key\", \"is_remote_url\"],\n",
" }\n",
" }\n",
" },\n",
" {\n",
" \"name\": \"search_s3_objects\",\n",
" \"description\": \"Search for a specific file name inside an S3 bucket\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"search_name\": {\"type\": \"string\", \"description\": \"The name of the file you want to search for\"},\n",
" \"bucket\": {\"type\": \"string\", \"description\": \"The name of the S3 bucket\"},\n",
" \"prefix\": {\"type\": \"string\", \"description\": \"The folder path in the S3 bucket\"},\n",
" \"exact_match\": {\"type\": \"boolean\", \"description\": \"Set exact_match to True if the search should match the exact file name. Set exact_match to False to compare part of the file name string (the file contains)\"}\n",
" \"type\": \"function\",\n",
" \"function\":{\n",
" \"name\": \"search_s3_objects\",\n",
" \"description\": \"Search for a specific file name inside an S3 bucket\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"search_name\": {\"type\": \"string\", \"description\": \"The name of the file you want to search for\"},\n",
" \"bucket\": {\"type\": \"string\", \"description\": \"The name of the S3 bucket\"},\n",
" \"prefix\": {\"type\": \"string\", \"description\": \"The folder path in the S3 bucket\"},\n",
" \"exact_match\": {\"type\": \"boolean\", \"description\": \"Set exact_match to True if the search should match the exact file name. Set exact_match to False to compare part of the file name string (the file contains)\"}\n",
" },\n",
" \"required\": [\"search_name\"],\n",
" },\n",
" \"required\": [\"search_name\"],\n",
" },\n",
" }\n",
" }\n",
"]"
]
@ -195,7 +224,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 5,
"id": "cf30f14e",
"metadata": {},
"outputs": [],
@ -208,7 +237,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 6,
"id": "37736b74",
"metadata": {},
"outputs": [],
@ -275,7 +304,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 7,
"id": "03c3d555",
"metadata": {},
"outputs": [],
@ -300,7 +329,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"id": "796acfdd",
"metadata": {},
"outputs": [],
@ -309,13 +338,13 @@
" model_name=GPT_MODEL):\n",
" \n",
" if functions is not None:\n",
" return openai.ChatCompletion.create(\n",
" return client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" functions=functions,\n",
" function_call=function_call)\n",
" tools=functions,\n",
" tool_choice=function_call)\n",
" else:\n",
" return openai.ChatCompletion.create(\n",
" return client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages)"
]
@ -340,7 +369,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 17,
"id": "3e2e9192",
"metadata": {},
"outputs": [],
@ -354,15 +383,15 @@
" \n",
" # Call the model to get a response\n",
" response = chat_completion_request(messages, functions=functions)\n",
" response_message = response['choices'][0]['message']\n",
" response_message = response.choices[0].message\n",
" \n",
" if is_log:\n",
" print(response['choices'])\n",
" print(response.choices)\n",
" \n",
" # check if GPT wanted to call a function\n",
" if response_message.get(\"function_call\"):\n",
" function_name = response_message['function_call']['name']\n",
" function_args = json.loads(response_message['function_call']['arguments'])\n",
" if response_message.tool_calls:\n",
" function_name = response_message.tool_calls[0].function.name\n",
" function_args = json.loads(response_message.tool_calls[0].function.arguments)\n",
" \n",
" # Call the function\n",
" function_response = available_functions[function_name](**function_args)\n",
@ -370,16 +399,16 @@
" # Add the response to the conversation\n",
" messages.append(response_message)\n",
" messages.append({\n",
" \"role\": \"function\",\n",
" \"name\": function_name,\n",
" \"role\": \"tool\",\n",
" \"content\": function_response,\n",
" \"tool_call_id\": response_message.tool_calls[0].id,\n",
" })\n",
" \n",
" # Call the model again to summarize the results\n",
" second_response = chat_completion_request(messages)\n",
" final_message = second_response['choices'][0]['message']['content']\n",
" final_message = second_response.choices[0].message.content\n",
" else:\n",
" final_message = response_message['content']\n",
" final_message = response_message.content\n",
"\n",
" return final_message"
]

@ -134,7 +134,7 @@
},
"source": [
"Before we start, we can configure the LLM provider and model that will power our RAG system. \n",
"Here, we pick *text-davinci-003* from OpenAI, allow unlimited output tokens. "
"Here, we pick `gpt-3.5-turbo-instruct` from OpenAI. "
]
},
{
@ -146,7 +146,7 @@
},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0, model_name=\"text-davinci-003\", max_tokens=-1)"
"llm = OpenAI(temperature=0, model_name=\"gpt-3.5-turbo-instruct\", max_tokens=-1)"
]
},
{

@ -7,60 +7,57 @@ from scipy import spatial
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import average_precision_score, precision_recall_curve
from tenacity import retry, stop_after_attempt, wait_random_exponential
import openai
from openai import OpenAI
import numpy as np
import pandas as pd
client = OpenAI(max_retries=5)
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embedding(text: str, model="text-similarity-davinci-001", **kwargs) -> List[float]:
def get_embedding(text: str, model="text-embedding-3-small", **kwargs) -> List[float]:
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
response = openai.embeddings.create(input=[text], model=model, **kwargs)
response = client.embeddings.create(input=[text], model=model, **kwargs)
return response.data[0].embedding
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embedding(
text: str, model="text-similarity-davinci-001", **kwargs
text: str, model="text-embedding-3-small", **kwargs
) -> List[float]:
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (await openai.embeddings.create(input=[text], model=model, **kwargs))["data"][0][
"embedding"
]
return (await client.embeddings.create(input=[text], model=model, **kwargs))[
"data"
][0]["embedding"]
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embeddings(
list_of_text: List[str], model="text-similarity-babbage-001", **kwargs
list_of_text: List[str], model="text-embedding-3-small", **kwargs
) -> List[List[float]]:
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = openai.embeddings.create(input=list_of_text, model=model, **kwargs).data
data = client.embeddings.create(input=list_of_text, model=model, **kwargs).data
return [d.embedding for d in data]
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embeddings(
list_of_text: List[str], model="text-similarity-babbage-001", **kwargs
list_of_text: List[str], model="text-embedding-3-small", **kwargs
) -> List[List[float]]:
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = (await openai.embeddings.create(input=list_of_text, model=model, **kwargs)).data
data = (
await client.embeddings.create(input=list_of_text, model=model, **kwargs)
).data
return [d.embedding for d in data]

@ -371,7 +371,7 @@
"source": [
"## Search data\n",
"\n",
"Once the data is put into Qdrant we will start querying the collection for the closest vectors. We may provide an additional parameter `vector_name` to switch from title to content based search. Since the precomputed embeddings were created with `text-embedding-ada-002` OpenAI model we also have to use it during search."
"Once the data is put into Qdrant we will start querying the collection for the closest vectors. We may provide an additional parameter `vector_name` to switch from title to content based search. Since the precomputed embeddings were created with `text-embedding-3-small` OpenAI model we also have to use it during search."
]
},
{
@ -385,7 +385,7 @@
" # Creates embedding vector from user query\n",
" embedded_query = openai.Embedding.create(\n",
" input=query,\n",
" model=\"text-embedding-ada-002\",\n",
" model=\"text-embedding-3-small\",\n",
" )[\"data\"][0][\"embedding\"]\n",
"\n",
" # Convert the embedded_query to PostgreSQL compatible format\n",

@ -9,7 +9,7 @@ Each provider has their own named directory, with a standard notebook to introdu
## Guides & deep dives
- [AnalyticDB](https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/latest/get-started-with-analyticdb-for-postgresql)
- [Cassandra/Astra DB](https://docs.datastax.com/en/astra-serverless/docs/vector-search/qandasimsearch-quickstart.html)
- [AzureSearch](https://learn.microsoft.com/azure/search/search-get-started-vector)
- [Azure AI Search](https://learn.microsoft.com/azure/search/search-get-started-vector)
- [Chroma](https://docs.trychroma.com/getting-started)
- [Elasticsearch](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html)
- [Hologres](https://www.alibabacloud.com/help/en/hologres/latest/procedure-to-use-hologres)

@ -46,8 +46,7 @@
"source": [
"import openai\n",
"\n",
"# models\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
"EMBEDDING_MODEL = \"text-embedding-3-small\"\n",
"GPT_MODEL = \"gpt-3.5-turbo\"\n"
]
},

@ -429,7 +429,7 @@
"source": [
"## Search data\n",
"\n",
"Once the data is put into Qdrant we will start querying the collection for the closest vectors. We may provide an additional parameter `vector_name` to switch from title to content based search. Since the precomputed embeddings were created with `text-embedding-ada-002` OpenAI model we also have to use it during search.\n"
"Once the data is put into Qdrant we will start querying the collection for the closest vectors. We may provide an additional parameter `vector_name` to switch from title to content based search. Since the precomputed embeddings were created with `text-embedding-3-small` OpenAI model we also have to use it during search.\n"
]
},
{
@ -448,7 +448,7 @@
" # Creates embedding vector from user query\n",
" embedded_query = openai.Embedding.create(\n",
" input=query,\n",
" model=\"text-embedding-ada-002\",\n",
" model=\"text-embedding-3-small\",\n",
" )[\"data\"][0][\"embedding\"]\n",
"\n",
" # Convert the embedded_query to PostgreSQL compatible format\n",

@ -0,0 +1,738 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Azure AI Search as a vector database for OpenAI embeddings"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook provides step by step instuctions on using Azure AI Search (f.k.a Azure Cognitive Search) as a vector database with OpenAI embeddings. Azure AI Search is a cloud search service that gives developers infrastructure, APIs, and tools for building a rich search experience over private, heterogeneous content in web, mobile, and enterprise applications.\n",
"\n",
"## Prerequistites:\n",
"For the purposes of this exercise you must have the following:\n",
"- [Azure AI Search Service](https://learn.microsoft.com/azure/search/)\n",
"- [OpenAI Key](https://platform.openai.com/account/api-keys) or [Azure OpenAI credentials](https://learn.microsoft.com/azure/cognitive-services/openai/)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! pip install wget\n",
"! pip install azure-search-documents \n",
"! pip install azure-identity\n",
"! pip install openai"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import required libraries"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import json \n",
"import wget\n",
"import pandas as pd\n",
"import zipfile\n",
"from openai import AzureOpenAI\n",
"from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n",
"from azure.core.credentials import AzureKeyCredential \n",
"from azure.search.documents import SearchClient, SearchIndexingBufferedSender \n",
"from azure.search.documents.indexes import SearchIndexClient \n",
"from azure.search.documents.models import (\n",
" QueryAnswerType,\n",
" QueryCaptionType,\n",
" QueryType,\n",
" VectorizedQuery,\n",
")\n",
"from azure.search.documents.indexes.models import (\n",
" HnswAlgorithmConfiguration,\n",
" HnswParameters,\n",
" SearchField,\n",
" SearchableField,\n",
" SearchFieldDataType,\n",
" SearchIndex,\n",
" SemanticConfiguration,\n",
" SemanticField,\n",
" SemanticPrioritizedFields,\n",
" SemanticSearch,\n",
" SimpleField,\n",
" VectorSearch,\n",
" VectorSearchAlgorithmKind,\n",
" VectorSearchAlgorithmMetric,\n",
" VectorSearchProfile,\n",
")\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configure OpenAI settings\n",
"\n",
"This section guides you through setting up authentication for Azure OpenAI, allowing you to securely interact with the service using either Azure Active Directory (AAD) or an API key. Before proceeding, ensure you have your Azure OpenAI endpoint and credentials ready. For detailed instructions on setting up AAD with Azure OpenAI, refer to the [official documentation](https://learn.microsoft.com/azure/ai-services/openai/how-to/managed-identity).\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"endpoint: str = \"YOUR_AZURE_OPENAI_ENDPOINT\"\n",
"api_key: str = \"YOUR_AZURE_OPENAI_KEY\"\n",
"api_version: str = \"2023-05-15\"\n",
"deployment = \"YOUR_AZURE_OPENAI_DEPLOYMENT_NAME\"\n",
"credential = DefaultAzureCredential()\n",
"token_provider = get_bearer_token_provider(\n",
" credential, \"https://cognitiveservices.azure.com/.default\"\n",
")\n",
"\n",
"# Set this flag to True if you are using Azure Active Directory\n",
"use_aad_for_aoai = True \n",
"\n",
"if use_aad_for_aoai:\n",
" # Use Azure Active Directory (AAD) authentication\n",
" client = AzureOpenAI(\n",
" azure_endpoint=endpoint,\n",
" api_version=api_version,\n",
" azure_ad_token_provider=token_provider,\n",
" )\n",
"else:\n",
" # Use API key authentication\n",
" client = AzureOpenAI(\n",
" api_key=api_key,\n",
" api_version=api_version,\n",
" azure_endpoint=endpoint,\n",
" )"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configure Azure AI Search Vector Store settings\n",
"This section explains how to set up the Azure AI Search client for integrating with the Vector Store feature. You can locate your Azure AI Search service details in the Azure Portal or programmatically via the [Search Management SDK](https://learn.microsoft.com/rest/api/searchmanagement/).\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Configuration\n",
"search_service_endpoint: str = \"YOUR_AZURE_SEARCH_ENDPOINT\"\n",
"search_service_api_key: str = \"YOUR_AZURE_SEARCH_ADMIN_KEY\"\n",
"index_name: str = \"azure-ai-search-openai-cookbook-demo\"\n",
"\n",
"# Set this flag to True if you are using Azure Active Directory\n",
"use_aad_for_search = True \n",
"\n",
"if use_aad_for_search:\n",
" # Use Azure Active Directory (AAD) authentication\n",
" credential = DefaultAzureCredential()\n",
"else:\n",
" # Use API key authentication\n",
" credential = AzureKeyCredential(search_service_api_key)\n",
"\n",
"# Initialize the SearchClient with the selected authentication method\n",
"search_client = SearchClient(\n",
" endpoint=search_service_endpoint, index_name=index_name, credential=credential\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'vector_database_wikipedia_articles_embedded.zip'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embeddings_url = \"https://cdn.openai.com/API/examples/data/vector_database_wikipedia_articles_embedded.zip\"\n",
"\n",
"# The file is ~700 MB so this will take some time\n",
"wget.download(embeddings_url)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"with zipfile.ZipFile(\"vector_database_wikipedia_articles_embedded.zip\", \"r\") as zip_ref:\n",
" zip_ref.extractall(\"../../data\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>url</th>\n",
" <th>title</th>\n",
" <th>text</th>\n",
" <th>title_vector</th>\n",
" <th>content_vector</th>\n",
" <th>vector_id</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>https://simple.wikipedia.org/wiki/April</td>\n",
" <td>April</td>\n",
" <td>April is the fourth month of the year in the J...</td>\n",
" <td>[0.001009464613161981, -0.020700545981526375, ...</td>\n",
" <td>[-0.011253940872848034, -0.013491976074874401,...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>https://simple.wikipedia.org/wiki/August</td>\n",
" <td>August</td>\n",
" <td>August (Aug.) is the eighth month of the year ...</td>\n",
" <td>[0.0009286514250561595, 0.000820168002974242, ...</td>\n",
" <td>[0.0003609954728744924, 0.007262262050062418, ...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>6</td>\n",
" <td>https://simple.wikipedia.org/wiki/Art</td>\n",
" <td>Art</td>\n",
" <td>Art is a creative activity that expresses imag...</td>\n",
" <td>[0.003393713850528002, 0.0061537534929811954, ...</td>\n",
" <td>[-0.004959689453244209, 0.015772193670272827, ...</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>8</td>\n",
" <td>https://simple.wikipedia.org/wiki/A</td>\n",
" <td>A</td>\n",
" <td>A or a is the first letter of the English alph...</td>\n",
" <td>[0.0153952119871974, -0.013759135268628597, 0....</td>\n",
" <td>[0.024894846603274345, -0.022186409682035446, ...</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>9</td>\n",
" <td>https://simple.wikipedia.org/wiki/Air</td>\n",
" <td>Air</td>\n",
" <td>Air refers to the Earth's atmosphere. Air is a...</td>\n",
" <td>[0.02224554680287838, -0.02044147066771984, -0...</td>\n",
" <td>[0.021524671465158463, 0.018522677943110466, -...</td>\n",
" <td>4</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id url title \\\n",
"0 1 https://simple.wikipedia.org/wiki/April April \n",
"1 2 https://simple.wikipedia.org/wiki/August August \n",
"2 6 https://simple.wikipedia.org/wiki/Art Art \n",
"3 8 https://simple.wikipedia.org/wiki/A A \n",
"4 9 https://simple.wikipedia.org/wiki/Air Air \n",
"\n",
" text \\\n",
"0 April is the fourth month of the year in the J... \n",
"1 August (Aug.) is the eighth month of the year ... \n",
"2 Art is a creative activity that expresses imag... \n",
"3 A or a is the first letter of the English alph... \n",
"4 Air refers to the Earth's atmosphere. Air is a... \n",
"\n",
" title_vector \\\n",
"0 [0.001009464613161981, -0.020700545981526375, ... \n",
"1 [0.0009286514250561595, 0.000820168002974242, ... \n",
"2 [0.003393713850528002, 0.0061537534929811954, ... \n",
"3 [0.0153952119871974, -0.013759135268628597, 0.... \n",
"4 [0.02224554680287838, -0.02044147066771984, -0... \n",
"\n",
" content_vector vector_id \n",
"0 [-0.011253940872848034, -0.013491976074874401,... 0 \n",
"1 [0.0003609954728744924, 0.007262262050062418, ... 1 \n",
"2 [-0.004959689453244209, 0.015772193670272827, ... 2 \n",
"3 [0.024894846603274345, -0.022186409682035446, ... 3 \n",
"4 [0.021524671465158463, 0.018522677943110466, -... 4 "
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"article_df = pd.read_csv(\"../../data/vector_database_wikipedia_articles_embedded.csv\")\n",
"\n",
"# Read vectors from strings back into a list using json.loads\n",
"article_df[\"title_vector\"] = article_df.title_vector.apply(json.loads)\n",
"article_df[\"content_vector\"] = article_df.content_vector.apply(json.loads)\n",
"article_df[\"vector_id\"] = article_df[\"vector_id\"].apply(str)\n",
"article_df.head()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create an index\n",
"This code snippet demonstrates how to define and create a search index using the `SearchIndexClient` from the Azure AI Search Python SDK. The index incorporates both vector search and semantic ranker capabilities. For more details, visit our documentation on how to [Create a Vector Index](https://learn.microsoft.com/azure/search/vector-search-how-to-create-index?.tabs=config-2023-11-01%2Crest-2023-11-01%2Cpush%2Cportal-check-index)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"azure-ai-search-openai-cookbook-demo created\n"
]
}
],
"source": [
"# Initialize the SearchIndexClient\n",
"index_client = SearchIndexClient(\n",
" endpoint=search_service_endpoint, credential=credential\n",
")\n",
"\n",
"# Define the fields for the index\n",
"fields = [\n",
" SimpleField(name=\"id\", type=SearchFieldDataType.String),\n",
" SimpleField(name=\"vector_id\", type=SearchFieldDataType.String, key=True),\n",
" SimpleField(name=\"url\", type=SearchFieldDataType.String),\n",
" SearchableField(name=\"title\", type=SearchFieldDataType.String),\n",
" SearchableField(name=\"text\", type=SearchFieldDataType.String),\n",
" SearchField(\n",
" name=\"title_vector\",\n",
" type=SearchFieldDataType.Collection(SearchFieldDataType.Single),\n",
" vector_search_dimensions=1536,\n",
" vector_search_profile_name=\"my-vector-config\",\n",
" ),\n",
" SearchField(\n",
" name=\"content_vector\",\n",
" type=SearchFieldDataType.Collection(SearchFieldDataType.Single),\n",
" vector_search_dimensions=1536,\n",
" vector_search_profile_name=\"my-vector-config\",\n",
" ),\n",
"]\n",
"\n",
"# Configure the vector search configuration\n",
"vector_search = VectorSearch(\n",
" algorithms=[\n",
" HnswAlgorithmConfiguration(\n",
" name=\"my-hnsw\",\n",
" kind=VectorSearchAlgorithmKind.HNSW,\n",
" parameters=HnswParameters(\n",
" m=4,\n",
" ef_construction=400,\n",
" ef_search=500,\n",
" metric=VectorSearchAlgorithmMetric.COSINE,\n",
" ),\n",
" )\n",
" ],\n",
" profiles=[\n",
" VectorSearchProfile(\n",
" name=\"my-vector-config\",\n",
" algorithm_configuration_name=\"my-hnsw\",\n",
" )\n",
" ],\n",
")\n",
"\n",
"# Configure the semantic search configuration\n",
"semantic_search = SemanticSearch(\n",
" configurations=[\n",
" SemanticConfiguration(\n",
" name=\"my-semantic-config\",\n",
" prioritized_fields=SemanticPrioritizedFields(\n",
" title_field=SemanticField(field_name=\"title\"),\n",
" keywords_fields=[SemanticField(field_name=\"url\")],\n",
" content_fields=[SemanticField(field_name=\"text\")],\n",
" ),\n",
" )\n",
" ]\n",
")\n",
"\n",
"# Create the search index with the vector search and semantic search configurations\n",
"index = SearchIndex(\n",
" name=index_name,\n",
" fields=fields,\n",
" vector_search=vector_search,\n",
" semantic_search=semantic_search,\n",
")\n",
"\n",
"# Create or update the index\n",
"result = index_client.create_or_update_index(index)\n",
"print(f\"{result.name} created\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Uploading Data to Azure AI Search Index\n",
"\n",
"The following code snippet outlines the process of uploading a batch of documents—specifically, Wikipedia articles with pre-computed embeddings—from a pandas DataFrame to an Azure AI Search index. For a detailed guide on data import strategies and best practices, refer to [Data Import in Azure AI Search](https://learn.microsoft.com/azure/search/search-what-is-data-import).\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Uploaded 25000 documents in total\n"
]
}
],
"source": [
"from azure.core.exceptions import HttpResponseError\n",
"\n",
"# Convert the 'id' and 'vector_id' columns to string so one of them can serve as our key field\n",
"article_df[\"id\"] = article_df[\"id\"].astype(str)\n",
"article_df[\"vector_id\"] = article_df[\"vector_id\"].astype(str)\n",
"# Convert the DataFrame to a list of dictionaries\n",
"documents = article_df.to_dict(orient=\"records\")\n",
"\n",
"# Create a SearchIndexingBufferedSender\n",
"batch_client = SearchIndexingBufferedSender(\n",
" search_service_endpoint, index_name, credential\n",
")\n",
"\n",
"try:\n",
" # Add upload actions for all documents in a single call\n",
" batch_client.upload_documents(documents=documents)\n",
"\n",
" # Manually flush to send any remaining documents in the buffer\n",
" batch_client.flush()\n",
"except HttpResponseError as e:\n",
" print(f\"An error occurred: {e}\")\n",
"finally:\n",
" # Clean up resources\n",
" batch_client.close()\n",
"\n",
"print(f\"Uploaded {len(documents)} documents in total\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If your dataset didn't already contain pre-computed embeddings, you can create embeddings by using the below function using the `openai` python library. You'll also notice the same function and model are being used to generate query embeddings for performing vector searches."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Content: April is the fourth month of the year in the Julian and Gregorian calendars, and comes between March\n",
"Content vector generated\n"
]
}
],
"source": [
"# Example function to generate document embedding\n",
"def generate_embeddings(text, model):\n",
" # Generate embeddings for the provided text using the specified model\n",
" embeddings_response = client.embeddings.create(model=model, input=text)\n",
" # Extract the embedding data from the response\n",
" embedding = embeddings_response.data[0].embedding\n",
" return embedding\n",
"\n",
"\n",
"first_document_content = documents[0][\"text\"]\n",
"print(f\"Content: {first_document_content[:100]}\")\n",
"\n",
"content_vector = generate_embeddings(first_document_content, deployment)\n",
"print(\"Content vector generated\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Perform a vector similarity search"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Title: Documenta\n",
"Score: 0.8599451\n",
"URL: https://simple.wikipedia.org/wiki/Documenta\n",
"\n",
"Title: Museum of Modern Art\n",
"Score: 0.85260946\n",
"URL: https://simple.wikipedia.org/wiki/Museum%20of%20Modern%20Art\n",
"\n",
"Title: Expressionism\n",
"Score: 0.852354\n",
"URL: https://simple.wikipedia.org/wiki/Expressionism\n",
"\n"
]
}
],
"source": [
"# Pure Vector Search\n",
"query = \"modern art in Europe\"\n",
" \n",
"search_client = SearchClient(search_service_endpoint, index_name, credential) \n",
"vector_query = VectorizedQuery(vector=generate_embeddings(query, deployment), k_nearest_neighbors=3, fields=\"content_vector\")\n",
" \n",
"results = search_client.search( \n",
" search_text=None, \n",
" vector_queries= [vector_query], \n",
" select=[\"title\", \"text\", \"url\"] \n",
")\n",
" \n",
"for result in results: \n",
" print(f\"Title: {result['title']}\") \n",
" print(f\"Score: {result['@search.score']}\") \n",
" print(f\"URL: {result['url']}\\n\") "
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Perform a Hybrid Search\n",
"Hybrid search combines the capabilities of traditional keyword-based search with vector-based similarity search to provide more relevant and contextual results. This approach is particularly useful when dealing with complex queries that benefit from understanding the semantic meaning behind the text.\n",
"\n",
"The provided code snippet demonstrates how to execute a hybrid search query:"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Title: Wars of Scottish Independence\n",
"Score: 0.03306011110544205\n",
"URL: https://simple.wikipedia.org/wiki/Wars%20of%20Scottish%20Independence\n",
"\n",
"Title: Battle of Bannockburn\n",
"Score: 0.022253260016441345\n",
"URL: https://simple.wikipedia.org/wiki/Battle%20of%20Bannockburn\n",
"\n",
"Title: Scottish\n",
"Score: 0.016393441706895828\n",
"URL: https://simple.wikipedia.org/wiki/Scottish\n",
"\n"
]
}
],
"source": [
"# Hybrid Search\n",
"query = \"Famous battles in Scottish history\" \n",
" \n",
"search_client = SearchClient(search_service_endpoint, index_name, credential) \n",
"vector_query = VectorizedQuery(vector=generate_embeddings(query, deployment), k_nearest_neighbors=3, fields=\"content_vector\")\n",
" \n",
"results = search_client.search( \n",
" search_text=query, \n",
" vector_queries= [vector_query], \n",
" select=[\"title\", \"text\", \"url\"],\n",
" top=3\n",
")\n",
" \n",
"for result in results: \n",
" print(f\"Title: {result['title']}\") \n",
" print(f\"Score: {result['@search.score']}\") \n",
" print(f\"URL: {result['url']}\\n\") "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Perform a Hybrid Search with Reranking (powered by Bing)\n",
"[Semantic ranker](https://learn.microsoft.com/azure/search/semantic-search-overview) measurably improves search relevance by using language understanding to rerank search results. Additionally, you can get extractive captions, answers, and highlights. "
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Semantic Answer: Advancements During the industrial revolution, new technology brought many changes. For example:<em> Canals</em> were built to allow heavy goods to be moved easily where they were needed. The steam engine became the main source of power. It replaced horses and human labor. Cheap iron and steel became mass-produced.\n",
"Semantic Answer Score: 0.90478515625\n",
"\n",
"Title: Industrial Revolution\n",
"Reranker Score: 3.408700942993164\n",
"URL: https://simple.wikipedia.org/wiki/Industrial%20Revolution\n",
"Caption: Advancements During the industrial revolution, new technology brought many changes. For example: Canals were built to allow heavy goods to be moved easily where they were needed. The steam engine became the main source of power. It replaced horses and human labor. Cheap iron and steel became mass-produced.\n",
"\n",
"Title: Printing\n",
"Reranker Score: 1.603400707244873\n",
"URL: https://simple.wikipedia.org/wiki/Printing\n",
"Caption: Machines to speed printing, cheaper paper, automatic stitching and binding all arrived in the 19th century during the industrial revolution. What had once been done by a few men by hand was now done by limited companies on huge machines. The result was much lower prices, and a much wider readership.\n",
"\n",
"Title: Industrialisation\n",
"Reranker Score: 1.3238357305526733\n",
"URL: https://simple.wikipedia.org/wiki/Industrialisation\n",
"Caption: <em>Industrialisation</em> (or<em> industrialization)</em> is a process that happens in countries when they start to use machines to do work that was once done by people.<em> Industrialisation changes</em> the things people do.<em> Industrialisation</em> caused towns to grow larger. Many people left farming to take higher paid jobs in factories in towns.\n",
"\n"
]
}
],
"source": [
"# Semantic Hybrid Search\n",
"query = \"What were the key technological advancements during the Industrial Revolution?\"\n",
"\n",
"search_client = SearchClient(search_service_endpoint, index_name, credential)\n",
"vector_query = VectorizedQuery(\n",
" vector=generate_embeddings(query, deployment),\n",
" k_nearest_neighbors=3,\n",
" fields=\"content_vector\",\n",
")\n",
"\n",
"results = search_client.search(\n",
" search_text=query,\n",
" vector_queries=[vector_query],\n",
" select=[\"title\", \"text\", \"url\"],\n",
" query_type=QueryType.SEMANTIC,\n",
" semantic_configuration_name=\"my-semantic-config\",\n",
" query_caption=QueryCaptionType.EXTRACTIVE,\n",
" query_answer=QueryAnswerType.EXTRACTIVE,\n",
" top=3,\n",
")\n",
"\n",
"semantic_answers = results.get_answers()\n",
"for answer in semantic_answers:\n",
" if answer.highlights:\n",
" print(f\"Semantic Answer: {answer.highlights}\")\n",
" else:\n",
" print(f\"Semantic Answer: {answer.text}\")\n",
" print(f\"Semantic Answer Score: {answer.score}\\n\")\n",
"\n",
"for result in results:\n",
" print(f\"Title: {result['title']}\")\n",
" print(f\"Reranker Score: {result['@search.reranker_score']}\")\n",
" print(f\"URL: {result['url']}\")\n",
" captions = result[\"@search.captions\"]\n",
" if captions:\n",
" caption = captions[0]\n",
" if caption.highlights:\n",
" print(f\"Caption: {caption.highlights}\\n\")\n",
" else:\n",
" print(f\"Caption: {caption.text}\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -1,647 +0,0 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Azure Cognitive Search as a vector database for OpenAI embeddings"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook provides step by step instuctions on using Azure Cognitive Search as a vector database with OpenAI embeddings. Azure Cognitive Search (formerly known as \"Azure Search\") is a cloud search service that gives developers infrastructure, APIs, and tools for building a rich search experience over private, heterogeneous content in web, mobile, and enterprise applications.\n",
"\n",
"## Prerequistites:\n",
"For the purposes of this exercise you must have the following:\n",
"- [Azure Cognitive Search Service](https://learn.microsoft.com/azure/search/)\n",
"- [OpenAI Key](https://platform.openai.com/account/api-keys) or [Azure OpenAI credentials](https://learn.microsoft.com/azure/cognitive-services/openai/)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! pip install wget\n",
"! pip install azure-search-documents --pre "
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import required libraries"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"import json \n",
"import openai\n",
"import wget\n",
"import pandas as pd\n",
"import zipfile\n",
"from azure.core.credentials import AzureKeyCredential \n",
"from azure.search.documents import SearchClient \n",
"from azure.search.documents.indexes import SearchIndexClient \n",
"from azure.search.documents.models import Vector \n",
"from azure.search.documents import SearchIndexingBufferedSender\n",
"from azure.search.documents.indexes.models import ( \n",
" SearchIndex, \n",
" SearchField, \n",
" SearchFieldDataType, \n",
" SimpleField, \n",
" SearchableField, \n",
" SearchIndex, \n",
" SemanticConfiguration, \n",
" PrioritizedFields, \n",
" SemanticField, \n",
" SearchField, \n",
" SemanticSettings, \n",
" VectorSearch, \n",
" HnswVectorSearchAlgorithmConfiguration, \n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configure OpenAI settings\n",
"\n",
"Configure your OpenAI or Azure OpenAI settings. For this example, we use Azure OpenAI."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"openai.api_type = \"azure\"\n",
"openai.api_base = \"YOUR_AZURE_OPENAI_ENDPOINT\"\n",
"openai.api_version = \"2023-05-15\"\n",
"openai.api_key = \"YOUR_AZURE_OPENAI_KEY\"\n",
"model: str = \"text-embedding-ada-002\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configure Azure Cognitive Search Vector Store settings\n",
"You can find this in the Azure Portal or using the [Search Management SDK](https://learn.microsoft.com/rest/api/searchmanagement/)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"search_service_endpoint: str = \"YOUR_AZURE_SEARCH_ENDPOINT\"\n",
"search_service_api_key: str = \"YOUR_AZURE_SEARCH_ADMIN_KEY\"\n",
"index_name: str = \"azure-cognitive-search-vector-demo\"\n",
"credential = AzureKeyCredential(search_service_api_key)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'vector_database_wikipedia_articles_embedded.zip'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embeddings_url = \"https://cdn.openai.com/API/examples/data/vector_database_wikipedia_articles_embedded.zip\"\n",
"\n",
"# The file is ~700 MB so this will take some time\n",
"wget.download(embeddings_url)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"with zipfile.ZipFile(\"vector_database_wikipedia_articles_embedded.zip\",\"r\") as zip_ref:\n",
" zip_ref.extractall(\"../../data\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>url</th>\n",
" <th>title</th>\n",
" <th>text</th>\n",
" <th>title_vector</th>\n",
" <th>content_vector</th>\n",
" <th>vector_id</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>https://simple.wikipedia.org/wiki/April</td>\n",
" <td>April</td>\n",
" <td>April is the fourth month of the year in the J...</td>\n",
" <td>[0.001009464613161981, -0.020700545981526375, ...</td>\n",
" <td>[-0.011253940872848034, -0.013491976074874401,...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>https://simple.wikipedia.org/wiki/August</td>\n",
" <td>August</td>\n",
" <td>August (Aug.) is the eighth month of the year ...</td>\n",
" <td>[0.0009286514250561595, 0.000820168002974242, ...</td>\n",
" <td>[0.0003609954728744924, 0.007262262050062418, ...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>6</td>\n",
" <td>https://simple.wikipedia.org/wiki/Art</td>\n",
" <td>Art</td>\n",
" <td>Art is a creative activity that expresses imag...</td>\n",
" <td>[0.003393713850528002, 0.0061537534929811954, ...</td>\n",
" <td>[-0.004959689453244209, 0.015772193670272827, ...</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>8</td>\n",
" <td>https://simple.wikipedia.org/wiki/A</td>\n",
" <td>A</td>\n",
" <td>A or a is the first letter of the English alph...</td>\n",
" <td>[0.0153952119871974, -0.013759135268628597, 0....</td>\n",
" <td>[0.024894846603274345, -0.022186409682035446, ...</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>9</td>\n",
" <td>https://simple.wikipedia.org/wiki/Air</td>\n",
" <td>Air</td>\n",
" <td>Air refers to the Earth's atmosphere. Air is a...</td>\n",
" <td>[0.02224554680287838, -0.02044147066771984, -0...</td>\n",
" <td>[0.021524671465158463, 0.018522677943110466, -...</td>\n",
" <td>4</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id url title \\\n",
"0 1 https://simple.wikipedia.org/wiki/April April \n",
"1 2 https://simple.wikipedia.org/wiki/August August \n",
"2 6 https://simple.wikipedia.org/wiki/Art Art \n",
"3 8 https://simple.wikipedia.org/wiki/A A \n",
"4 9 https://simple.wikipedia.org/wiki/Air Air \n",
"\n",
" text \\\n",
"0 April is the fourth month of the year in the J... \n",
"1 August (Aug.) is the eighth month of the year ... \n",
"2 Art is a creative activity that expresses imag... \n",
"3 A or a is the first letter of the English alph... \n",
"4 Air refers to the Earth's atmosphere. Air is a... \n",
"\n",
" title_vector \\\n",
"0 [0.001009464613161981, -0.020700545981526375, ... \n",
"1 [0.0009286514250561595, 0.000820168002974242, ... \n",
"2 [0.003393713850528002, 0.0061537534929811954, ... \n",
"3 [0.0153952119871974, -0.013759135268628597, 0.... \n",
"4 [0.02224554680287838, -0.02044147066771984, -0... \n",
"\n",
" content_vector vector_id \n",
"0 [-0.011253940872848034, -0.013491976074874401,... 0 \n",
"1 [0.0003609954728744924, 0.007262262050062418, ... 1 \n",
"2 [-0.004959689453244209, 0.015772193670272827, ... 2 \n",
"3 [0.024894846603274345, -0.022186409682035446, ... 3 \n",
"4 [0.021524671465158463, 0.018522677943110466, -... 4 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"article_df = pd.read_csv('../../data/vector_database_wikipedia_articles_embedded.csv') \n",
" \n",
"# Read vectors from strings back into a list using json.loads \n",
"article_df[\"title_vector\"] = article_df.title_vector.apply(json.loads) \n",
"article_df[\"content_vector\"] = article_df.content_vector.apply(json.loads) \n",
"article_df['vector_id'] = article_df['vector_id'].apply(str) \n",
"article_df.head() \n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create an index"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"azure-cognitive-search-vector-demo created\n"
]
}
],
"source": [
"# Configure a search index\n",
"index_client = SearchIndexClient(\n",
" endpoint=search_service_endpoint, credential=credential)\n",
"fields = [\n",
" SimpleField(name=\"id\", type=SearchFieldDataType.String),\n",
" SimpleField(name=\"vector_id\", type=SearchFieldDataType.String, key=True),\n",
" SimpleField(name=\"url\", type=SearchFieldDataType.String),\n",
" SearchableField(name=\"title\", type=SearchFieldDataType.String),\n",
" SearchableField(name=\"text\", type=SearchFieldDataType.String),\n",
" SearchField(name=\"title_vector\", type=SearchFieldDataType.Collection(SearchFieldDataType.Single),\n",
" searchable=True, vector_search_dimensions=1536, vector_search_configuration=\"my-vector-config\"),\n",
" SearchField(name=\"content_vector\", type=SearchFieldDataType.Collection(SearchFieldDataType.Single),\n",
" searchable=True, vector_search_dimensions=1536, vector_search_configuration=\"my-vector-config\"),\n",
"]\n",
"\n",
"# Configure the vector search configuration\n",
"vector_search = VectorSearch(\n",
" algorithm_configurations=[\n",
" HnswVectorSearchAlgorithmConfiguration(\n",
" name=\"my-vector-config\",\n",
" kind=\"hnsw\",\n",
" parameters={\n",
" \"m\": 4,\n",
" \"efConstruction\": 400,\n",
" \"efSearch\": 500,\n",
" \"metric\": \"cosine\"\n",
" }\n",
" )\n",
" ]\n",
")\n",
"\n",
"# Optional: configure semantic reranking by passing your title, keywords, and content fields\n",
"semantic_config = SemanticConfiguration(\n",
" name=\"my-semantic-config\",\n",
" prioritized_fields=PrioritizedFields(\n",
" title_field=SemanticField(field_name=\"title\"),\n",
" prioritized_keywords_fields=[SemanticField(field_name=\"url\")],\n",
" prioritized_content_fields=[SemanticField(field_name=\"text\")]\n",
" )\n",
")\n",
"# Create the semantic settings with the configuration\n",
"semantic_settings = SemanticSettings(configurations=[semantic_config])\n",
"\n",
"# Create the index \n",
"index = SearchIndex(name=index_name, fields=fields,\n",
" vector_search=vector_search, semantic_settings=semantic_settings)\n",
"result = index_client.create_or_update_index(index)\n",
"print(f'{result.name} created')\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Insert text and embeddings into vector store\n",
"In this notebook, the wikipedia articles dataset provided by OpenAI, the embeddings are pre-computed. The code below takes the data frame and converts it into a dictionary list to upload to your Azure Search index.\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Uploaded 25000 documents in total\n"
]
}
],
"source": [
"# Convert the 'id' and 'vector_id' columns to string so one of them can serve as our key field \n",
"article_df['id'] = article_df['id'].astype(str) \n",
"article_df['vector_id'] = article_df['vector_id'].astype(str) \n",
" \n",
"# Convert the DataFrame to a list of dictionaries \n",
"documents = article_df.to_dict(orient='records') \n",
" \n",
"# Use SearchIndexingBufferedSender to upload the documents in batches optimized for indexing \n",
"with SearchIndexingBufferedSender(search_service_endpoint, index_name, AzureKeyCredential(search_service_api_key)) as batch_client: \n",
" # Add upload actions for all documents \n",
" batch_client.upload_documents(documents=documents) \n",
" \n",
"print(f\"Uploaded {len(documents)} documents in total\") "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If your dataset didn't already contain pre-computed embeddings, you can create embeddings by using the below function using the `openai` python library. You'll also notice the same function and model are being used to generate query embeddings for performing vector searches."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Content: April is the fourth month of the year in the Julian and Gregorian calendars, and comes between March\n",
"Content vector generated\n"
]
}
],
"source": [
"# Example function to generate document embedding \n",
"def generate_document_embeddings(text): \n",
" response = openai.Embedding.create( \n",
" input=text, engine=model) \n",
" embeddings = response['data'][0]['embedding'] \n",
" return embeddings \n",
" \n",
"# Sampling the first document content as an example \n",
"first_document_content = documents[0]['text'] \n",
"print(f\"Content: {first_document_content[:100]}\") \n",
" \n",
"# Generate the content vector using the `generate_document_embeddings` function \n",
"content_vector = generate_document_embeddings(first_document_content) \n",
"print(f\"Content vector generated\") \n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Perform a vector similarity search"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Title: Documenta\n",
"Score: 0.8599451\n",
"URL: https://simple.wikipedia.org/wiki/Documenta\n",
"\n",
"Title: Museum of Modern Art\n",
"Score: 0.85260946\n",
"URL: https://simple.wikipedia.org/wiki/Museum%20of%20Modern%20Art\n",
"\n",
"Title: Expressionism\n",
"Score: 0.85235393\n",
"URL: https://simple.wikipedia.org/wiki/Expressionism\n",
"\n"
]
}
],
"source": [
"# Function to generate query embedding\n",
"def generate_embeddings(text):\n",
" response = openai.Embedding.create(\n",
" input=text, engine=model)\n",
" embeddings = response['data'][0]['embedding']\n",
" return embeddings\n",
"\n",
"# Pure Vector Search\n",
"query = \"modern art in Europe\"\n",
" \n",
"search_client = SearchClient(search_service_endpoint, index_name, AzureKeyCredential(search_service_api_key)) \n",
"vector = Vector(value=generate_embeddings(query), k=3, fields=\"content_vector\") \n",
" \n",
"results = search_client.search( \n",
" search_text=None, \n",
" vectors=[vector], \n",
" select=[\"title\", \"text\", \"url\"] \n",
")\n",
" \n",
"for result in results: \n",
" print(f\"Title: {result['title']}\") \n",
" print(f\"Score: {result['@search.score']}\") \n",
" print(f\"URL: {result['url']}\\n\") "
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Perform a Hybrid Search"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Title: Wars of Scottish Independence\n",
"Score: 0.03306011110544205\n",
"URL: https://simple.wikipedia.org/wiki/Wars%20of%20Scottish%20Independence\n",
"\n",
"Title: Battle of Bannockburn\n",
"Score: 0.022253260016441345\n",
"URL: https://simple.wikipedia.org/wiki/Battle%20of%20Bannockburn\n",
"\n",
"Title: Scottish\n",
"Score: 0.016393441706895828\n",
"URL: https://simple.wikipedia.org/wiki/Scottish\n",
"\n"
]
}
],
"source": [
"# Hybrid Search\n",
"query = \"Famous battles in Scottish history\" \n",
" \n",
"search_client = SearchClient(search_service_endpoint, index_name, AzureKeyCredential(search_service_api_key)) \n",
"vector = Vector(value=generate_embeddings(query), k=3, fields=\"content_vector\") \n",
" \n",
"results = search_client.search( \n",
" search_text=query, \n",
" vectors=[vector],\n",
" select=[\"title\", \"text\", \"url\"],\n",
" top=3\n",
") \n",
" \n",
"for result in results: \n",
" print(f\"Title: {result['title']}\") \n",
" print(f\"Score: {result['@search.score']}\") \n",
" print(f\"URL: {result['url']}\\n\") "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Perform a Hybrid Search with Reranking (powered by Bing)\n",
"[Semantic search](https://learn.microsoft.com/azure/search/semantic-ranking) allows you to leverage deep neural networks from Microsoft Bing to further increase your search accuracy. Additionally, you can get captions, answers, and highlights. "
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Semantic Answer: The<em> Battle of Bannockburn,</em> fought on 23 and 24 June 1314, was an important Scottish victory in the Wars of Scottish Independence. A smaller Scottish army defeated a much larger and better armed English army. Background When King Alexander III of Scotland died in 1286, his heir was his granddaughter Margaret, Maid of Norway.\n",
"Semantic Answer Score: 0.8857421875\n",
"\n",
"Title: Wars of Scottish Independence\n",
"URL: https://simple.wikipedia.org/wiki/Wars%20of%20Scottish%20Independence\n",
"Caption: Important Figures Scotland King David II King John Balliol King Robert I the Bruce William Wallace England King Edward I King Edward II King Edward III Battles Battle of Bannockburn The Battle of Bannockburn (2324 June 1314) was an important Scottish victory. It was the decisive battle in the First War of Scottish Independence.\n",
"\n",
"Title: Battle of Bannockburn\n",
"URL: https://simple.wikipedia.org/wiki/Battle%20of%20Bannockburn\n",
"Caption: The Battle of Bannockburn, fought on 23 and 24 June 1314, was an important<em> Scottish</em> victory in the Wars of<em> Scottish</em> Independence. A smaller Scottish army defeated a much larger and better armed English army. Background When King Alexander III of Scotland died in 1286, his heir was his granddaughter Margaret, Maid of Norway.\n",
"\n",
"Title: First War of Scottish Independence\n",
"URL: https://simple.wikipedia.org/wiki/First%20War%20of%20Scottish%20Independence\n",
"Caption: The First War of<em> Scottish Independence</em> lasted from the outbreak of the war in 1296 until the 1328. The Scots were defeated at Dunbar on 27 April 1296. John Balliol abdicated in Montrose castle on 10 July 1296.\n",
"\n"
]
}
],
"source": [
"# Semantic Hybrid Search\n",
"query = \"Famous battles in Scottish history\" \n",
"\n",
"search_client = SearchClient(search_service_endpoint, index_name, AzureKeyCredential(search_service_api_key)) \n",
"vector = Vector(value=generate_embeddings(query), k=3, fields=\"content_vector\") \n",
"\n",
"results = search_client.search( \n",
" search_text=query, \n",
" vectors=[vector], \n",
" select=[\"title\", \"text\", \"url\"],\n",
" query_type=\"semantic\", query_language=\"en-us\", semantic_configuration_name='my-semantic-config', query_caption=\"extractive\", query_answer=\"extractive\",\n",
" top=3\n",
")\n",
"\n",
"semantic_answers = results.get_answers()\n",
"for answer in semantic_answers:\n",
" if answer.highlights:\n",
" print(f\"Semantic Answer: {answer.highlights}\")\n",
" else:\n",
" print(f\"Semantic Answer: {answer.text}\")\n",
" print(f\"Semantic Answer Score: {answer.score}\\n\")\n",
"\n",
"for result in results:\n",
" print(f\"Title: {result['title']}\")\n",
" print(f\"URL: {result['url']}\")\n",
" captions = result[\"@search.captions\"]\n",
" if captions:\n",
" caption = captions[0]\n",
" if caption.highlights:\n",
" print(f\"Caption: {caption.highlights}\\n\")\n",
" else:\n",
" print(f\"Caption: {caption.text}\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -138,7 +138,7 @@
"metadata": {},
"outputs": [
{
"name": "stdin",
"name": "stdout",
"output_type": "stream",
"text": [
"Please enter your API Endpoint: https://4f835778-ec78-42b0-9ae3-29e3cf45b596-us-east1.apps.astra.datastax.com\n",
@ -222,7 +222,7 @@
"metadata": {},
"outputs": [
{
"name": "stdin",
"name": "stdout",
"output_type": "stream",
"text": [
"Please enter your OpenAI API Key: ········\n"
@ -251,7 +251,7 @@
"outputs": [],
"source": [
"client = openai.OpenAI(api_key=OPENAI_API_KEY)\n",
"embedding_model_name = \"text-embedding-ada-002\"\n",
"embedding_model_name = \"text-embedding-3-small\"\n",
"\n",
"result = client.embeddings.create(\n",
" input=[\n",

@ -165,7 +165,7 @@
"metadata": {},
"outputs": [
{
"name": "stdin",
"name": "stdout",
"output_type": "stream",
"text": [
"Please provide the full path to your Secure Connect Bundle zipfile: /path/to/secure-connect-DatabaseName.zip\n",
@ -406,7 +406,7 @@
"metadata": {},
"outputs": [
{
"name": "stdin",
"name": "stdout",
"output_type": "stream",
"text": [
"Please enter your OpenAI API Key: ········\n"
@ -435,7 +435,7 @@
"outputs": [],
"source": [
"client = openai.OpenAI(api_key=OPENAI_API_KEY)\n",
"embedding_model_name = \"text-embedding-ada-002\"\n",
"embedding_model_name = \"text-embedding-3-small\"\n",
"\n",
"result = client.embeddings.create(\n",
" input=[\n",

@ -145,7 +145,7 @@
"metadata": {},
"outputs": [
{
"name": "stdin",
"name": "stdout",
"output_type": "stream",
"text": [
"Please enter your Astra token ('AstraCS:...') ········\n",
@ -223,7 +223,7 @@
"metadata": {},
"outputs": [
{
"name": "stdin",
"name": "stdout",
"output_type": "stream",
"text": [
"Please enter your OpenAI API Key: ········\n"
@ -252,7 +252,7 @@
"outputs": [],
"source": [
"client = openai.OpenAI(api_key=OPENAI_API_KEY)\n",
"embedding_model_name = \"text-embedding-ada-002\"\n",
"embedding_model_name = \"text-embedding-3-small\"\n",
"\n",
"result = client.embeddings.create(\n",
" input=[\n",

@ -265,7 +265,7 @@
"import chromadb\n",
"\n",
"# I've set this to our new embeddings model, this can be changed to the embedding model of your choice\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
"EMBEDDING_MODEL = \"text-embedding-3-small\"\n",
"\n",
"# Ignore unclosed SSL socket warnings - optional in case you get these errors\n",
"import warnings\n",

@ -306,7 +306,7 @@
"## Encode a question with OpenAI embedding model\n",
"\n",
"To perform kNN search, we need to encode queries with the same embedding model used to encode the documents at index time.\n",
"In this example, we need to use the `text-embedding-ada-002` model.\n",
"In this example, we need to use the `text-embedding-3-small` model.\n",
"\n",
"You'll need your OpenAI [API key](https://platform.openai.com/account/api-keys) to generate the embeddings."
]
@ -325,7 +325,7 @@
"openai.api_key = OPENAI_API_KEY\n",
"\n",
"# Define model\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
"EMBEDDING_MODEL = \"text-embedding-3-small\"\n",
"\n",
"# Define question\n",
"question = 'Is the Atlantic the biggest ocean in the world?'\n",

@ -314,7 +314,7 @@
"## Encode a question with OpenAI embedding model\n",
"\n",
"To perform semantic search, we need to encode queries with the same embedding model used to encode the documents at index time.\n",
"In this example, we need to use the `text-embedding-ada-002` model.\n",
"In this example, we need to use the `text-embedding-3-small` model.\n",
"\n",
"You'll need your OpenAI [API key](https://platform.openai.com/account/api-keys) to generate the embeddings."
]
@ -333,7 +333,7 @@
"openai.api_key = OPENAI_API_KEY\n",
"\n",
"# Define model\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
"EMBEDDING_MODEL = \"text-embedding-3-small\"\n",
"\n",
"# Define question\n",
"question = 'Is the Atlantic the biggest ocean in the world?'\n",

@ -707,7 +707,7 @@
"source": [
"## Search data\n",
"\n",
"Once the data is uploaded we will start querying the collection for the closest vectors. We may provide an additional parameter `vector_name` to switch from title to content based search. Since the precomputed embeddings were created with `text-embedding-ada-002` OpenAI model we also have to use it during search.\n"
"Once the data is uploaded we will start querying the collection for the closest vectors. We may provide an additional parameter `vector_name` to switch from title to content based search. Since the precomputed embeddings were created with `text-embedding-3-small` OpenAI model we also have to use it during search.\n"
]
},
{
@ -727,7 +727,7 @@
" # Creates embedding vector from user query\n",
" embedded_query = openai.Embedding.create(\n",
" input=query,\n",
" model=\"text-embedding-ada-002\",\n",
" model=\"text-embedding-3-small\",\n",
" )[\"data\"][0][\"embedding\"]\n",
"\n",
" # Convert the embedded_query to PostgreSQL compatible format\n",

File diff suppressed because one or more lines are too long

@ -175,7 +175,7 @@
"PORT = 19530\n",
"COLLECTION_NAME = 'movie_search'\n",
"DIMENSION = 1536\n",
"OPENAI_ENGINE = 'text-embedding-ada-002'\n",
"OPENAI_ENGINE = 'text-embedding-3-small'\n",
"openai.api_key = 'sk-your_key'\n",
"\n",
"INDEX_PARAM = {\n",

@ -275,7 +275,7 @@
"PORT = 19530\n",
"COLLECTION_NAME = 'book_search'\n",
"DIMENSION = 1536\n",
"OPENAI_ENGINE = 'text-embedding-ada-002'\n",
"OPENAI_ENGINE = 'text-embedding-3-small'\n",
"openai.api_key = 'sk-your_key'\n",
"\n",
"INDEX_PARAM = {\n",

@ -1,43 +1,29 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "flQYAhphJC5m"
},
"source": [
"\n",
"This notebook demonstrates how to build a semantic search application using OpenAI and [MongoDB Atlas vector search](https://www.mongodb.com/products/platform/atlas-vector-search)"
],
"metadata": {
"id": "flQYAhphJC5m"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iYMn0dXXdFbY",
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "iYMn0dXXdFbY",
"outputId": "98dab421-f11b-40b8-8f82-6de42b25725a"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting pymongo\n",
" Downloading pymongo-4.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (677 kB)\n",
@ -80,6 +66,9 @@
},
{
"cell_type": "markdown",
"metadata": {
"id": "vLuKFvTwJXpu"
},
"source": [
"# Step 1: Setup the environment\n",
"\n",
@ -88,27 +77,18 @@
"1. **MongoDB Atlas cluster**: To create a forever free MongoDB Atlas cluster, first, you need to create a MongoDB Atlas account if you don't already have one. Visit the [MongoDB Atlas website](https://www.mongodb.com/atlas/database) and click on “Register.” Visit the [MongoDB Atlas](https://account.mongodb.com/account/login) dashboard and set up your cluster. In order to take advantage of the `$vectorSearch` operator in an aggregation pipeline, you need to run MongoDB Atlas 6.0.11 or higher. This tutorial can be built using a free cluster. When youre setting up your deployment, youll be prompted to set up a database user and rules for your network connection. Please ensure you save your username and password somewhere safe and have the correct IP address rules in place so your cluster can connect properly. If you need more help getting started, check out our [tutorial on MongoDB Atlas](https://www.mongodb.com/basics/mongodb-atlas-tutorial).\n",
"\n",
"2. **OpenAI API key** To create your OpenAI key, you'll need to create an account. Once you have that, visit the [OpenAI platform](https://platform.openai.com/). Click on your profile icon in the top right of the screen to get the dropdown menu and select “View API keys”.\n"
],
"metadata": {
"id": "vLuKFvTwJXpu"
}
]
},
{
"cell_type": "code",
"source": [
"import getpass\n",
"\n",
"MONGODB_ATLAS_CLUSTER_URI = getpass.getpass(\"MongoDB Atlas Cluster URI:\")\n",
"OPENAI_API_KEY = getpass.getpass(\"OpenAI API Key:\")\n"
],
"execution_count": null,
"metadata": {
"id": "qJHHIIKjIFUZ",
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qJHHIIKjIFUZ",
"outputId": "57ad72d4-8afb-4e34-aad1-1fea6eb3645b"
},
"execution_count": null,
"outputs": [
{
"name": "stdout",
@ -118,29 +98,40 @@
"OpenAI API Key:··········\n"
]
}
],
"source": [
"import getpass\n",
"\n",
"MONGODB_ATLAS_CLUSTER_URI = getpass.getpass(\"MongoDB Atlas Cluster URI:\")\n",
"OPENAI_API_KEY = getpass.getpass(\"OpenAI API Key:\")\n"
]
},
{
"cell_type": "markdown",
"source": [
"Note: After executing the step above you will be prompted to enter the credentials."
],
"metadata": {
"id": "Sarx9wdxb4Rr"
}
},
"source": [
"Note: After executing the step above you will be prompted to enter the credentials."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sk1xXoyxMfil"
},
"source": [
"For this tutorial, we will be using the\n",
"[MongoDB sample dataset](https://www.mongodb.com/docs/atlas/sample-data/). Load the sample dataset using the Atlas UI. We'll be using the “sample_mflix” database, which contains a “movies” collection where each document contains fields like title, plot, genres, cast, directors, etc.\n"
],
"metadata": {
"id": "sk1xXoyxMfil"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "k-G6WhNFdIvW"
},
"outputs": [],
"source": [
"import openai\n",
"import pymongo\n",
@ -150,72 +141,56 @@
"collection = db.movies\n",
"\n",
"openai.api_key = OPENAI_API_KEY"
],
"metadata": {
"id": "k-G6WhNFdIvW"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"source": [
"ATLAS_VECTOR_SEARCH_INDEX_NAME = \"default\"\n",
"EMBEDDING_FIELD_NAME = \"embedding_openai_nov19_23\""
],
"execution_count": null,
"metadata": {
"id": "On9e13ASwReq"
},
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"ATLAS_VECTOR_SEARCH_INDEX_NAME = \"default\"\n",
"EMBEDDING_FIELD_NAME = \"embedding_openai_nov19_23\""
]
},
{
"cell_type": "markdown",
"source": [
"# Step 2: Setup embeddings generation function"
],
"metadata": {
"id": "X-9gl2s-uGtw"
}
},
"source": [
"# Step 2: Setup embeddings generation function"
]
},
{
"cell_type": "code",
"source": [
"model = \"text-embedding-ada-002\"\n",
"def generate_embedding(text: str) -> list[float]:\n",
" return openai.embeddings.create(input = [text], model=model).data[0].embedding\n"
],
"execution_count": null,
"metadata": {
"id": "BMnE4BxSOCtH"
},
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"model = \"text-embedding-3-small\"\n",
"def generate_embedding(text: str) -> list[float]:\n",
" return openai.embeddings.create(input = [text], model=model).data[0].embedding\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "snSjiSNKwX6Z"
},
"source": [
"# Step 3: Create and store embeddings\n",
"\n",
"Each document in the sample dataset sample_mflix.movies corresponds to a movie; we will execute an operation to create a vector embedding for the data in the \"plot\" field and store it in the database. Creating vector embeddings using OpenAI embeddings endpoint is necessary for performing a similarity search based on intent."
],
"metadata": {
"id": "snSjiSNKwX6Z"
}
]
},
{
"cell_type": "code",
"source": [
"from pymongo import ReplaceOne\n",
"\n",
"# Update the collection with the embeddings\n",
"requests = []\n",
"\n",
"for doc in collection.find({'plot':{\"$exists\": True}}).limit(500):\n",
" doc[EMBEDDING_FIELD_NAME] = generate_embedding(doc['plot'])\n",
" requests.append(ReplaceOne({'_id': doc['_id']}, doc))\n",
"\n",
"collection.bulk_write(requests)"
],
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -223,42 +198,56 @@
"id": "t4i9gQM2xUFF",
"outputId": "ae558b67-9b06-4c83-c52a-a8047ecd40d5"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"BulkWriteResult({'writeErrors': [], 'writeConcernErrors': [], 'nInserted': 0, 'nUpserted': 0, 'nMatched': 50, 'nModified': 50, 'nRemoved': 0, 'upserted': []}, acknowledged=True)"
]
},
"execution_count": 15,
"metadata": {},
"execution_count": 15
"output_type": "execute_result"
}
],
"source": [
"from pymongo import ReplaceOne\n",
"\n",
"# Update the collection with the embeddings\n",
"requests = []\n",
"\n",
"for doc in collection.find({'plot':{\"$exists\": True}}).limit(500):\n",
" doc[EMBEDDING_FIELD_NAME] = generate_embedding(doc['plot'])\n",
" requests.append(ReplaceOne({'_id': doc['_id']}, doc))\n",
"\n",
"collection.bulk_write(requests)"
]
},
{
"cell_type": "markdown",
"source": [
"After executing the above, the documents in \"movies\" collection will contain an additional field of \"embedding\", as defined by the `EMBEDDDING_FIELD_NAME` variable, apart from already existing fields like title, plot, genres, cast, directors, etc."
],
"metadata": {
"id": "ShPbxQPaPvHD"
}
},
"source": [
"After executing the above, the documents in \"movies\" collection will contain an additional field of \"embedding\", as defined by the `EMBEDDDING_FIELD_NAME` variable, apart from already existing fields like title, plot, genres, cast, directors, etc."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Coq0tyjXyNIu"
},
"source": [
"Note: We are restricting this to just 500 documents in the interest of time. If you want to do this over the entire dataset of 23,000+ documents in our sample_mflix database, it will take a little while. Alternatively, you can use the [sample_mflix.embedded_movies collection](https://www.mongodb.com/docs/atlas/sample-data/sample-mflix/#sample_mflix.embedded_movies) which includes a pre-populated `plot_embedding` field that contains embeddings created using `OpenAI's text-embedding-ada-002` embedding model that you can use with the Atlas Search vector search feature.\n",
"Note: We are restricting this to just 500 documents in the interest of time. If you want to do this over the entire dataset of 23,000+ documents in our sample_mflix database, it will take a little while. Alternatively, you can use the [sample_mflix.embedded_movies collection](https://www.mongodb.com/docs/atlas/sample-data/sample-mflix/#sample_mflix.embedded_movies) which includes a pre-populated `plot_embedding` field that contains embeddings created using OpenAI's `text-embedding-3-small` embedding model that you can use with the Atlas Search vector search feature.\n",
"\n",
"\n"
],
"metadata": {
"id": "Coq0tyjXyNIu"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rCRCK6QOskqo"
},
"source": [
"# Step 4: Create a vector search index\n",
"\n",
@ -266,13 +255,13 @@
"We will cover 2 ways to create this index - Atlas UI and using MongoDB python driver.\n",
"\n",
"(Optional) [Documentation: Create a Vector Search Index ](https://www.mongodb.com/docs/atlas/atlas-search/field-types/knn-vector/)"
],
"metadata": {
"id": "rCRCK6QOskqo"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ymRTaFb1X5Tq"
},
"source": [
"Now head over to [Atlas UI](cloud.mongodb.com) and create an Atlas Vector Search index using the steps descibed [here](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-tutorial/#create-the-atlas-vector-search-index). The 'dimensions' field with value 1536, corresponds to openAI text-embedding-ada002.\n",
"\n",
@ -292,36 +281,21 @@
" }\n",
"}\n",
"```"
],
"metadata": {
"id": "ymRTaFb1X5Tq"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2l5BzUgncjiq"
},
"source": [
"(Optional) Alternatively, we can use [pymongo driver to create these vector search indexes programatically](https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.create_search_index)\n",
"The python command given in the cell below will create the index (this only works for the most recent version of the Python Driver for MongoDB and MongoDB server version 7.0+ Atlas cluster)."
],
"metadata": {
"id": "2l5BzUgncjiq"
}
]
},
{
"cell_type": "code",
"source": [
"collection.create_search_index(\n",
" {\"definition\":\n",
" {\"mappings\": {\"dynamic\": True, \"fields\": {\n",
" EMBEDDING_FIELD_NAME : {\n",
" \"dimensions\": 1536,\n",
" \"similarity\": \"dotProduct\",\n",
" \"type\": \"knnVector\"\n",
" }}}},\n",
" \"name\": ATLAS_VECTOR_SEARCH_INDEX_NAME\n",
" }\n",
")"
],
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@ -330,38 +304,55 @@
"id": "54OWgiaPcmD0",
"outputId": "2cb9d1d8-4515-49ad-9fe7-5b4fa3c6c86b"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'default'"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
}
},
"text/plain": [
"'default'"
]
},
"execution_count": 16,
"metadata": {},
"execution_count": 16
"output_type": "execute_result"
}
],
"source": [
"collection.create_search_index(\n",
" {\"definition\":\n",
" {\"mappings\": {\"dynamic\": True, \"fields\": {\n",
" EMBEDDING_FIELD_NAME : {\n",
" \"dimensions\": 1536,\n",
" \"similarity\": \"dotProduct\",\n",
" \"type\": \"knnVector\"\n",
" }}}},\n",
" \"name\": ATLAS_VECTOR_SEARCH_INDEX_NAME\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6V9QKgm8caNb"
},
"source": [
"# Step 5: Query your data\n",
"\n",
"The results for the query here finds movies which have semantically similar plots to the text captured in the query string, rather than being based on the keyword search.\n",
"\n",
"(Optional) [Documentation: Run Vector Search Queries](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/)"
],
"metadata": {
"id": "6V9QKgm8caNb"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "34tib9TrMPg4"
},
"outputs": [],
"source": [
"\n",
"def query_results(query, k):\n",
@ -377,28 +368,37 @@
" }\n",
" ])\n",
" return results"
],
"metadata": {
"id": "34tib9TrMPg4"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"id": "kTqrip-hWULK"
},
"outputs": [],
"source": [
"query=\"imaginary characters from outerspace at war with earthlings\"\n",
"movies = query_results(query, 5)\n",
"\n",
"for movie in movies:\n",
" print(f'Movie Name: {movie[\"title\"]},\\nMovie Plot: {movie[\"plot\"]}\\n')"
],
"metadata": {
"collapsed": true,
"id": "kTqrip-hWULK"
},
"execution_count": null,
"outputs": []
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
]
}
},
"nbformat": 4,
"nbformat_minor": 0
}

@ -73,429 +73,14 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-16T12:05:05.730338Z",
"start_time": "2023-02-16T12:05:05.723351Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"<OpenAIObject list at 0x118768f40> JSON: {\n",
" \"data\": [\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"babbage\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"davinci\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-davinci-edit-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"babbage-code-search-code\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-similarity-babbage-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"code-davinci-edit-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-davinci-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-davinci-003\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-internal\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"ada\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"babbage-code-search-text\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"babbage-similarity\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"code-search-babbage-text-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-curie-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"whisper-1\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-internal\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"code-search-babbage-code-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-ada-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-embedding-ada-002\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-internal\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-similarity-ada-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"curie-instruct-beta\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"ada-code-search-code\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"ada-similarity\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"code-search-ada-text-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-search-ada-query-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"davinci-search-document\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"gpt-3.5-turbo-0301\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"ada-code-search-text\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-search-ada-doc-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"davinci-instruct-beta\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-similarity-curie-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"code-search-ada-code-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"ada-search-query\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-search-davinci-query-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"curie-search-query\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"davinci-search-query\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"babbage-search-document\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"ada-search-document\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-search-curie-query-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-search-babbage-doc-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"curie-search-document\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-search-curie-doc-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"babbage-search-query\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-babbage-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-search-davinci-doc-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"gpt-3.5-turbo\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-search-babbage-query-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"curie-similarity\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"curie\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-similarity-davinci-001\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"text-davinci-002\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" },\n",
" {\n",
" \"created\": null,\n",
" \"id\": \"davinci-similarity\",\n",
" \"object\": \"engine\",\n",
" \"owner\": \"openai-dev\",\n",
" \"permissions\": null,\n",
" \"ready\": true\n",
" }\n",
" ],\n",
" \"object\": \"list\"\n",
"}"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"import openai\n",
"\n",
@ -744,7 +329,7 @@
"# creates embedding vector from user query\n",
"embed = openai.Embedding.create(\n",
" input=query,\n",
" model=\"text-embedding-ada-002\",\n",
" model=\"text-embedding-3-small\",\n",
")[\"data\"][0][\"embedding\"]\n",
"\n",
"# query the database to find the top K similar content to the given query\n",

@ -74,7 +74,7 @@
"import clickhouse_connect\n",
"\n",
"# I've set this to our new embeddings model, this can be changed to the embedding model of your choice\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
"EMBEDDING_MODEL = \"text-embedding-3-small\"\n",
"\n",
"# Ignore unclosed SSL socket warnings - optional in case you get these errors\n",
"import warnings\n",
@ -476,7 +476,7 @@
"# creates embedding vector from user query\n",
"embed = openai.Embedding.create(\n",
" input=query,\n",
" model=\"text-embedding-ada-002\",\n",
" model=\"text-embedding-3-small\",\n",
")[\"data\"][0][\"embedding\"]\n",
"\n",
"# query the database to find the top K similar content to the given query\n",

@ -391,7 +391,7 @@
"\n",
"After the data is stored in your Neon database, you can query the data for nearest neighbors. \n",
"\n",
"Start by defining the `query_neon` function, which is executed when you run the vector similarity search. The function creates an embedding based on the user's query, prepares the SQL query, and runs the SQL query with the embedding. The pre-computed embeddings that you loaded into your database were created with `text-embedding-ada-002` OpenAI model, so you must use the same model to create an embedding for the similarity search.\n",
"Start by defining the `query_neon` function, which is executed when you run the vector similarity search. The function creates an embedding based on the user's query, prepares the SQL query, and runs the SQL query with the embedding. The pre-computed embeddings that you loaded into your database were created with `text-embedding-3-small` OpenAI model, so you must use the same model to create an embedding for the similarity search.\n",
"\n",
"A `vector_name` parameter is provided that allows you to search based on \"title\" or \"content\"."
]
@ -407,7 +407,7 @@
" # Create an embedding vector from the user query\n",
" embedded_query = openai.Embedding.create(\n",
" input=query,\n",
" model=\"text-embedding-ada-002\",\n",
" model=\"text-embedding-3-small\",\n",
" )[\"data\"][0][\"embedding\"]\n",
"\n",
" # Convert the embedded_query to PostgreSQL compatible format\n",

@ -387,7 +387,7 @@
"source": [
"## Initialize Embedding Model\n",
"\n",
"We use `text-embedding-ada-002` as the embedding model. We can embed text like so:"
"We use `text-embedding-3-small` as the embedding model. We can embed text like so:"
]
},
{
@ -401,9 +401,9 @@
"import openai\n",
"\n",
"# initialize openai API key\n",
"openai.api_key = \"sk-...\" #platform.openai.com\n",
"openai.api_key = \"sk-...\"\n",
"\n",
"embed_model = \"text-embedding-ada-002\"\n",
"embed_model = \"text-embedding-3-small\"\n",
"\n",
"res = openai.Embedding.create(\n",
" input=[\n",
@ -456,7 +456,7 @@
"id": "zStnHFpkNVIU"
},
"source": [
"Inside `'data'` we will find two records, one for each of the two sentences we just embedded. Each vector embedding contains `1536` dimensions (the output dimensionality of the `text-embedding-ada-002` model."
"Inside `'data'` we will find two records, one for each of the two sentences we just embedded. Each vector embedding contains `1536` dimensions (the output dimensionality of the `text-embedding-3-small` model."
]
},
{
@ -598,7 +598,7 @@
"id": "ezSTzN2rPa2o"
},
"source": [
"We can see the index is currently empty with a `total_vector_count` of `0`. We can begin populating it with OpenAI `text-embedding-ada-002` built embeddings like so:"
"We can see the index is currently empty with a `total_vector_count` of `0`. We can begin populating it with OpenAI `text-embedding-3-small` built embeddings like so:"
]
},
{

@ -104,9 +104,9 @@
"source": [
"query = \"who was the 12th person on the moon and when did they land?\"\n",
"\n",
"# now query text-davinci-003 WITHOUT context\n",
"# now query `gpt-3.5-turbo-instruct` WITHOUT context\n",
"res = openai.Completion.create(\n",
" engine='text-davinci-003',\n",
" engine='gpt-3.5-turbo-instruct',\n",
" prompt=query,\n",
" temperature=0,\n",
" max_tokens=400,\n",
@ -138,9 +138,8 @@
"outputs": [],
"source": [
"def complete(prompt):\n",
" # query text-davinci-003\n",
" res = openai.Completion.create(\n",
" engine='text-davinci-003',\n",
" engine='gpt-3.5-turbo-instruct',\n",
" prompt=prompt,\n",
" temperature=0,\n",
" max_tokens=400,\n",
@ -240,7 +239,7 @@
"\n",
"We will need to retrieve information that is semantically related to our queries, to do this we need to use _\"dense vector embeddings\"_. These can be thought of as numerical representations of the *meaning* behind our sentences.\n",
"\n",
"To create these dense vectors we use the `text-embedding-ada-002` model.\n",
"To create these dense vectors we use the `text-embedding-3-small` model.\n",
"\n",
"We have already authenticated our OpenAI connection, to create an embedding we just do:"
]
@ -304,7 +303,7 @@
"id": "MwSk-wiK62KO"
},
"source": [
"Inside `'data'` we will find two records, one for each of the two sentences we just embedded. Each vector embedding contains `1536` dimensions (the output dimensionality of the `text-embedding-ada-002` model."
"Inside `'data'` we will find two records, one for each of the two sentences we just embedded. Each vector embedding contains `1536` dimensions (the output dimensionality of the `text-embedding-3-small` model."
]
},
{
@ -615,7 +614,7 @@
"id": "nELBmqxxzeqL"
},
"source": [
"We can see the index is currently empty with a `total_vector_count` of `0`. We can begin populating it with OpenAI `text-embedding-ada-002` built embeddings like so:"
"We can see the index is currently empty with a `total_vector_count` of `0`. We can begin populating it with OpenAI `text-embedding-3-small` built embeddings like so:"
]
},
{
@ -673,13 +672,13 @@
" # get texts to encode\n",
" texts = [x['text'] for x in meta_batch]\n",
" # create embeddings (try-except added to avoid RateLimitError)\n",
" done = False\n",
" while not done:\n",
" try:\n",
" res = openai.Embedding.create(input=texts, engine=embed_model)\n",
" done = True\n",
" except:\n",
" sleep(5)\n",
" done = False\n",
" while not done:\n",
" try:\n",
" res = openai.Embedding.create(input=texts, engine=embed_model)\n",
" done = True\n",
" except:\n",
" sleep(5)\n",
" embeds = [record['embedding'] for record in res['data']]\n",
" # cleanup metadata\n",
" meta_batch = [{\n",

File diff suppressed because one or more lines are too long

@ -105,7 +105,7 @@
"import pinecone\n",
"\n",
"# I've set this to our new embeddings model, this can be changed to the embedding model of your choice\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
"EMBEDDING_MODEL = \"text-embedding-3-small\"\n",
"\n",
"# Ignore unclosed SSL socket warnings - optional in case you get these errors\n",
"import warnings\n",

@ -60,10 +60,10 @@
"Successfully installed certifi-2023.5.7 grpcio-1.56.0 grpcio-tools-1.56.0 h11-0.14.0 h2-4.1.0 hpack-4.0.0 httpcore-0.17.2 httpx-0.24.1 hyperframe-6.0.1 numpy-1.25.0 portalocker-2.7.0 protobuf-4.23.3 pydantic-1.10.9 qdrant-client-1.3.1 typing-extensions-4.5.0 urllib3-1.26.16\r\n",
"Collecting wget\r\n",
" Using cached wget-3.2.zip (10 kB)\r\n",
" Preparing metadata (setup.py) ... \u001B[?25ldone\r\n",
"\u001B[?25hBuilding wheels for collected packages: wget\r\n",
" Building wheel for wget (setup.py) ... \u001B[?25ldone\r\n",
"\u001B[?25h Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9657 sha256=eb5f15f12150fc304e7b14973424f696fa8d95225772bc0cbc0b318bf92e04b9\r\n",
" Preparing metadata (setup.py) ... \u001b[?25ldone\r\n",
"\u001b[?25hBuilding wheels for collected packages: wget\r\n",
" Building wheel for wget (setup.py) ... \u001b[?25ldone\r\n",
"\u001b[?25h Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9657 sha256=eb5f15f12150fc304e7b14973424f696fa8d95225772bc0cbc0b318bf92e04b9\r\n",
" Stored in directory: /home/user/.cache/pip/wheels/04/5f/3e/46cc37c5d698415694d83f607f833f83f0149e49b3af9d0f38\r\n",
"Successfully built wget\r\n",
"Installing collected packages: wget\r\n",
@ -104,7 +104,7 @@
"import qdrant_client\n",
"\n",
"# I've set this to our new embeddings model, this can be changed to the embedding model of your choice\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
"EMBEDDING_MODEL = \"text-embedding-3-small\"\n",
"\n",
"# Ignore unclosed SSL socket warnings - optional in case you get these errors\n",
"import warnings\n",
@ -136,7 +136,9 @@
"outputs": [
{
"data": {
"text/plain": "'vector_database_wikipedia_articles_embedded.zip'"
"text/plain": [
"'vector_database_wikipedia_articles_embedded.zip'"
]
},
"execution_count": 5,
"metadata": {},
@ -700,9 +702,9 @@
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"display_name": "Python 3 (ipykernel)"
"name": "python3"
},
"language_info": {
"codemirror_mode": {

@ -74,7 +74,7 @@
"import redis\n",
"\n",
"# I've set this to our new embeddings model, this can be changed to the embedding model of your choice\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
"EMBEDDING_MODEL = \"text-embedding-3-small\"\n",
"\n",
"# Ignore unclosed SSL socket warnings - optional in case you get these errors\n",
"import warnings\n",

@ -518,7 +518,7 @@
"\n",
" # Creates embedding vector from user query\n",
" embedded_query = openai.Embedding.create(input=user_query,\n",
" model=\"text-embedding-ada-002\",\n",
" model=\"text-embedding-3-small\",\n",
" )[\"data\"][0]['embedding']\n",
"\n",
" # Prepare the Query\n",

@ -298,7 +298,6 @@
}
],
"source": [
"# imports\n",
"import pandas as pd\n",
"import numpy as np\n",
"from typing import List\n",
@ -311,8 +310,7 @@
" indices_of_nearest_neighbors_from_distances,\n",
")\n",
"\n",
"# constants\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
"EMBEDDING_MODEL = \"text-embedding-3-small\"\n",
"\n",
"# load in data and clean data types and drop null rows\n",
"df = pd.read_csv(\"../../data/styles_2k.csv\", on_bad_lines='skip')\n",
@ -633,7 +631,7 @@
"\n",
" # Use OpenAI to create embedding vector from user query\n",
" embedded_query = openai.Embedding.create(input=user_query,\n",
" model=\"text-embedding-ada-002\",\n",
" model=\"text-embedding-3-small\",\n",
" )[\"data\"][0]['embedding']\n",
"\n",
" # Prepare the Query\n",

@ -75,7 +75,7 @@
"load_dotenv()\n",
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
"\n",
"def get_vector(text, model=\"text-embedding-ada-002\"):\n",
"def get_vector(text, model=\"text-embedding-3-small\"):\n",
" text = text.replace(\"\\n\", \" \")\n",
" return openai.Embedding.create(input = [text], model = model)['data'][0]['embedding']\n",
"\n",

@ -251,7 +251,7 @@
"import openai\n",
"\n",
"directory = './assets/'\n",
"model='text-embedding-ada-002'\n",
"model='text-embedding-3-small'\n",
"i = 1\n",
"for file in os.listdir(directory):\n",
" with open(os.path.join(directory, file)) as f:\n",

@ -41,7 +41,7 @@ create table documents (
Since Supabase is built on Postgres, we're just using regular SQL here. You can modify this table however you like to better fit your application. If you have existing database tables, you can simply add a new `vector` column to the appropriate table.
The important piece to understand is the `vector` data type, which is a new data type that became available when we enabled the pgvector extension earlier. The size of the vector (1536 here) represents the number of dimensions in the embedding. Since we're using OpenAI's `text-embedding-ada-002` model in this example, we set the vector size to 1536.
The important piece to understand is the `vector` data type, which is a new data type that became available when we enabled the pgvector extension earlier. The size of the vector (1536 here) represents the number of dimensions in the embedding. Since we're using OpenAI's `text-embedding-3-small` model in this example, we set the vector size to 1536.
Let's go ahead and create a vector index on this table so that future queries remain performant as the table grows:
@ -74,27 +74,27 @@ npm install openai
then import it:
```js
import OpenAI from 'openai';
import OpenAI from "openai";
```
If you're using Deno or Supabase Edge Functions, you can import `openai` directly from a URL:
```js
import OpenAI from 'https://esm.sh/openai@4';
import OpenAI from "https://esm.sh/openai@4";
```
> In this example we import from https://esm.sh which is a CDN that automatically fetches the respective NPM module for you and serves it over HTTP.
Next we'll generate an OpenAI embedding using [`text-embedding-ada-002`](https://platform.openai.com/docs/guides/embeddings/embedding-models):
Next we'll generate an OpenAI embedding using [`text-embedding-3-small`](https://platform.openai.com/docs/guides/embeddings/embedding-models):
```js
const openai = new OpenAI();
const input = 'The cat chases the mouse';
const input = "The cat chases the mouse";
const result = await openai.embeddings.create({
input,
model: 'text-embedding-ada-002',
model: "text-embedding-3-small",
});
const [{ embedding }] = result.data;
@ -104,7 +104,7 @@ Remember that you will need an [OpenAI API key](https://platform.openai.com/api-
```js
const openai = new OpenAI({
apiKey: '<openai-api-key>',
apiKey: "<openai-api-key>",
});
```
@ -125,13 +125,13 @@ npm install @supabase/supabase-js
then import it:
```js
import { createClient } from '@supabase/supabase-js';
import { createClient } from "@supabase/supabase-js";
```
If you're using Deno or Supabase Edge Functions, you can import `@supabase/supabase-js` directly from a URL:
```js
import { createClient } from 'https://esm.sh/@supabase/supabase-js@2';
import { createClient } from "https://esm.sh/@supabase/supabase-js@2";
```
Next we'll instantiate our Supabase client and configure it so that it points to your Supabase project. In this guide we'll store a reference to your Supabase URL and key in a `.env` file, but feel free to modify this based on how your application handles configuration.
@ -158,32 +158,32 @@ npm install dotenv
And retrieve the environment variables from `process.env`:
```js
import { config } from 'dotenv';
import { config } from "dotenv";
// Load .env file
config();
const supabaseUrl = process.env['SUPABASE_URL'];
const supabaseServiceRoleKey = process.env['SUPABASE_SERVICE_ROLE_KEY'];
const supabaseUrl = process.env["SUPABASE_URL"];
const supabaseServiceRoleKey = process.env["SUPABASE_SERVICE_ROLE_KEY"];
```
In Deno, load the `.env` file using the `dotenv` standard library:
```js
import { load } from 'https://deno.land/std@0.208.0/dotenv/mod.ts';
import { load } from "https://deno.land/std@0.208.0/dotenv/mod.ts";
// Load .env file
const env = await load();
const supabaseUrl = env['SUPABASE_URL'];
const supabaseServiceRoleKey = env['SUPABASE_SERVICE_ROLE_KEY'];
const supabaseUrl = env["SUPABASE_URL"];
const supabaseServiceRoleKey = env["SUPABASE_SERVICE_ROLE_KEY"];
```
In Supabase Edge Functions, simply load the injected environment variables directly:
```js
const supabaseUrl = Deno.env.get('SUPABASE_URL');
const supabaseServiceRoleKey = Deno.env.get('SUPABASE_SERVICE_ROLE_KEY');
const supabaseUrl = Deno.env.get("SUPABASE_URL");
const supabaseServiceRoleKey = Deno.env.get("SUPABASE_SERVICE_ROLE_KEY");
```
Next let's instantiate our `supabase` client:
@ -197,7 +197,7 @@ const supabase = createClient(supabaseUrl, supabaseServiceRoleKey, {
From here we use the `supabase` client to insert our text and embedding (generated earlier) into the database:
```js
const { error } = await supabase.from('documents').insert({
const { error } = await supabase.from("documents").insert({
content: input,
embedding,
});
@ -241,23 +241,23 @@ Inside the function we implement the query which does two things:
Now we can call this function from our application using the `supabase.rpc()` method:
```js
const query = 'What does the cat chase?';
const query = "What does the cat chase?";
// First create an embedding on the query itself
const result = await openai.embeddings.create({
input: query,
model: 'text-embedding-ada-002',
model: "text-embedding-3-small",
});
const [{ embedding }] = result.data;
// Then use this embedding to search for matches
const { data: documents, error: matchError } = await supabase
.rpc('match_documents', {
.rpc("match_documents", {
query_embedding: embedding,
match_threshold: 0.8,
})
.select('content')
.select("content")
.limit(5);
```

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save