mirror of
https://github.com/openai/openai-cookbook
synced 2024-11-04 06:00:33 +00:00
269 lines
282 KiB
Plaintext
269 lines
282 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "983ef639-fbf4-4912-b593-9cf08aeb11cd",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Visualizing the embeddings in 3D"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "9c9ea9a8-675d-4e3a-a8f7-6f4563df84ad",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"The example uses [PCA](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html) to reduce the dimensionality fo the embeddings from 2048 to 3. Then we can visualize the data points in a 3D plot. The small dataset `dbpedia_samples.jsonl` is curated by randomly sampling 200 samples from [DBpedia validation dataset](https://www.kaggle.com/danofer/dbpedia-classes?select=DBPEDIA_val.csv)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "8df5f2c3-ddbb-4cc4-9205-4c0af1670562",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### 1. Load the dataset and query embeddings"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "133dfc2a-9dbd-4a5a-96fa-477272f7af5a",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Categories of DBpedia samples: Artist 21\n",
|
||
|
"Film 19\n",
|
||
|
"Plant 19\n",
|
||
|
"OfficeHolder 18\n",
|
||
|
"Company 17\n",
|
||
|
"NaturalPlace 16\n",
|
||
|
"Athlete 16\n",
|
||
|
"Village 12\n",
|
||
|
"WrittenWork 11\n",
|
||
|
"Building 11\n",
|
||
|
"Album 11\n",
|
||
|
"Animal 11\n",
|
||
|
"EducationalInstitution 10\n",
|
||
|
"MeanOfTransportation 8\n",
|
||
|
"Name: category, dtype: int64\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<div>\n",
|
||
|
"<style scoped>\n",
|
||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||
|
" vertical-align: middle;\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" .dataframe tbody tr th {\n",
|
||
|
" vertical-align: top;\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" .dataframe thead th {\n",
|
||
|
" text-align: right;\n",
|
||
|
" }\n",
|
||
|
"</style>\n",
|
||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||
|
" <thead>\n",
|
||
|
" <tr style=\"text-align: right;\">\n",
|
||
|
" <th></th>\n",
|
||
|
" <th>text</th>\n",
|
||
|
" <th>category</th>\n",
|
||
|
" </tr>\n",
|
||
|
" </thead>\n",
|
||
|
" <tbody>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>0</th>\n",
|
||
|
" <td>Morada Limited is a textile company based in ...</td>\n",
|
||
|
" <td>Company</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>1</th>\n",
|
||
|
" <td>The Armenian Mirror-Spectator is a newspaper ...</td>\n",
|
||
|
" <td>WrittenWork</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>2</th>\n",
|
||
|
" <td>Mt. Kinka (金華山 Kinka-zan) also known as Kinka...</td>\n",
|
||
|
" <td>NaturalPlace</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>3</th>\n",
|
||
|
" <td>Planning the Play of a Bridge Hand is a book ...</td>\n",
|
||
|
" <td>WrittenWork</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>4</th>\n",
|
||
|
" <td>Wang Yuanping (born 8 December 1976) is a ret...</td>\n",
|
||
|
" <td>Athlete</td>\n",
|
||
|
" </tr>\n",
|
||
|
" </tbody>\n",
|
||
|
"</table>\n",
|
||
|
"</div>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
" text category\n",
|
||
|
"0 Morada Limited is a textile company based in ... Company\n",
|
||
|
"1 The Armenian Mirror-Spectator is a newspaper ... WrittenWork\n",
|
||
|
"2 Mt. Kinka (金華山 Kinka-zan) also known as Kinka... NaturalPlace\n",
|
||
|
"3 Planning the Play of a Bridge Hand is a book ... WrittenWork\n",
|
||
|
"4 Wang Yuanping (born 8 December 1976) is a ret... Athlete"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import pandas as pd\n",
|
||
|
"samples = pd.read_json(\"dbpedia_samples.jsonl\", lines=True)\n",
|
||
|
"categories = sorted(samples[\"category\"].unique())\n",
|
||
|
"print(\"Categories of DBpedia samples:\", samples[\"category\"].value_counts())\n",
|
||
|
"samples.head()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "19874e3e-a216-48cc-a27b-acb73854d832",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from openai.embeddings_utils import get_embeddings\n",
|
||
|
"# NOTE: The following code will send a query of batch size 200 to /embeddings, cost about $0.2\n",
|
||
|
"matrix = get_embeddings(samples[\"text\"].to_list(), engine=\"text-similarity-babbage-001\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "d410c268-d8a7-4979-887c-45b1d382dda9",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### 2. Reduce the embedding dimensionality"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "f5410068-f3da-490c-8576-48e84a8728de",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.decomposition import PCA\n",
|
||
|
"pca = PCA(n_components=3)\n",
|
||
|
"vis_dims = pca.fit_transform(matrix)\n",
|
||
|
"samples[\"embed_vis\"] = vis_dims.tolist()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "b6565f57-59c6-4d36-a094-3cbbd9ddeb4c",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### 3. Plot the embeddings of lower dimensionality"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "b17caad3-f0de-4115-83eb-55434a132acc",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"<matplotlib.legend.Legend at 0x14b5df760>"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"application/vnd.jupyter.widget-view+json": {
|
||
|
"model_id": "864488447fdd46b4ae1f338d3b0afded",
|
||
|
"version_major": 2,
|
||
|
"version_minor": 0
|
||
|
},
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA+gAAAH0CAYAAACuKActAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeXycdbn//9c9W2bJZGayJ23adEn3Jd3bFKEKQhU5llOxCqigcs4BwYPoUTkoBz3iguAPEcXv0WrFBVCPIAcUxMre0hZo0jRJ0zRNmjT7MllmJrPfvz/C3CRpkmaZSdL2ej4ePJRk5r7vyca87+vzuS5FVVUVIYQQQgghhBBCTCvddF+AEEIIIYQQQgghJKALIYQQQgghhBAzggR0IYQQQgghhBBiBpCALoQQQgghhBBCzAAS0IUQQgghhBBCiBlAAroQQgghhBBCCDEDSEAXQgghhBBCCCFmAAnoQgghhBBCCCHEDCABXQghhBBCCCGEmAEkoAshhBBCCCGEEDOABHQhhBBCCCGEEGIGkIAuhBBCCCGEEELMABLQhRBCCCGEEEKIGUACuhBCCCGEEEIIMQNIQBdCCCGEEEIIIWYACehCCCGEEEIIIcQMIAFdCCGEEEIIIYSYASSgCyGEEEIIIYQQM4AEdCGEEEIIIYQQYgaQgC6EEEIIIYQQQswAEtCFEEIIIYQQQogZQAK6EEIIIYQQQggxA0hAF0IIIYQQQgghZgAJ6EIIIYQQQgghxAwgAV0IIYQQQgghhJgBJKALIYQQQgghhBAzgAR0IYQQQgghhBBiBpCALoQQQgghhBBCzAAS0IUQQgghhBBCiBlAAroQQgghhBBCCDEDSEAXQgghhBBCCCFmAAnoQgghhBBCCCHEDCABXQghhBBCCCGEmAEkoAshhBBCCCGEEDOABHQhhBBCCCGEEGIGkIAuhBBCCCGEEELMABLQhRBCCCGEEEKIGUACuhBCCCGEEEIIMQNIQBdCCCGEEEIIIWYACehCCCGEEEIIIcQMIAFdCCGEEEIIIYSYASSgCyGEEEIIIYQQM4AEdCGEEEIIIYQQYgaQgC6EEEIIIYQQQswAEtCFEEIIIYQQQogZQAK6EEIIIYQQQggxA0hAF0IIIYQQQgghZgAJ6EIIIYQQQgghxAwgAV0IIYQQQgghhJgBJKALIYQQQgghhBAzgAR0IYQQQgghhBBiBpCALoQQQgghhBBCzAAS0IUQQgghhBBCiBlAAroQQgghhBBCCDEDSEAXQgghhBBCCCFmAAnoQgghhBBCCCHEDCABXQghhBBCCCGEmAEkoAshhBBCCCGEEDOABHQhhBBCCCGEEGIGkIAuhBBCCCGEEELMABLQhRBCCCGEEEKIGUACuhBCCCGEEEIIMQNIQBdCCCGEEEIIIWYACehCCCGEEEIIIcQMIAFdCCGEEEIIIYSYAQzTfQFCCCGEEEKIsYtEIoRCoem+DCFmLKPRiF6vn+7LmBAJ6EIIIYQQQpwDVFWlubmZrq6u6b4UIWY8p9NJdnY2iqJM96WMiwR0IYQQQgghzgGxcJ6ZmYnVaj3ngocQU0FVVXw+H62trQDk5ORM8xWNjwR0IYQQQgghZrhIJKKF87S0tOm+HCFmNIvFAkBrayuZmZnn1HJ3aRInhBBCCCHEDBfbc261Wqf5SoQ4N8R+V861fg0S0IUQQgghhDhHyLJ2IcbmXP1dkYAuhBBCCCGEEELMABLQhRBCCCGEENPmpZdeQlEUrTv9nj17cDqd03pNQkwXCehCCCGEEEKIhNu/fz96vZ4rr7xyui9FiBlLAroQQgghhBAi4Xbv3s1tt93GK6+8QmNj43RfjhAzkgR0IYQQQgghLhDVbR4+/9hhVt3zPOu/9QLfeqacbl/iu1x7PB6eeOIJbr75Zq688kr27Nlz1uc89dRTFBQUYDabueKKK6ivr9c+d8MNN7Bjx45Bj7/99tvZtm2b9u/btm3jtttu4/bbb8flcpGVlcXPfvYzvF4vN954I3a7nYULF/LXv/41Tq9SiMmTgC6EEEIIIcQFoKbdy4cffp1nS5vo8Ydp9wT55es1XPP/9tEXjCT03L///e9ZsmQJixcv5vrrr+cXv/gFqqqO+Hifz8e9997Lo48+yuuvv05XVxcf+9jHxn3eX/3qV6Snp3Pw4EFuu+02br75Zq655hqKiop4++23ufzyy/nEJz6Bz+ebzMsTIm4koAshhBBCCHEBePgfVfSFIkSi7wbjiArHWzw8ebghoefevXs3119/PQDbt2+nu7ubl19+ecTHh0IhHn74YbZs2cK6dev41a9+xb59+zh48OC4zrt69Wq+9rWvUVBQwJ133onZbCY9PZ2bbrqJgoIC7r77bjo6Ojhy5MikXp8Q8SIBXQghhBBCiAvAy8fbBoXzGJ0Cr59oT9h5KysrOXjwIB//+McBMBgM7Nq1i927d4/4HIPBwIYNG7R/X7JkCU6nk4qKinGde9WqVdr/1+v1pKWlsXLlSu1jWVlZALS2to7ruEIkimG6L0AIIYQQQgiReFaTAQie8XFFUbCY9Ak77+7duwmHw+Tm5mofU1WVpKQkHn744QkdU6fTnbFEPhQ6cy+90Wgc9O+Kogz6mKIoAESj0QldhxDxJhV0IYQQQgghLgBXr5mFTjnz45GoylWrc8/8RByEw2EeffRRHnjgAYqLi7V/SkpKyM3N5bHHHhvxeW+++ab275WVlXR1dbF06VIAMjIyaGpqGvSc4uLihLwGIaaSBHQhhBBCCCEuAP96yXxWz3YCoNcp6N9J69dtmsPFBekJOeczzzyD2+3mM5/5DCtWrBj0z86dO0dc5m40Grnttts4cOAAb731FjfccAObN29m48aNALzvfe/jzTff5NFHH6Wqqor/+q//4ujRowl5DUJMJQnoQgghhBBCXACsJgNP/OsWfvixQnYUzmLXhjx+99lNfGvHCm2pd7zt3r2byy67DIfDccbndu7cyZtvvjlsgzar1cpXvvIVrr32WrZu3UpycjJPPPGE9vkrrriCr3/963z5y19mw4YN9Pb28slPfjIhr0GIqaSoo803EEIIIYQQQkw7v99PTU0N8+bNw2w2T/flCDHjnau/M1JBF0IIIYQQQgghZgAJ6EIIIYQQQgghxAwgAV0IIYQQQgghhJgBJKALIYQQQgghhBAzgGG6L0AIIYQ4H6mqSjQaJRAIAP0jg/R6PYqiJKxbshBCCCHObRLQhRBCiDhTVZVwOEw4HCYQCKCqKoFAAEVR0Ov1WljX6/XodLKYTQghhBD9JKALIYQQcRSNRgmFQkSjUQD0er32uVhwD4VCWiVdArsQQgghYiSgCyGEEHEQW9IeC+dDg3YskMc+rqrqoMAOoNPpMBgMGAwGCexCCCHEBUj+qy+EEEJMkqqqhEIhgsEgqqqi0+nOus88FtYNBgNGoxGDwYCiKHR3d/Pyyy/j8Xjo6enB4/Hg9/sHVeWFEOJCc88991BYWHjenEeIkUhAF0IIISYhGo0SDAYJh8Na6J5IE7jYc/V6PaFQSFsaHwqF6OvrGzawq6oa75cjhBAJs3//fvR6PVdeeeW4n/ulL32JvXv3JuCqhJhZJKALIYQQExBbnh4IBIhEIiMG8/GGdUVRUFVV258+cMk79Ad2n8+Hx+Ohu7tbC+zhcFgCuxBiRtu9eze33XYbr7zyCo2NjeN6bnJyMmlpaQm6MiFmDgnoQgghxDjFlrQP3Ds+XBBvbm5m//79lJSUUFdXR09Pz5hC9HCPGRjYY03l4N3A3tvbq1XYA4GABHYhxIzi8Xh44oknuPnmm7nyyivZs2eP9rmXXnoJRVHYu3cv69evx2q1UlRURGVlpfaYoUvPb7jhBnbs2MG3v/1tsrKycDqdfPOb3yQcDvMf//EfpKamMnv2bH75y18Ouo6vfOUrLFq0CKvVyvz58/n617+u/S0XYiaQJnFCCCHEOEQikUGN4IYL5pFIhGPHjtHU1MTChQsJh8N0dXVRU1ODoig4nU5cLhculwubzTboGGOtuMcCe8zAJnXBYHDQHveBTedkBrsQF7Yub4jyeg+t3UF
|
||
|
"text/html": [
|
||
|
"\n",
|
||
|
" <div style=\"display: inline-block;\">\n",
|
||
|
" <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n",
|
||
|
" Figure\n",
|
||
|
" </div>\n",
|
||
|
" <img src='
|
||
|
" </div>\n",
|
||
|
" "
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%matplotlib widget\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import numpy as np\n",
|
||
|
"\n",
|
||
|
"fig = plt.figure(figsize=(10, 5))\n",
|
||
|
"ax = fig.add_subplot(projection='3d')\n",
|
||
|
"cmap = plt.get_cmap(\"tab20\")\n",
|
||
|
"\n",
|
||
|
"# Plot each sample category individually such that we can set label name.\n",
|
||
|
"for i, cat in enumerate(categories):\n",
|
||
|
" sub_matrix = np.array(samples[samples[\"category\"] == cat][\"embed_vis\"].to_list())\n",
|
||
|
" x=sub_matrix[:, 0]\n",
|
||
|
" y=sub_matrix[:, 1]\n",
|
||
|
" z=sub_matrix[:, 2]\n",
|
||
|
" colors = [cmap(i/len(categories))] * len(sub_matrix)\n",
|
||
|
" ax.scatter(x, y, zs=z, zdir='z', c=colors, label=cat)\n",
|
||
|
"\n",
|
||
|
"ax.set_xlabel('x')\n",
|
||
|
"ax.set_ylabel('y')\n",
|
||
|
"ax.set_zlabel('z')\n",
|
||
|
"ax.legend(bbox_to_anchor=(1.1, 1))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "a8868043-9889-4a0b-b23d-79bb3823bdc7",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.9"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|