openai-cookbook/examples/third_party_examples/Visualizing_embeddings_in_Kangas.ipynb
Douglas Blank d6acc8894f Added notebook example for visualizing embeddings in Kangas (#469)
* Added notebook example for visualizing embeddings in Kangas

Plots UMAP projection space, one per row, in open source Kangas DataGrid.

For more information about Kangas, see: https://github.com/comet-ml/kangas

* Moved notebook to third_party_examples
2023-07-11 17:11:16 -07:00

440 lines
20 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "0wjP9mrldJsd"
},
"source": [
"## Visualizing the embeddings in Kangas\n",
"\n",
"In this Jupyter Notebook, we construct a Kangas DataGrid containing the data and projections of the embeddings into 2 dimensions."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4tPKQqqldJsj"
},
"source": [
"## What is Kangas?\n",
"\n",
"[Kangas](https://github.com/comet-ml/kangas/) as an open source, mixed-media, dataframe-like tool for data scientists. It was developed by [Comet](https://comet.com/), a company designed to help reduce the friction of moving models into production. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6sNsB2iFdJsk"
},
"source": [
"### 1. Setup\n",
"\n",
"To get started, we pip install kangas, and import it."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "N8gi529adL-f",
"outputId": "c12e9973-a179-41e3-c5a8-f241804d99ad"
},
"outputs": [],
"source": [
"%pip install kangas --quiet"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "htxjXThodRxD"
},
"outputs": [],
"source": [
"import kangas as kg"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Constructing a Kangas DataGrid\n",
"\n",
"We create a Kangas Datagrid with the original data and the embeddings. The data is composed of a rows of reviews, and the embeddings are composed of 1536 floating-point values. In this example, we get the data directly from github, in case you aren't running this notebook inside OpenAI's repo.\n",
"\n",
"We use Kangas to read the CSV file into a DataGrid for further processing."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0SxWlRTrdVJq",
"outputId": "d36c3a14-2e80-4315-e285-f39f6b008976"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading CSV file 'fine_food_reviews_with_embeddings_1k.csv'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"1001it [00:00, 2412.90it/s]\n",
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2899.16it/s]\n"
]
}
],
"source": [
"data = kg.read_csv(\"https://raw.githubusercontent.com/openai/openai-cookbook/main/examples/data/fine_food_reviews_with_embeddings_1k.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can review the fields of the CSV file:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bzhQgoRGeMCp",
"outputId": "791c4e40-fb28-409e-d1e9-20b753fb1215"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataGrid (in memory)\n",
" Name : fine_food_reviews_with_embeddings_1k\n",
" Rows : 1,000\n",
" Columns: 9\n",
"# Column Non-Null Count DataGrid Type \n",
"--- -------------------- --------------- --------------------\n",
"1 Column 1 1,000 INTEGER \n",
"2 ProductId 1,000 TEXT \n",
"3 UserId 1,000 TEXT \n",
"4 Score 1,000 INTEGER \n",
"5 Summary 1,000 TEXT \n",
"6 Text 1,000 TEXT \n",
"7 combined 1,000 TEXT \n",
"8 n_tokens 1,000 INTEGER \n",
"9 embedding 1,000 TEXT \n"
]
}
],
"source": [
"data.info()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And get a glimpse of the first and last rows:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 349
},
"id": "Q95N832aeaBr",
"outputId": "aaea2816-e5a1-4e52-f228-c3e6aca6fa3e"
},
"outputs": [
{
"data": {
"text/html": [
"<table><th colspan='1' > row-id </th> <th colspan='1' > Column 1 </th> <th colspan='1' > ProductId </th> <th colspan='1' > UserId </th> <th colspan='1' > Score </th> <th colspan='1' > Summary </th> <th colspan='1' > Text </th> <th colspan='1' > combined </th> <th colspan='1' > n_tokens </th> <th colspan='1' > embedding </th> <tr>\n",
"<td colspan='1' > 1 </td> <td colspan='1' > 0 </td> <td colspan='1' > B003XPF9BO </td> <td colspan='1' > A3R7JR3FMEBXQB </td> <td colspan='1' > 5 </td> <td colspan='1' > where does one </td> <td colspan='1' > Wanted to save </td> <td colspan='1' > Title: where do </td> <td colspan='1' > 52 </td> <td colspan='1' > [0.007018072064 </td> <tr>\n",
"<td colspan='1' > 2 </td> <td colspan='1' > 297 </td> <td colspan='1' > B003VXHGPK </td> <td colspan='1' > A21VWSCGW7UUAR </td> <td colspan='1' > 4 </td> <td colspan='1' > Good, but not W </td> <td colspan='1' > Honestly, I hav </td> <td colspan='1' > Title: Good, bu </td> <td colspan='1' > 178 </td> <td colspan='1' > [-0.00314055196 </td> <tr>\n",
"<td colspan='1' > 3 </td> <td colspan='1' > 296 </td> <td colspan='1' > B008JKTTUA </td> <td colspan='1' > A34XBAIFT02B60 </td> <td colspan='1' > 1 </td> <td colspan='1' > Should advertis </td> <td colspan='1' > First, these sh </td> <td colspan='1' > Title: Should a </td> <td colspan='1' > 78 </td> <td colspan='1' > [-0.01757248118 </td> <tr>\n",
"<td colspan='1' > 4 </td> <td colspan='1' > 295 </td> <td colspan='1' > B000LKTTTW </td> <td colspan='1' > A14MQ40CCU8B13 </td> <td colspan='1' > 5 </td> <td colspan='1' > Best tomato sou </td> <td colspan='1' > I have a hard t </td> <td colspan='1' > Title: Best tom </td> <td colspan='1' > 111 </td> <td colspan='1' > [-0.00139322795 </td> <tr>\n",
"<td colspan='1' > 5 </td> <td colspan='1' > 294 </td> <td colspan='1' > B001D09KAM </td> <td colspan='1' > A34XBAIFT02B60 </td> <td colspan='1' > 1 </td> <td colspan='1' > Should advertis </td> <td colspan='1' > First, these sh </td> <td colspan='1' > Title: Should a </td> <td colspan='1' > 78 </td> <td colspan='1' > [-0.01757248118 </td> <tr>\n",
"<tr><td colspan='10' style='text-align: left;'>...</td></tr><td colspan='1' > 996 </td> <td colspan='1' > 623 </td> <td colspan='1' > B0000CFXYA </td> <td colspan='1' > A3GS4GWPIBV0NT </td> <td colspan='1' > 1 </td> <td colspan='1' > Strange inflamm </td> <td colspan='1' > Truthfully wasn </td> <td colspan='1' > Title: Strange </td> <td colspan='1' > 110 </td> <td colspan='1' > [0.000110913533 </td> <tr>\n",
"<td colspan='1' > 997 </td> <td colspan='1' > 624 </td> <td colspan='1' > B0001BH5YM </td> <td colspan='1' > A1BZ3HMAKK0NC </td> <td colspan='1' > 5 </td> <td colspan='1' > My favorite and </td> <td colspan='1' > You've just got </td> <td colspan='1' > Title: My favor </td> <td colspan='1' > 80 </td> <td colspan='1' > [-0.02086931467 </td> <tr>\n",
"<td colspan='1' > 998 </td> <td colspan='1' > 625 </td> <td colspan='1' > B0009ET7TC </td> <td colspan='1' > A2FSDQY5AI6TNX </td> <td colspan='1' > 5 </td> <td colspan='1' > My furbabies LO </td> <td colspan='1' > Shake the conta </td> <td colspan='1' > Title: My furba </td> <td colspan='1' > 47 </td> <td colspan='1' > [-0.00974910240 </td> <tr>\n",
"<td colspan='1' > 999 </td> <td colspan='1' > 619 </td> <td colspan='1' > B007PA32L2 </td> <td colspan='1' > A15FF2P7RPKH6G </td> <td colspan='1' > 5 </td> <td colspan='1' > got this for th </td> <td colspan='1' > all i have hear </td> <td colspan='1' > Title: got this </td> <td colspan='1' > 50 </td> <td colspan='1' > [-0.00521062919 </td> <tr>\n",
"<td colspan='1' > 1000 </td> <td colspan='1' > 999 </td> <td colspan='1' > B001EQ5GEO </td> <td colspan='1' > A3VYU0VO6DYV6I </td> <td colspan='1' > 5 </td> <td colspan='1' > I love Maui Cof </td> <td colspan='1' > My first experi </td> <td colspan='1' > Title: I love M </td> <td colspan='1' > 118 </td> <td colspan='1' > [-0.00605782261 </td> <tr>\n",
"<tr>\n",
"<td colspan='10' style=\"text-align: left;\"> [1000 rows x 9 columns] </td> <tr>\n",
"<tr><td colspan='10' style='text-align: left;'></td></tr><tr><td colspan='10' style='text-align: left;'>* Use DataGrid.save() to save to disk</td></tr><tr><td colspan='10' style='text-align: left;'>** Use DataGrid.show() to start user interface</td></tr></table>"
],
"text/plain": [
"<th colspan='1' > row-id </th> <th colspan='1' > Column 1 </th> <th colspan='1' > ProductId </th> <th colspan='1' > UserId </th> <th colspan='1' > Score </th> <th colspan='1' > Summary </th> <th colspan='1' > Text </th> <th colspan='1' > combined </th> <th colspan='1' > n_tokens </th> <th colspan='1' > embedding </th> <tr>\n",
"<td colspan='1' > 1 </td> <td colspan='1' > 0 </td> <td colspan='1' > B003XPF9BO </td> <td colspan='1' > A3R7JR3FMEBXQB </td> <td colspan='1' > 5 </td> <td colspan='1' > where does one </td> <td colspan='1' > Wanted to save </td> <td colspan='1' > Title: where do </td> <td colspan='1' > 52 </td> <td colspan='1' > [0.007018072064 </td> <tr>\n",
"<td colspan='1' > 2 </td> <td colspan='1' > 297 </td> <td colspan='1' > B003VXHGPK </td> <td colspan='1' > A21VWSCGW7UUAR </td> <td colspan='1' > 4 </td> <td colspan='1' > Good, but not W </td> <td colspan='1' > Honestly, I hav </td> <td colspan='1' > Title: Good, bu </td> <td colspan='1' > 178 </td> <td colspan='1' > [-0.00314055196 </td> <tr>\n",
"<td colspan='1' > 3 </td> <td colspan='1' > 296 </td> <td colspan='1' > B008JKTTUA </td> <td colspan='1' > A34XBAIFT02B60 </td> <td colspan='1' > 1 </td> <td colspan='1' > Should advertis </td> <td colspan='1' > First, these sh </td> <td colspan='1' > Title: Should a </td> <td colspan='1' > 78 </td> <td colspan='1' > [-0.01757248118 </td> <tr>\n",
"<td colspan='1' > 4 </td> <td colspan='1' > 295 </td> <td colspan='1' > B000LKTTTW </td> <td colspan='1' > A14MQ40CCU8B13 </td> <td colspan='1' > 5 </td> <td colspan='1' > Best tomato sou </td> <td colspan='1' > I have a hard t </td> <td colspan='1' > Title: Best tom </td> <td colspan='1' > 111 </td> <td colspan='1' > [-0.00139322795 </td> <tr>\n",
"<td colspan='1' > 5 </td> <td colspan='1' > 294 </td> <td colspan='1' > B001D09KAM </td> <td colspan='1' > A34XBAIFT02B60 </td> <td colspan='1' > 1 </td> <td colspan='1' > Should advertis </td> <td colspan='1' > First, these sh </td> <td colspan='1' > Title: Should a </td> <td colspan='1' > 78 </td> <td colspan='1' > [-0.01757248118 </td> <tr>\n",
"...\n",
"<td colspan='1' > 996 </td> <td colspan='1' > 623 </td> <td colspan='1' > B0000CFXYA </td> <td colspan='1' > A3GS4GWPIBV0NT </td> <td colspan='1' > 1 </td> <td colspan='1' > Strange inflamm </td> <td colspan='1' > Truthfully wasn </td> <td colspan='1' > Title: Strange </td> <td colspan='1' > 110 </td> <td colspan='1' > [0.000110913533 </td> <tr>\n",
"<td colspan='1' > 997 </td> <td colspan='1' > 624 </td> <td colspan='1' > B0001BH5YM </td> <td colspan='1' > A1BZ3HMAKK0NC </td> <td colspan='1' > 5 </td> <td colspan='1' > My favorite and </td> <td colspan='1' > You've just got </td> <td colspan='1' > Title: My favor </td> <td colspan='1' > 80 </td> <td colspan='1' > [-0.02086931467 </td> <tr>\n",
"<td colspan='1' > 998 </td> <td colspan='1' > 625 </td> <td colspan='1' > B0009ET7TC </td> <td colspan='1' > A2FSDQY5AI6TNX </td> <td colspan='1' > 5 </td> <td colspan='1' > My furbabies LO </td> <td colspan='1' > Shake the conta </td> <td colspan='1' > Title: My furba </td> <td colspan='1' > 47 </td> <td colspan='1' > [-0.00974910240 </td> <tr>\n",
"<td colspan='1' > 999 </td> <td colspan='1' > 619 </td> <td colspan='1' > B007PA32L2 </td> <td colspan='1' > A15FF2P7RPKH6G </td> <td colspan='1' > 5 </td> <td colspan='1' > got this for th </td> <td colspan='1' > all i have hear </td> <td colspan='1' > Title: got this </td> <td colspan='1' > 50 </td> <td colspan='1' > [-0.00521062919 </td> <tr>\n",
"<td colspan='1' > 1000 </td> <td colspan='1' > 999 </td> <td colspan='1' > B001EQ5GEO </td> <td colspan='1' > A3VYU0VO6DYV6I </td> <td colspan='1' > 5 </td> <td colspan='1' > I love Maui Cof </td> <td colspan='1' > My first experi </td> <td colspan='1' > Title: I love M </td> <td colspan='1' > 118 </td> <td colspan='1' > [-0.00605782261 </td> <tr>\n",
"<tr>\n",
"<td colspan='10' style=\"text-align: left;\"> [1000 rows x 9 columns] </td> <tr>\n",
"\n",
"* Use DataGrid.save() to save to disk\n",
"** Use DataGrid.show() to start user interface"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we create a new DataGrid, converting the numbers into an Embedding:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "Bu0erP68dvLU"
},
"outputs": [],
"source": [
"import ast # to convert string of a list of numbers into a list of numbers\n",
"\n",
"dg = kg.DataGrid(\n",
" name=\"openai_embeddings\",\n",
" columns=data.get_columns(),\n",
" converters={\"Score\": str},\n",
")\n",
"for row in data:\n",
" embedding = ast.literal_eval(row[8])\n",
" row[8] = kg.Embedding(\n",
" embedding, \n",
" name=str(row[3]), \n",
" text=\"%s - %.10s\" % (row[3], row[4]),\n",
" projection=\"umap\",\n",
" )\n",
" dg.append(row)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The new DataGrid now has an Embedding column with proper datatype."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gd6Od4Bmhijy",
"outputId": "9aa38221-0272-4a63-e393-706e0a0c5879"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataGrid (in memory)\n",
" Name : openai_embeddings\n",
" Rows : 1,000\n",
" Columns: 9\n",
"# Column Non-Null Count DataGrid Type \n",
"--- -------------------- --------------- --------------------\n",
"1 Column 1 1,000 INTEGER \n",
"2 ProductId 1,000 TEXT \n",
"3 UserId 1,000 TEXT \n",
"4 Score 1,000 TEXT \n",
"5 Summary 1,000 TEXT \n",
"6 Text 1,000 TEXT \n",
"7 combined 1,000 TEXT \n",
"8 n_tokens 1,000 INTEGER \n",
"9 embedding 1,000 EMBEDDING-ASSET \n"
]
}
],
"source": [
"dg.info()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We simply save the datagrid, and we're done."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dg.save()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Render 2D Projections\n",
"\n",
"To render the data directly in the notebook, simply show it. Note that each row contains an embedding projection. \n",
"\n",
"Scroll to far right to see embeddings projection per row.\n",
"\n",
"The color of the point in projection space represents the Score."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 771
},
"id": "Z8j-GdpiijU0",
"outputId": "20a0b1ca-3059-4384-cd8c-b32b1aa1c270"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"100%\"\n",
" height=\"750px\"\n",
" src=\"http://127.0.1.1:4000/?datagrid=openai_embeddings.datagrid&timestamp=1685559502.7515423\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" \n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x7fcdd16bfbe0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dg.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Group by \"Score\" to see rows of each group."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"100%\"\n",
" height=\"750px\"\n",
" src=\"http://127.0.1.1:4000/?datagrid=openai_embeddings.datagrid&timestamp=1685559502.7515423&group=Score&sort=Score&rows=5&select=Score%2Cembedding\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" \n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x7fcd06209180>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dg.show(group=\"Score\", sort=\"Score\", rows=5, select=\"Score,embedding\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vLIxfmK5dJsq"
},
"source": [
"An example of this datagrid is hosted here: https://kangas.comet.com/?datagrid=/data/openai_embeddings.datagrid"
]
}
],
"metadata": {
"accelerator": "TPU",
"colab": {
"gpuType": "V100",
"machine_shape": "hm",
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}