You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/Visualizing_embeddings_in_2...

152 lines
181 KiB
Plaintext

2 years ago
{
"cells": [
{
"attachments": {},
2 years ago
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing the embeddings in 2D\n",
"\n",
"We will use t-SNE to reduce the dimensionality of the embeddings from 1536 to 2. Once the embeddings are reduced to two dimensions, we can plot them in a 2D scatter plot. The dataset is created in the [Get_embeddings_from_dataset Notebook](Get_embeddings_from_dataset.ipynb)."
2 years ago
]
},
{
"attachments": {},
2 years ago
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Reduce dimensionality\n",
"\n",
"We reduce the dimensionality to 2 dimensions using t-SNE decomposition."
]
},
{
"cell_type": "code",
"execution_count": 1,
2 years ago
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1000, 2)"
]
},
"execution_count": 1,
2 years ago
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"from sklearn.manifold import TSNE\n",
"import numpy as np\n",
"from ast import literal_eval\n",
2 years ago
"\n",
"# Load the embeddings\n",
"datafile_path = \"data/fine_food_reviews_with_embeddings_1k.csv\"\n",
"df = pd.read_csv(datafile_path)\n",
2 years ago
"\n",
"# Convert to a list of lists of floats\n",
"matrix = np.array(df.embedding.apply(literal_eval).to_list())\n",
2 years ago
"\n",
"# Create a t-SNE model and transform the data\n",
"tsne = TSNE(n_components=2, perplexity=15, random_state=42, init='random', learning_rate=200)\n",
"vis_dims = tsne.fit_transform(matrix)\n",
"vis_dims.shape"
]
},
{
"attachments": {},
2 years ago
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Plotting the embeddings\n",
"\n",
"We colour each review by its star rating, ranging from red to green."
]
},
{
"attachments": {},
2 years ago
"cell_type": "markdown",
"metadata": {},
"source": [
"We can observe a decent data separation even in the reduced 2 dimensions."
2 years ago
]
},
{
"cell_type": "code",
"execution_count": 2,
2 years ago
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Amazon ratings visualized in language using t-SNE')"
]
},
"execution_count": 2,
2 years ago
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGzCAYAAAABsTylAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9d5hd1XWw/552e5+50/uod1QQsgDRMcUYDAb82Q4QO3ZCYkPiktifPxvsJKTYjhN3xw5u+IdjjBs2VSBRBep9Rhpper+9l1N+f1zpotGMpBFISILz6rnPozlln7X3aeusvYpgGIaBiYmJiYmJiclZiHimBTAxMTExMTExORamomJiYmJiYmJy1mIqKiYmJiYmJiZnLaaiYmJiYmJiYnLWYioqJiYmJiYmJmctpqJiYmJiYmJictZiKiomJiYmJiYmZy2momJiYmJiYmJy1mIqKiYmJiYmJiZnLaaiYvK25b777kMQhDMtxptGEATuu+++My3GJDl+/OMfIwgCPT09b6kcLS0t3HnnnSfc7lSP25133klLS8spa8/k1HC23B8mpw9TUXmL+M53voMgCKxcufJMi/K2IpPJcN9997Fu3bozLYqJick7iKGhIe677z62bds27X10XeenP/0pK1euJBAI4Ha7mTVrFn/2Z3/Ghg0bytutW7cOQRAQBIHNmzdPaufOO+/E5XJNWHbJJZeU9zn6N2fOnDfcz7MB+UwL8E7hoYceoqWlhddee42uri5mzJhxpkV6W5DJZLj//vuB0o16JF/4whf4h3/4hzMg1aklm80iy2ffrfrhD3+Y22+/HavVeqZFmZKzddxMTi1n6jwPDQ1x//3309LSwpIlS6a1zyc/+Um+/e1v8973vpcPfvCDyLJMZ2cnjz/+OG1tbVxwwQWT9rnvvvv4wx/+MK32GxoaeOCBByYt93q909r/bMW8i98Curu7efnll3n00Uf5+Mc/zkMPPcSXvvSlMy3WWYmqqui6jsViedNtybL8tnhR2Wy2My3ClEiShCRJZ1qMY3K2jpvJqeVcOc+jo6N85zvf4S/+4i/4wQ9+MGHdN77xDcbHxyfts2TJEh577DG2bNnC0qVLT3gMr9fLhz70oVMm89mCOfXzFvDQQw/h9/u57rrruOWWW3jooYcmbdPT04MgCHz1q1/l29/+Nm1tbTgcDq666ir6+/sxDIOvfOUrNDQ0YLfbee9730skEpnQxu9+9zuuu+466urqsFqttLe385WvfAVN08rbHPYrmOp3pEVCVVW+8pWv0N7ejtVqpaWlhc9//vPk8/kJx2xpaeH666/nxRdf5Pzzz8dms9HW1sZPf/rTE47LkX3+xje+UT7Wnj17KBQKfPGLX2TZsmV4vV6cTicXXXQRzz333IT9g8EgAPfff3+5H4fnq6fyUREEgb/5m7/ht7/9LQsWLMBqtTJ//nyeeOKJSfKtW7eO5cuXY7PZaG9v5/vf//6UbT799NNceOGF+Hw+XC4Xs2fP5vOf//xx+75gwQIuvfTSSct1Xae+vp5bbrllgsxHzsEnk0nuvfdeWlpasFqtVFVVceWVV7Jly5byNsfy47jkkksmnOfpjPOxONpH5fDYTPU7UhZd1/nGN77B/PnzsdlsVFdX8/GPf5xoNDqhfcMw+Md//EcaGhpwOBxceuml7N69+4RyHebocTssX1dXF3feeSc+nw+v18tdd91FJpOZdrtH8tWvfpV3vetdVFRUYLfbWbZsGY888siUspzK6+7wvfPjH//4hP3u7e3l7rvvZvbs2djtdioqKnj/+98/pW/Rjh07WLNmDXa7nYaGBv7xH/+RBx98cEpfpMcff5yLLroIp9OJ2+3muuuum9b5OZbv2FQ+T5s2beLqq6+msrISu91Oa2srf/7nf37c/p7Mec5ms3zyk5+ksrISt9vNDTfcwODg4An9XtatW8eKFSsAuOuuu8rX+VTn4zDd3d0YhsHq1asnrRMEgaqqqknLP/GJT+D3+9/xPjjn/ufmOcBDDz3E+973PiwWCx/4wAf47ne/y8aNG8sX+tHbFgoFPvGJTxCJRPi3f/s3br31Vi677DLWrVvH3//939PV1cU3v/lNPv3pT/M///M/5X1//OMf43K5+Lu/+ztcLhfPPvssX/ziF0kkEvz7v/87ABdffDE/+9nPJhyzt7eXL3zhCxNulI9+9KP85Cc/4ZZbbuFTn/oUr776Kg888AB79+7lN7/5zYT9u7q6uOWWW/jIRz7CHXfcwf/8z/9w5513smzZMubPn3/C8XnwwQfJ5XJ87GMfw2q1EggESCQS/PCHP+QDH/gAf/EXf0EymeRHP/oRV199Na+99hpLliwhGAzy3e9+l7/6q7/ipptu4n3vex8AixYtOu7xXnzxRR599FHuvvtu3G43//Vf/8XNN99MX18fFRUVAGzdupV3v/vd1NbWcv/996NpGl/+8pfLitFhdu/ezfXXX8+iRYv48pe/jNVqpauri5deeum4Mtx2223cd999jIyMUFNTM0G2oaEhbr/99mPu+5d/+Zc88sgj/M3f/A3z5s0jHA7z4osvsnfv3ml9dR3JdMZ5urzvfe+bNKW5efNmvvGNb0y4tj7+8Y/z4x//mLvuuotPfvKTdHd3861vfYutW7fy0ksvoSgKAF/84hf5x3/8R6699lquvfZatmzZwlVXXUWhUDipPh7NrbfeSmtrKw888ABbtmzhhz/8IVVVVfzrv/7rSbf1n//5n9xwww188IMfpFAo8PDDD/P+97+fxx57jOuuu27CtqfyujsZNm7cyMsvv8ztt99OQ0MDPT09fPe73+WSSy5hz549OBwOAAYHB7n00ksRBIHPfe5zOJ1OfvjDH045tfezn/2MO+64g6uvvpp//dd/JZPJ8N3vfpcLL7yQrVu3nhKn47GxMa666iqCwSD/8A//gM/no6enh0cffXRa+0/nPN9555387//+Lx/+8Ie54IILWL9+/aTzNhVz587ly1/+Ml/84hf52Mc+xkUXXQTAu971rmPu09zcDMCvfvUr3v/+95fH/Xh4PB7+9m//li9+8YvTsqpomkYoFJq03G6343Q6T3i8sxbD5LSyadMmAzCefvppwzAMQ9d1o6GhwbjnnnsmbNfd3W0ARjAYNGKxWHn55z73OQMwFi9ebBSLxfLyD3zgA4bFYjFyuVx5WSaTmXT8j3/844bD4Ziw3ZFks1lj2bJlRl1dnTE8PGwYhmFs27bNAIyPfvSjE7b99Kc/bQDGs88+W17W3NxsAMbzzz9fXjY2NmZYrVbjU5/61HHH5nCfPR6PMTY2NmGdqqpGPp+fsCwajRrV1dXGn//5n5eXjY+PG4DxpS99aVL7X/rSl4yjL3HAsFgsRldXV3nZ9u3bDcD45je/WV72nve8x3A4HMbg4GB52f79+w1Zlie0+R//8R8GYIyPjx+3r0fT2dk56ZiGYRh333234XK5JpzLo/vn9XqNv/7rvz5u+83NzcYdd9wxafmaNWuMNWvWlP+e7jhPJceDDz5oAEZ3d/eUMoyPjxtNTU3GwoULjVQqZRiGYbzwwgsGYDz00EMTtn3iiScmLB8bGzMsFotx3XXXGbqul7f7/Oc/bwBT9u1ojpb38PVwdL9uuukmo6Ki4oTt3XHHHUZzc/OEZUffc4VCwViwYIFx2WWXTZLlVF53h++dBx988IT9nuq58MorrxiA8dOf/rS87BOf+IQhCIKxdevW8rJwOGwEAoEJ5zmZTBo+n8/4i7/4iwltjoyMGF6vd9Lyo5nqvjSMydfTb37zGwMwNm7ceNz23uh53rx5swEY995774Tt7rzzzmM+U45k48aNxzwHx+LP/uzPDMDw+/3GTTfdZHz1q1819u7dO2m75557zgCMX/3qV0YsFjP8fr9xww03lNffcccdhtPpnLDPmjVrDGDK38c//vFpy3g2Yk79nGYeeughqqury2Z+QRC47bbbePjhhydMyRzm/e9//wTHp8NRQh/60Icm+FusXLmSQqHA4OBgeZndbi//P5lMEgqFuOiii8hkMnR0dEwp3913383OnTv59a9/Xf6
2 years ago
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
2 years ago
]
},
"metadata": {},
2 years ago
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"import numpy as np\n",
"\n",
"colors = [\"red\", \"darkorange\", \"gold\", \"turquoise\", \"darkgreen\"]\n",
"x = [x for x,y in vis_dims]\n",
"y = [y for x,y in vis_dims]\n",
"color_indices = df.Score.values - 1\n",
"\n",
"colormap = matplotlib.colors.ListedColormap(colors)\n",
"plt.scatter(x, y, c=color_indices, cmap=colormap, alpha=0.3)\n",
"for score in [0,1,2,3,4]:\n",
" avg_x = np.array(x)[df.Score-1==score].mean()\n",
" avg_y = np.array(y)[df.Score-1==score].mean()\n",
" color = colors[score]\n",
" plt.scatter(avg_x, avg_y, marker='x', color=color, s=100)\n",
"\n",
2 years ago
"plt.title(\"Amazon ratings visualized in language using t-SNE\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai",
"language": "python",
"name": "python3"
2 years ago
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
2 years ago
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
2 years ago
},
"nbformat": 4,
"nbformat_minor": 2
}