You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/Visualizing_embeddings_in_2...

143 lines
97 KiB
Plaintext

2 years ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing the embeddings in 2D\n",
"\n",
"We will use t-SNE to reduce the dimensionality of the embeddings from 2048 to 2. Once the embeddings are reduced to two dimensions, we can plot them in a 2D scatter plot. The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Reduce dimensionality\n",
"\n",
"We reduce the dimensionality to 2 dimensions using t-SNE decomposition."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1000, 2)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"from sklearn.manifold import TSNE\n",
"\n",
"# Load the embeddings\n",
"df = pd.read_csv('output/embedded_1k_reviews.csv')\n",
"\n",
"# Convert to a list of lists of floats\n",
"matrix = df.babbage_similarity.apply(eval).to_list()\n",
"\n",
"# Create a t-SNE model and transform the data\n",
"tsne = TSNE(n_components=2, perplexity=15, random_state=42, init='random', learning_rate=200)\n",
"vis_dims = tsne.fit_transform(matrix)\n",
"vis_dims.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Plotting the embeddings\n",
"\n",
"We colour each review by its star rating, ranging from red to green."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can observe a decent data separation even in the reduced 2 dimensions. There seems to be a cluster of mostly negative reviews."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Amazon ratings visualized in language using t-SNE')"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAEICAYAAAC6fYRZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9Z5gc13Wg/d4KncOE7skZYRBJgABBMFMUKVKkSGXKEiVRlrSSbK/X9rdeede7Xq/DrrRrr5LD2pZkW5YVrCyKSsxiJnIehAFmMDlP51BdVff7cRuDGWAAIjGB/T5PPzPdVXXrVjp17rknCCklFSpUqFDh8kR7tTtQoUKFChVePipCvkKFChUuYypCvkKFChUuYypCvkKFChUuYypCvkKFChUuYypCvkKFChUuYypC/jJGCNEmhMgIIfRXuy/wyvRHCHGLEGJo3vf9QohbLvE+/lkI8ednWHa/EOLhC2z3fwgh/vXienf5czHn+I3IZSfkhRBPCiFmhRDeV7svrzRCiH4hxG0nvkspB6SUISml82r26wSvRn+klKullE++gvv7hpTyLa/U/t6IvJzn+NRn6AzrrBZCPCyEmBFCJIQQ24UQd5WX3SKEkEKIvz1lm2eEEB8p//8RIYRTVnjmf5pejmO6rIS8EKIDuBGQwL2vbm8uLUII49XuQ4UKFQD4CfAI0ADUAf8BSM1bngU+VJZHZ+L5ssIz/zPycnT2shLywIeBF4B/Bh6Yv6A8xP5bIcTPy2/NZ4UQDUKIL5Q1/4NCiPXz1v/PQoijQoi0EOKAEOKd85btPuUNLE+YBIQQ95ZNBInyqGLlvO36hRC/L4TYI4RICiH+TQjhW+xAym/7Z4UQnxdCTAP/QwixRAjxuBBiWggxJYT4hhCiqrz+14E24CflPn1aCNFR7ptRXudJIcSfldtNl7WR2Lx9flgIcbzc/h/N12qEEJuEENuEECkhxLgQ4nNn6HePEOJt874bQohJIcRVi/TnI0KIY+W+9Akh7i//vsBssch2v17eT7q8/SfPdEOccgyJedcsW26zo7zsbUKIXeV1nhNCXDGvjfVCiB3l/f0bsOg1m3dMz8z7LoUQnxJCHCm3/TdCCHGm7U9p67tCiLHyvfKUEGL1vGX/XG7rp+V+vSiEWDJv+VuEEIfK2/6tEOJXQoiPX4rzW763RoUQI0KIj5e3XVpe5hVC/KUQYqB8n/ydEMJ/huN7qX6c6f4453MshNCFEP9XqOelTwjx7+fv45T+nPYMLbJODOgEviyltMqfZ6WUz8xbLYGSQX+82HG/4kgpL5sP0Av8JrABKAH185b9MzBVXuYDHgf6UC8GHfhz4Il5678XaEK9CN+Hejs3LrLPTwAHgQiwvLze7YAJfLrcJ0953X5gS7ndGqAH+NQZjuUjgA38NmAAfmBpuW0vEAeeAr4wb5t+4LZ53ztQoxqj/P1J4Gi5n/7y98+Wl60CMsANgAf4y/I5vK28/HngQ+X/Q8DmM/T7vwPfmPf9bqDn1P4AQZT2011e1gisLv//P4B/Pctx3A0sAQRwM5ADriovuwUYOtM5mff7/yqfPxNYD0wA15TvhQfK23nL5+I48Hvldd9TPi9/fpbr9sy87xJ4CKhCCZBJ4M4zbHvqcX8UCJf78QVg1yn38zSwqXw+vwF8u7wsVj637yov+51ynz9+Cc7vncAYsBoIAP9a3nZpefnngQdR93cYpfV+5hyP91zvj3M+x8CngANAC1ANPDr/WBfp06L3y7zlAjhS3t87mCdj5t9/KC1/fv+fAT6yWP9f7s9lo8kLIW4A2oHvSCm3o4TZB05Z7YdSyu1SygLwQ6AgpfwXqWzE/4Z62AGQUn5XSjkipXSllP+GurCbFtnnnwP3SilTqJfBT6WUj0gpSyhB6Qeum7fZl8rtzqAegHVnOawRKeVfSSltKWVeStlbbrsopZwEPod6CM+Hf5JSHpZS5oHvzNv/e4CfSCmfkVJaKGE9P7FRCVgqhIhJKTNSyhfO0P43gXuFEIHy9w8A3zrDui6wRgjhl1KOSin3n8sBSCl/KqU8KhW/Ah5GmenOCSHE+8r9enf5On0C+Hsp5YtSSkdK+TWgCGwuf0zUy7QkpfwesPVc91Xms1LKhJRyAHiCs1/zOaSU/yilTEspiyiBeKUQIjpvlR9KKbdIKW2UkD/R7l3AfinlD8rLvoQSzOfES5zf+1D30H4pZa7cLwDK2vMngN+TUs5IKdOol+mvneu+T+F87o8zneP7gC9KKYeklLPAZy+wLwBIJaXfhHoZ/F9gtDzKWnbKemPA3wF/eoamNpdHHSc+Ry+mX2fjshHyKO3rYSnlVPn7NznFZAOMz/s/v8j30IkvQpkuTgzfE8AalIZ0YnkrSkg+IKU8XP65CaX1ASCldIFBoHnefuY/bLn5+1yEwflfhBD1QohvCyGGhRAplBYVW3zTM3Km/TfN31/5AZ6et+7HUCOAg0KIrWKeSWY+Uspe1AjlnrKgvxd1LU5dL4t6KX4K9aD8VAix4lwOQAjxViHEC6I88YUSaud0HoQyyf018M7yixKUcvAf5z90QCvqnDQBw+WH+wTHOT/O55qf6KcuhPisUCbDFEqowMLjPNdrKVHa5TnxEud3Qdun/B9Haffb553HX5R/Py8u4P44p3Nxyv8vSdncdMLE94flvg1JKf+9lHIJ6t7JAv+yyOb/G7hDCHHlIstekFJWzfssWWSdS8JlIeTLNr/7gJvLNswx1PD6yjOc4Jdqrx34MvDvgVopZRWwDzVUO7G/H6G0u5/P23QEddFPtCNQwmL4Ag4LFmrSoLQiCayVUkaAD57o0xnWPx9GUUNaYO4Ya+calvKIlPL9qImm/w18TwgRPENb3wLeD7wdOFAW/KchpfyllPJ21FD8IOqcg3poAvNWbZjXLy/wfdQoqb58bX7GwvOwKEKIOtR1+y0p5c55iwaB/3nKQxeQUn4LdV6aT9h4y7S91L4uAR9Anb/bgCjKlAHncJycfi3F/O9c3Pld0Dbq/j7BFEpZWj3vPEallGd6qZ2xH3DW++N8OFt/F2PBMySl/JQ8OTH6v05bWcpB4G9QSuCpy6ZRZrY/O99OX0ouCyGPso05KLvyuvJnJfA0yuZ+vgRRF3sS1EQUCy/iPwIHpZT/55TtvgPcLYR4sxDCBP4jatj/3AX0YTHCKLt5UgjRDPynU5aPA10X2Pb3UNr3dUIID2oYPidQhBAfFELEy6OTRPln9wxtfRt4C/AbLKLFl9urF0K8vfyiKKKO60R7u4CbhPKrjwL/Zd6mHpSNehKwhRBvLe/rrJQn2r6HsgF/55TFXwY+JYS4RiiCQoi7hRBh1FyEDfwHIYQphHgXp5jtXibCqPMyjRKEpwmYs/BTYK0Q4h3l4/4tFgrQXVz4+f0O8OtCiJXlkdofnVhQvje+DHy+/EJFCNEshLjjDP08Yz9e4v44H74D/E65H1XAH7zE+md9hoQQ1UKIPxFCLBVCaOWJ2I+iHD4W43Moc+3KMyx/2blchPwDKDvhgJRy7MQHNSy/f7GZ9LMhpTyAsrc9j7roa4Fn563ya8A7xUIPmxullIdQ2vVfobSae4B7yjbuS8GfAFcBSdSD/INTln8G+G/lofLvn0/DZXvnb6ME9CjqoZpAPWCgJtz2CyEywBeBXyvb9RdraxR17q5DzXUshgb8f6jRzwxqbuE3yts/Ut5uD7AdNcl1ou00ymXtO8AsSuN98BwOsQVlV/7dU65bm5RyG/DvUPfLLGqy/CPl/VmoCcyPlPv5Pk4/7y8H/4IyCw2jJg7PJEROo2yyfC/wf1AviVXANsrX8mLOb3nk+iWU3bt3Xr9O3Cd/cOL3spnpUaD7DP08Yz84y/1xnnwZNaewB9iJGpXYKKVwMV7qGbJQo6pHUROr+1DH/pHFGpNqru7/oCai53OtON1P/urzObBzRSw0NVaooBBChFAa+zIpZd+r3J0KF4EQQkPZ5O+XUj5xidteiRJ03vIk72ua8sjk76SU7S+58mXC5aLJV7gECCHuEUIEykPkvwT2cnLCr8LrCCHEHUKIqrKN/Q9Rprd
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"import numpy as np\n",
"\n",
"colors = [\"red\", \"darkorange\", \"gold\", \"turquoise\", \"darkgreen\"]\n",
"x = [x for x,y in vis_dims]\n",
"y = [y for x,y in vis_dims]\n",
"color_indices = df.Score.values - 1\n",
"\n",
"colormap = matplotlib.colors.ListedColormap(colors)\n",
"plt.scatter(x, y, c=color_indices, cmap=colormap, alpha=0.3)\n",
"for score in [0,1,2,3,4]:\n",
" avg_x = np.array(x)[df.Score-1==score].mean()\n",
" avg_y = np.array(y)[df.Score-1==score].mean()\n",
" color = colors[score]\n",
" plt.scatter(avg_x, avg_y, marker='x', color=color, s=100)\n",
"plt.title(\"Amazon ratings visualized in language using t-SNE\")"
]
}
],
"metadata": {
"interpreter": {
"hash": "be4b5d5b73a21c599de40d6deb1129796d12dc1cc33a738f7bac13269cfcafe8"
},
"kernelspec": {
"display_name": "Python 3.7.3 64-bit ('base': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}