You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/Visualizing_embeddings_in_2...

151 lines
93 KiB
Plaintext

2 years ago
{
"cells": [
{
"attachments": {},
2 years ago
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing the embeddings in 2D\n",
"\n",
"We will use t-SNE to reduce the dimensionality of the embeddings from 1536 to 2. Once the embeddings are reduced to two dimensions, we can plot them in a 2D scatter plot. The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb)."
2 years ago
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Reduce dimensionality\n",
"\n",
"We reduce the dimensionality to 2 dimensions using t-SNE decomposition."
]
},
{
"cell_type": "code",
"execution_count": 1,
2 years ago
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1000, 2)"
]
},
"execution_count": 1,
2 years ago
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"from sklearn.manifold import TSNE\n",
"import numpy as np\n",
2 years ago
"\n",
"# Load the embeddings\n",
"datafile_path = \"data/fine_food_reviews_with_embeddings_1k.csv\"\n",
"df = pd.read_csv(datafile_path)\n",
2 years ago
"\n",
"# Convert to a list of lists of floats\n",
"matrix = np.array(df.embedding.apply(eval).to_list())\n",
2 years ago
"\n",
"# Create a t-SNE model and transform the data\n",
"tsne = TSNE(n_components=2, perplexity=15, random_state=42, init='random', learning_rate=200)\n",
"vis_dims = tsne.fit_transform(matrix)\n",
"vis_dims.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Plotting the embeddings\n",
"\n",
"We colour each review by its star rating, ranging from red to green."
]
},
{
"attachments": {},
2 years ago
"cell_type": "markdown",
"metadata": {},
"source": [
"We can observe a decent data separation even in the reduced 2 dimensions."
2 years ago
]
},
{
"cell_type": "code",
"execution_count": 2,
2 years ago
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Amazon ratings visualized in language using t-SNE')"
]
},
"execution_count": 2,
2 years ago
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAEICAYAAAC6fYRZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9d7gkx3WYjb/VYXION+d7N0dgE3ImSIIEE5hBipRo0bJy/GRLP9u0gi3rJ4uUZNn8lEiKosRMgiRA5LjIu9gcb85pcg7dXd8fPXuxC+wiLrALcN595tk7091Vp6u7T1edOnWOkFLSpEmTJk3enigXWoAmTZo0afLG0VTyTZo0afI2pqnkmzRp0uRtTFPJN2nSpMnbmKaSb9KkSZO3MU0l36RJkyZvY5pK/m2MEKJHCFEQQqgXWhZ4c+QRQlwnhJg57fsRIcR157mOrwoh/uQc224XQtz7Gsv9ghDiX16fdG9/Xk8b/yzytlPyQoiHhRBpIYTzQsvyZiOEmBBC3HTqu5RySkrpk1KaF1KuU1wIeaSUG6SUD7+J9X1DSnnzm1XfzyJvZBu/8Bk6xz4bhBD3CiFSQoiMEGKvEOKWxrbrhBBSCPF/XnDMbiHEZxt/f1YIYTY6PKd/Ot6Ic3pbKXkhRB9wNSCB911Yac4vQgjtQsvQpEkTAH4M3Ae0AS3ArwO507YXgU839NG5eLLR4Tn9M/dGCPu2UvLAzwFPAV8FPnP6hsYQ+/8IIX7aeGs+LoRoE0J8qdHzPy6EuOS0/f+jEGJUCJEXQhwVQnzwtG0HXvAGlqdMAkKI9zVMBJnGqGLdacdNCCF+VwhxUAiRFUJ8SwjhOtuJNN72jwshviiESAJfEEIMCiEeFEIkhRAJIcQ3hBChxv5fB3qAHzdk+n+EEH0N2bTGPg8LIf64UW6+0RuJnVbnzwkhJhvl/+fTezVCiJ1CiD1CiJwQYlEI8ZfnkPuYEOK9p33XhBDLQohLzyLPZ4UQYw1ZxoUQtzd+P8NscZbjfr5RT75x/L8/1w3xgnPInHbNio0y+xrb3iuE2N/Y5wkhxObTyrhECPFco75vAWe9Zqed0+7TvkshxC8JIYYbZf+tEEKc6/gXlPUdIcRC4155VAix4bRtX22UdWdDrqeFEIOnbb9ZCHGicez/EUI8IoT4d+ejfRv31rwQYk4I8e8axw41tjmFEH8hhJhq3CdfFkK4z3F+LyfHue6PV9zGQghVCPG/hP28jAshfvX0Ol4gz4ueobPsEwP6gb+XUtYan8ellLtP2y2DrYP+69nO+01HSvm2+QAjwC8D24A60Hratq8CicY2F/AgMI79YlCBPwEeOm3/jwAd2C/Cj2G/ndvPUufngeNAAFjd2O8dgA78Pw2ZHI19J4BnGuVGgGPAL53jXD4LGMCvARrgBoYaZTuBOPAo8KXTjpkAbjrtex/2qEZrfH8YGG3I6W58/7PGtvVAAbgKcAB/0WjDmxrbnwQ+3fjbB1x2Drn/C/CN076/Bzj2QnkAL3bvZ01jWzuwofH3F4B/eYnzeA8wCAjgWqAEXNrYdh0wc642Oe33/95oPx24BFgCdjXuhc80jnM22mIS+K3Gvh9utMufvMR1233adwn8BAhhK5Bl4F3nOPaF5/0LgL8hx5eA/S+4n5PAzkZ7fgP4ZmNbrNG2H2ps+42GzP/uPLTvu4AFYAPgAf6lcexQY/sXgR9h399+7F7v/3iF5/tK749X3MbALwFHgS4gDNx/+rmeRaaz3i+nbRfAcKO+D3Cajjn9/sPu5Z8u/27gs2eT/43+vG168kKIq4Be4NtSyr3YyuyTL9jtB1LKvVLKCvADoCKl/Gdp24i/hf2wAyCl/I6Uck5KaUkpv4V9YXeepc4/Ad4npcxhvwzulFLeJ6WsYytKN3DFaYf9daPcFPYDsPUlTmtOSvk3UkpDSlmWUo40yq5KKZeBv8R+CF8NX5FSnpRSloFvn1b/h4EfSyl3Sylr2Mr69MBGdWBICBGTUhaklE+do/x/Bd4nhPA0vn8S+Ldz7GsBG4UQbinlvJTyyCs5ASnlnVLKUWnzCHAvtpnuFSGE+FhDrtsa1+nzwP8rpXxaSmlKKb8GVIHLGh8d+2Val1J+F3j2ldbV4M+klBkp5RTwEC99zVeQUv6TlDIvpaxiK8QtQojgabv8QEr5jJTSwFbyp8q9BTgipfx+Y9tfYyvmV8TLtO9Hse+hI1LKUkMuABq9588DvyWlTEkp89gv04+/0rpfwKu5P87Vxh8F/kpKOSOlTAN/9hplAUDaWvp67JfB/wLmG6OsVS/YbwH4MvBH5yjqssao49Rn9PXI9VK8bZQ8du/rXillovH9X3mByQZYPO3v8lm++059Ebbp4tTwPQNsxO4hndreja0kPyOlPNn4uQO71weAlNICpoHO0+o5/WErnV7nWZg+/YsQolUI8U0hxKwQIofdi4qd/dBzcq76O06vr/EAJ0/b93PYI4DjQohnxWkmmdORUo5gj1BubSj692FfixfuV8R+Kf4S9oNypxBi7Ss5ASHEu4UQT4nGxBe2UntF7SBsk9z/Bj7YeFGC3Tn4ndMfOqAbu006gNnGw32KSV4dr+aan5JTFUL8mbBNhjlspQJnnucrvZYSu3f5iniZ9j2j7Bf8Hcfu3e89rR3vbvz+qngN98craosX/P2yNMxNp0x8f9CQbUZK+atSykHse6cI/PNZDv+fwDuFEFvOsu0pKWXotM/gWfY5L7wtlHzD5vdR4NqGDXMBe3i95RwN/HLl9QJ/D/wqEJVShoDD2EO1U/X9ELt399PTDp3DvuinyhHYymL2NZwWnNmTBrtXJIFNUsoA8KlTMp1j/1fDPPaQFlg5x+hKwVIOSyk/gT3R9D+B7wohvOco69+ATwDvB442FP+LkFLeI6V8B/ZQ/Dh2m4P90HhO27XtNLmcwPewR0mtjWtzF2e2w1kRQrRgX7dfkVLuO23TNPCnL3joPFLKf8Nul85TNt4GPS9X13ngk9jtdxMQxDZlwCs4T158LcXp33l97XtG2dj39ykS2J2lDae1Y1BKea6X2jnlgJe8P14NLyXv2TjjGZJS/pJ8fmL0v79oZymngb/F7gS+cFsS28z2x69W6PPJ20LJY9vGTGy78tbGZx3wGLbN/dXixb7Yy2BPRHHmRfwn4LiU8s9fcNy3gfcIIW4UQujA72AP+594DTKcDT+23TwrhOgEfu8F2xeBgddY9nexe99XCCEc2MPwFYUihPiUECLeGJ1kGj9b5yjrm8DNwH/gLL34RnmtQoj3N14UVezzOlXefuAaYfvVB4H/dNqhDmwb9TJgCCHe3ajrJWlMtH0X2wb87Rds/nvgl4QQu4SNVwjxHiGEH3suwgB+XQihCyE+xAvMdm8Qfux2SWIrwhcpmJfgTmCTEOIDjfP+Fc5UoPt57e37beDnhRDrGiO1/3xqQ+Pe+Hvgi40XKkKITiHEO88h5znleJn749XwbeA3GnKEgN9/mf1f8hkSQoSFEP9NCDEkhFAaE7G/gO3wcTb+Ettcu+4c299w3i5K/jPYdsIpKeXCqQ/2sPz2s82kvxRSyqPY9rYnsS/6JuDx03b5OPBBcaaHzdVSyhPYveu/we7V3Arc2rBxnw/+G3ApkMV+kL//gu3/A/j/NYbKv/tqCm7YO38NW0HPYz9US9gPGNgTbkeEEAXgr4CPN+z6ZytrHrvtrsCe6zgbCvDb2KOfFPbcwn9oHH9f47iDwF7sSa5TZeexXda+DaSxe7w/egWn2IVtV/7NF1y3HinlHuAXse+XNPZk+Wcb9dWwJzA/25DzY7y43d8I/hnbLDSLPXF4LiXyIhomy48Af479klgP7KFxLV9P+zZGrn+NbfceOU2uU/fJ75/6vWFmuh9Ycw45zykHL3F/vEr+HntO4SCwD3tUYmB3Cs/Gyz1DNexR1f3YE6uHsc/9s2crTNpzdX+OPRF9OpeLF/vJ73g1J/ZKEWeaGps0sRFC+LB77KuklOMXWJwmrwMhhIJtk79dSvnQeS57HbaiczYmeS9qGiOTL0spe19257cJb5eefJPzgBDiViGEpzF
2 years ago
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
2 years ago
]
},
"metadata": {
"needs_background": "light"
},
2 years ago
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"import numpy as np\n",
"\n",
"colors = [\"red\", \"darkorange\", \"gold\", \"turquoise\", \"darkgreen\"]\n",
"x = [x for x,y in vis_dims]\n",
"y = [y for x,y in vis_dims]\n",
"color_indices = df.Score.values - 1\n",
"\n",
"colormap = matplotlib.colors.ListedColormap(colors)\n",
"plt.scatter(x, y, c=color_indices, cmap=colormap, alpha=0.3)\n",
"for score in [0,1,2,3,4]:\n",
" avg_x = np.array(x)[df.Score-1==score].mean()\n",
" avg_y = np.array(y)[df.Score-1==score].mean()\n",
" color = colors[score]\n",
" plt.scatter(avg_x, avg_y, marker='x', color=color, s=100)\n",
"\n",
2 years ago
"plt.title(\"Amazon ratings visualized in language using t-SNE\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai",
"language": "python",
"name": "python3"
2 years ago
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
2 years ago
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
2 years ago
},
"nbformat": 4,
"nbformat_minor": 2
}