You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/Visualizing_embeddings_in_2...

148 lines
171 KiB
Plaintext

2 years ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing the embeddings in 2D\n",
"\n",
"We will use t-SNE to reduce the dimensionality of the embeddings from 2048 to 2. Once the embeddings are reduced to two dimensions, we can plot them in a 2D scatter plot. The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Reduce dimensionality\n",
"\n",
"We reduce the dimensionality to 2 dimensions using t-SNE decomposition."
]
},
{
"cell_type": "code",
"execution_count": 1,
2 years ago
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1000, 2)"
]
},
"execution_count": 1,
2 years ago
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"from sklearn.manifold import TSNE\n",
"import numpy as np\n",
2 years ago
"\n",
"# Load the embeddings\n",
"# If you have not run the \"Obtain_dataset.ipynb\" notebook, you can download the datafile from here: https://cdn.openai.com/API/examples/data/fine_food_reviews_with_embeddings_1k.csv\n",
"datafile_path = \"./data/fine_food_reviews_with_embeddings_1k.csv\"\n",
"df = pd.read_csv(datafile_path)\n",
2 years ago
"\n",
"# Convert to a list of lists of floats\n",
"matrix = np.array(df.ada_similarity.apply(eval).to_list())\n",
2 years ago
"\n",
"# Create a t-SNE model and transform the data\n",
"tsne = TSNE(n_components=2, perplexity=15, random_state=42, init='random', learning_rate=200)\n",
"vis_dims = tsne.fit_transform(matrix)\n",
"vis_dims.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Plotting the embeddings\n",
"\n",
"We colour each review by its star rating, ranging from red to green."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can observe a decent data separation even in the reduced 2 dimensions. There seems to be a cluster of mostly negative reviews."
]
},
{
"cell_type": "code",
"execution_count": 2,
2 years ago
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Amazon ratings visualized in language using t-SNE')"
]
},
"execution_count": 2,
2 years ago
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGzCAYAAAABsTylAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd5gc1ZXw4V9Vdc7dM92Tg2ZGOaGAhBBCZJGMSSZ8DoCN7XVmba938TqA7TXruOzaxtkY23ixjcGwNjknCQkFlDUz0gRN7Amdc1fV90dJjUYzikhIwH2fpx9pqrurb1Wn0/eee66k67qOIAiCIAjCSUg+0Q0QBEEQBEE4EBGoCIIgCIJw0hKBiiAIgiAIJy0RqAiCIAiCcNISgYogCIIgCCctEagIgiAIgnDSEoGKIAiCIAgnLRGoCIIgCIJw0hKBiiAIgiAIJy0RqAjvWLfddhuSJJ3oZrxpkiRx2223nehmjGvHb3/7WyRJorOz8y1tR2NjIzfeeOMhb3esz9uNN95IY2PjMdufcGycLO8P4fgRgcpb5K677kKSJBYvXnyim/KOkk6nue2223juuedOdFMEQXgX6evr47bbbmPDhg2HfR9N0/jd737H4sWLCQQCuN1upkyZwoc+9CFWrVpVut1zzz2HJElIksTatWvH7efGG2/E5XKN2XbWWWeV7rP/Zdq0aUd9nCcD04luwLvFvffeS2NjI6tXr6a9vZ2WlpYT3aR3hHQ6ze233w4Yb9R9feUrX+Hf/u3fTkCrjq1MJoPJdPK9VT/4wQ9y3XXXYbVaT3RTJnSynjfh2DpRz3NfXx+33347jY2NnHLKKYd1n89+9rP85Cc/4b3vfS/vf//7MZlM7Nixg0cffZSmpiZOO+20cfe57bbb+L//+7/D2n9tbS133HHHuO1er/ew7n+yEu/it0BHRwevvPIKDzzwAB//+Me59957+frXv36im3VSKhaLaJqGxWJ50/symUzviC8qm812opswIUVRUBTlRDfjgE7W8yYcW2+X53lwcJC77rqLj370o/ziF78Yc92dd97J0NDQuPuccsop/P3vf2fdunXMnz//kI/h9Xr5wAc+cMzafLIQQz9vgXvvvRe/388ll1zC1Vdfzb333jvuNp2dnUiSxPe//31+8pOf0NTUhMPh4IILLmD37t3ous43v/lNamtrsdvtvPe972V0dHTMPh566CEuueQSqqursVqtNDc3881vfhNVVUu32ZtXMNFl3x6JYrHIN7/5TZqbm7FarTQ2NvLlL3+ZXC435jEbGxu59NJLeemll1i0aBE2m42mpiZ+97vfHfK87HvMd955Z+mxtm7dSj6f52tf+xoLFizA6/XidDpZtmwZzz777Jj7B4NBAG6//fbScewdr54oR0WSJD796U/zt7/9jVmzZmG1Wpk5cyaPPfbYuPY999xzLFy4EJvNRnNzMz//+c8n3OeTTz7JGWecgc/nw+VyMXXqVL785S8f9NhnzZrF2WefPW67pmnU1NRw9dVXj2nzvmPwiUSCW265hcbGRqxWK6FQiPPPP59169aVbnOgPI6zzjprzPN8OOf5QPbPUdl7bia67NsWTdO48847mTlzJjabjYqKCj7+8Y8TiUTG7F/Xdb71rW9RW1uLw+Hg7LPPZsuWLYds1177n7e97Wtvb+fGG2/E5/Ph9Xq56aabSKfTh73ffX3/+9/n9NNPp6ysDLvdzoIFC7j//vsnbMuxfN3tfe/89re/PeRxd3V18clPfpKpU6dit9spKyvjfe9734S5RRs3bmT58uXY7XZqa2v51re+xd133z1hLtKjjz7KsmXLcDqduN1uLrnkksN6fg6UOzZRztNrr73GihUrKC8vx263M2nSJD784Q8f9HiP5HnOZDJ89rOfpby8HLfbzWWXXUZvb+8h816ee+45Tj31VABuuumm0ut8oudjr46ODnRdZ+nSpeOukySJUCg0bvtnPvMZ/H7/uz4H5+3/c/Nt4N577+XKK6/EYrFw/fXX89Of/pQ1a9aUXuj73zafz/OZz3yG0dFRvvvd73LNNddwzjnn8Nxzz/Gv//qvtLe386Mf/YgvfvGL/OY3vynd97e//S0ul4vPf/7zuFwunnnmGb72ta8Rj8f53ve+B8CZZ57J73//+zGP2dXVxVe+8pUxb5Sbb76Ze+65h6uvvpovfOELvPrqq9xxxx1s27aNBx98cMz929vbufrqq/nIRz7CDTfcwG9+8xtuvPFGFixYwMyZMw95fu6++26y2Swf+9jHsFqtBAIB4vE4v/rVr7j++uv56Ec/SiKR4Ne//jUrVqxg9erVnHLKKQSDQX7605/yiU98giuuuIIrr7wSgDlz5hz08V566SUeeOABPvnJT+J2u/mf//kfrrrqKrq7uykrKwNg/fr1XHjhhVRVVXH77bejqirf+MY3SoHRXlu2bOHSSy9lzpw5fOMb38BqtdLe3s7LL7980DZce+213HbbbQwMDFBZWTmmbX19fVx33XUHvO8//dM/cf/99/PpT3+aGTNmMDIywksvvcS2bdsO61fXvg7nPB+uK6+8ctyQ5tq1a7nzzjvHvLY+/vGP89vf/pabbrqJz372s3R0dPDjH/+Y9evX8/LLL2M2mwH42te+xre+9S0uvvhiLr74YtatW8cFF1xAPp8/omPc3zXXXMOkSZO44447WLduHb/61a8IhUJ85zvfOeJ9/fd//zeXXXYZ73//+8nn89x33328733v4+9//zuXXHLJmNsey9fdkVizZg2vvPIK1113HbW1tXR2dvLTn/6Us846i61bt+JwOADo7e3l7LPPRpIkbr31VpxOJ7/61a8mHNr7/e9/zw033MCKFSv4zne+Qzqd5qc//SlnnHEG69evPyZJx+FwmAsuuIBgMMi//du/4fP56Ozs5IEHHjis+x/O83zjjTfy5z//mQ9+8IOcdtppPP/88+Oet4lMnz6db3zjG3zta1/jYx/7GMuWLQPg9NNPP+B9GhoaAPjLX/7C+973vtJ5PxiPx8M///M/87Wvfe2welVUVWV4eHjcdrvdjtPpPOTjnbR04bh67bXXdEB/8skndV3XdU3T9NraWv1zn/vcmNt1dHTogB4MBvVoNFrafuutt+qAPnfuXL1QKJS2X3/99brFYtGz2WxpWzqdHvf4H//4x3WHwzHmdvvKZDL6ggUL9Orqar2/v1/XdV3fsGGDDug333zzmNt+8Ytf1AH9mWeeKW1raGjQAf2FF14obQuHw7rVatW/8IUvHPTc7D1mj8ejh8PhMdcVi0U9l8uN2RaJRPSKigr9wx/+cGnb0NCQDuhf//rXx+3/61//ur7/SxzQLRaL3t7eXtr2+uuv64D+ox/9qLTtPe95j+5wOPTe3t7Stra2Nt1kMo3Z53/913/pgD40NHTQY93fjh07xj2mruv6Jz/5Sd3lco15Lvc/Pq/Xq3/qU5866P4bGhr0G264Ydz25cuX68uXLy/9fbjneaJ23H333Tqgd3R0TNiGoaEhvb6+Xp89e7aeTCZ1Xdf1F198UQf0e++9d8xtH3vssTHbw+GwbrFY9EsuuUTXNK10uy9/+cs6MOGx7W//9u59Pex/XFdccYVeVlZ2yP3dcMMNekNDw5ht+7/n8vm8PmvWLP2cc84Z15Zj+brb+965++67D3ncE30urFy5Ugf03/3ud6Vtn/nMZ3RJkvT169eXto2MjOiBQGDM85xIJHSfz6d/9KMfHbPPgYEB3ev1jtu+v4nel7o+/vX04IMP6oC+Zs2ag+7vaJ/ntWvX6oB+yy23jLndjTfeeMDPlH2tWbPmgM/BgXzoQx/SAd3v9+tXXHGF/v3vf1/ftm3buNs9++yzOqD/5S9/0aPRqO73+/XLLrusdP0NN9ygO53OMfdZvny5Dkx4+fjHP37YbTwZiaGf4+zee++loqKi1M0vSRLXXnst991335ghmb3e9773jUl82jtL6AMf+MCYfIvFixeTz+fp7e0tbbPb7aX/JxIJhoeHWbZsGel0mu3bt0/Yvk9+8pNs2rSJv/71r6Vf9o888ggAn//858fc9gt
2 years ago
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
2 years ago
]
},
"metadata": {},
2 years ago
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"import numpy as np\n",
"\n",
"colors = [\"red\", \"darkorange\", \"gold\", \"turquoise\", \"darkgreen\"]\n",
"x = [x for x,y in vis_dims]\n",
"y = [y for x,y in vis_dims]\n",
"color_indices = df.Score.values - 1\n",
"\n",
"colormap = matplotlib.colors.ListedColormap(colors)\n",
"plt.scatter(x, y, c=color_indices, cmap=colormap, alpha=0.3)\n",
"for score in [0,1,2,3,4]:\n",
" avg_x = np.array(x)[df.Score-1==score].mean()\n",
" avg_y = np.array(y)[df.Score-1==score].mean()\n",
" color = colors[score]\n",
" plt.scatter(avg_x, avg_y, marker='x', color=color, s=100)\n",
"\n",
2 years ago
"plt.title(\"Amazon ratings visualized in language using t-SNE\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai-cookbook",
"language": "python",
"name": "openai-cookbook"
2 years ago
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
2 years ago
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
2 years ago
},
"nbformat": 4,
"nbformat_minor": 2
}