You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/Visualizing_embeddings_in_2...

147 lines
99 KiB
Plaintext

2 years ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing the embeddings in 2D\n",
"\n",
"We will use t-SNE to reduce the dimensionality of the embeddings from 2048 to 2. Once the embeddings are reduced to two dimensions, we can plot them in a 2D scatter plot. The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Reduce dimensionality\n",
"\n",
"We reduce the dimensionality to 2 dimensions using t-SNE decomposition."
]
},
{
"cell_type": "code",
"execution_count": 1,
2 years ago
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1000, 2)"
]
},
"execution_count": 1,
2 years ago
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"from sklearn.manifold import TSNE\n",
"\n",
"# Load the embeddings\n",
"datafile_path = \"https://cdn.openai.com/API/examples/data/fine_food_reviews_with_embeddings_1k.csv\" # for your convenience, we precomputed the embeddings\n",
"df = pd.read_csv(datafile_path)\n",
2 years ago
"\n",
"# Convert to a list of lists of floats\n",
"matrix = df.babbage_similarity.apply(eval).to_list()\n",
"\n",
"# Create a t-SNE model and transform the data\n",
"tsne = TSNE(n_components=2, perplexity=15, random_state=42, init='random', learning_rate=200)\n",
"vis_dims = tsne.fit_transform(matrix)\n",
"vis_dims.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Plotting the embeddings\n",
"\n",
"We colour each review by its star rating, ranging from red to green."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can observe a decent data separation even in the reduced 2 dimensions. There seems to be a cluster of mostly negative reviews."
]
},
{
"cell_type": "code",
"execution_count": 2,
2 years ago
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Amazon ratings visualized in language using t-SNE')"
]
},
"execution_count": 2,
2 years ago
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAEICAYAAAC6fYRZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9d5wcx3ngjX+re3Ke2Zx3sQsscgYIEiAJJjGTkqhoWZZk2ZLuzrLPF5ze88+nO/scXtuSbJ9fn32WbdmygkWKYhDFnJAzFmkXm3OamZ0cu7t+f/RiuItAggQIQuR8+cGHO9PdVdXVPU9VPfUEIaWkTJkyZcq8P1He6waUKVOmTJl3j7KQL1OmTJn3MWUhX6ZMmTLvY8pCvkyZMmXex5SFfJkyZcq8jykL+TJlypR5H1MW8u9jhBDNQoiUEEJ9r9sC16Y9QoidQoixBZ9PCSF2XuU6/lEI8fuXOPYZIcRz77Dc/y6E+Jcra937nyvp4w8i7zshL4R4RQgxJ4Swv9dtudYIIYaEEHee+yylHJFSeqSU+nvZrnO8F+2RUq6SUr5yDev7jpTyQ9eqvg8i72Yfn/8busQ5q4QQzwkhokKImBDisBDivvljO4UQUgjx1+dds0sI8fn5vz8vhNDnJzwL/9W/G/f0vhLyQohW4GZAAg+9t625ugghLO91G8qUKQPAk8DzQC1QDfwqkFhwPA18dl4eXYq98xOehf8m3o3Gvq+EPPALwD7gH4HPLTwwv8T+ayHEM/Oj5m4hRK0Q4hvzM/9uIcSGBef/lhCiXwiRFEKcFkJ8ZMGx4+eNwPKcSkAI8dC8iiA2v6pYseC6ISHEfxFCdAkh4kKI7wshHBe7kfnRfrcQ4utCiAjw34UQ7UKIl4QQESFEWAjxHSFEYP78fwaagSfn2/QbQojW+bZZ5s95RQjxP+fLTc7PRioX1PkLQojh+fJ/d+GsRgixVQhxSAiREEJMCyH+/BLtPiOEeGDBZ4sQYlYIsfEi7fm8EGJgvi2DQojPzH+/SG1xkeu+MF9Pcv76L1/qhTjvHmILnll6vszW+WMPCCGOzZ+zRwixdkEZG4QQR+br+z5w0We24J52LfgshRBfEUL0zpf9v4UQ4lLXn1fWvwkhpubfldeEEKsWHPvH+bKenm/XfiFE+4LjHxJC9Mxf+9dCiFeFEL90Nfp3/t2aFEJMCCF+af7ajvljdiHEnwohRubfk78RQjgvcX9v1Y5LvR+X3cdCCFUI8WfC/L0MCiF+ZWEd57Xngt/QRc6pBNqAv5NSFub/7ZZS7lpwWgxTBv3exe77miOlfN/8A/qAfw9sAopAzYJj/wiE5485gJeAQcyBQQV+H3h5wfkfB+oxB8JPYo7OdRep80tAN+ADls2fdxdgBX5jvk22+XOHgAPz5YaAM8BXLnEvnwc04KuABXACHfNl24Eq4DXgGwuuGQLuXPC5FXNVY5n//ArQP99O5/znP5o/thJIATsAG/Cn83145/zxvcBn5//2ANsu0e7/H/CdBZ/vB86c3x7AjTn76Zw/Vgesmv/7vwP/8ib3cT/QDgjgViADbJw/thMYu1SfLPj+f833nxXYAMwAN8y/C5+bv84+3xfDwK/Pn/ux+X75/Td5brsWfJbAU0AAU4DMAvdc4trz7/sXAe98O74BHDvvfY4AW+f78zvA9+aPVc737Ufnj/3afJt/6Sr07z3AFLAKcAH/Mn9tx/zxrwNPYL7fXsxZ7x9e5v1e7vtx2X0MfAU4DTQCQeCFhfd6kTZd9H1ZcFwAvfP1fZgFMmbh+4c5y1/Y/l3A5y/W/nf73/tmJi+E2AG0AD+QUh7GFGY/d95pP5JSHpZS5oAfATkp5belqSP+PuaPHQAp5b9JKSeklIaU8vuYD3brRer8feAhKWUCczB4Wkr5vJSyiCkoncBNCy77i/lyo5g/gPVvclsTUsq/lFJqUsqslLJvvuy8lHIW+HPMH+Hb4R+klGellFngBwvq/xjwpJRyl5SygCmsFwY2KgIdQohKKWVKSrnvEuX/K/CQEMI1//nngO9e4lwDWC2EcEopJ6WUpy7nBqSUT0sp+6XJq8BzmGq6y0II8cn5dj0y/5y+BPwfKeV+KaUupfwnIA9sm/9nxRxMi1LKHwIHL7euef5IShmTUo4AL/Pmz7yElPJbUsqklDKPKRDXCSH8C075kZTygJRSwxTy58q9DzglpXxs/thfYArmy+It+vcTmO/QKSllZr5dAMzPnr8E/LqUMiqlTGIOpp+63LrP4+28H5fq408A35RSjkkp54A/eodtAUCaUvo2zMHgz4DJ+VXW0vPOmwL+Bvgflyhq2/yq49y//itp15vxvhHymLOv56SU4fnP/8p5KhtgesHf2Yt89pz7IEzVxbnlewxYjTlDOne8CVNIfk5KeXb+63rMWR8AUkoDGAUaFtSz8MeWWVjnRRhd+EEIUSOE+J4QYlwIkcCcRVVe/NJLcqn66xfWN/8Djiw494uYK4BuIcRBsUAlsxApZR/mCuXBeUH/EOazOP+8NOag+BXMH8rTQojll3MDQoh7hRD7xPzGF6ZQu6x+EKZK7q+Aj8wPlGBODv7zwh8d0ITZJ/XA+PyP+xzDvD3ezjM/105VCPFHwlQZJjCFCiy+z8t9lhJzdnlZvEX/Lir7vL+rMGf3hxf040/nv39bvIP347L64ry/35J5ddM5Fd/vzLdtTEr5K1LKdsx3Jw18+yKX/zFwtxBi3UWO7ZNSBhb8a7/IOVeF94WQn9f5fQK4dV6HOYW5vF53iQ5+q/JagL8DfgWokFIGgJOYS7Vz9T2OObt7ZsGlE5gP/Vw5AlNYjL+D24LFM2kwZ0USWCOl9AE/f65Nlzj/7TCJuaQFSvdYUSpYyl4p5acxN5r+GPihEMJ9ibK+C3waeBg4PS/4L0BK+ayU8i7MpXg3Zp+D+aNxLTi1dkG77MCjmKukmvln8xMW98NFEUJUYz63/yClPLrg0CjwB+f96FxSyu9i9kvDOR3vPM1vVddV4Ocw++9OwI+pyoDLuE8ufJZi4WeurH8XlY35fp8jjDlZWrWgH/1SyksNapdsB7zp+/F2eLP2XoxFvyEp5VfkGxuj/+uCk6UcBf435iTw/GMRTDXb/3y7jb6avC+EPKZuTMfUK6+f/7cCeB1T5/52cWM+7FkwN6JY/BC/BXRLKf/kvOt+ANwvhLhDCGEF/jPmsn/PO2jDxfBi6s3jQogG4L+ed3waWPIOy/4h5uz7JiGEDXMZXhIoQoifF0JUza9OYvNfG5co63vAh4B/x0Vm8fPl1QghHp4fKPKY93WuvGPALcK0q/cDv73gUhumjnoW0IQQ987X9abMb7T9EFMH/IPzDv8d8BUhxA3CxC2EuF8I4cXci9CAXxVCWIUQH+U8td27hBezXyKYgvACAfMmPA2sEUJ8eP6+/wOLBegx3nn//gD4ghBixfxK7XfPHZh/N/4O+Pr8gIoQokEIcfcl2nnJdrzF+/F2+AHwa/PtCAC/+Rbnv+lvSAgRFEJ8TQjRIYRQ5jdifxHT4ONi/DmmunbFJY6/67xfhPznMPWEI1LKqXP/MJfln7nYTvqbIaU8jalv24v50NcAuxec8ingI2Kxhc3NUsoezNn1X2LOah4EHpzXcV8NvgZsBOKYP+THzjv+h8B/m18q/5e3U/C8vvOrmAJ6EvNHNYP5AwNzw+2UECIFfBP41Lxe/2JlTWL23U2Yex0XQwH+E+bqJ4q5t/Dv5q9/fv66LuAw5ibXubKTmCZrPwDmMGe8T1zGLTZi6pX/43nPrVlKeQj4Zcz3ZQ5zs/zz8/UVMDcwPz/fzk9yYb+/G3wbUy00jrlxeCkhcgHzKsuPA3+COUisBA4x/yyvpH/nV65/gan37lvQrnPvyW+e+35ezfQC0HmJdl6yHbzJ+/E2+TvMPYUu4CjmqkTDnBRejLf6DRUwV1UvYG6snsS8989frDBp7tX9CeZG9EJuFBfayW95Ozd2uYjFqsYyZUyEEB7MGftSKeXge9ycMleAEELB1Ml/Rkr58lUuewWmoLPPb/Je18yvTP5GStnylie/T3i/zOT
2 years ago
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"import numpy as np\n",
"\n",
"colors = [\"red\", \"darkorange\", \"gold\", \"turquoise\", \"darkgreen\"]\n",
"x = [x for x,y in vis_dims]\n",
"y = [y for x,y in vis_dims]\n",
"color_indices = df.Score.values - 1\n",
"\n",
"colormap = matplotlib.colors.ListedColormap(colors)\n",
"plt.scatter(x, y, c=color_indices, cmap=colormap, alpha=0.3)\n",
"for score in [0,1,2,3,4]:\n",
" avg_x = np.array(x)[df.Score-1==score].mean()\n",
" avg_y = np.array(y)[df.Score-1==score].mean()\n",
" color = colors[score]\n",
" plt.scatter(avg_x, avg_y, marker='x', color=color, s=100)\n",
"plt.title(\"Amazon ratings visualized in language using t-SNE\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.9 ('openai')",
"language": "python",
2 years ago
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
2 years ago
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
2 years ago
},
"nbformat": 4,
"nbformat_minor": 2
}