You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/Clustering.ipynb

289 lines
137 KiB
Plaintext

2 years ago
{
"cells": [
{
"attachments": {},
2 years ago
"cell_type": "markdown",
"metadata": {},
"source": [
"## K-means Clustering in Python using OpenAI\n",
2 years ago
"\n",
"We use a simple k-means algorithm to demonstrate how clustering can be done. Clustering can help discover valuable, hidden groupings within the data. The dataset is created in the [Get_embeddings_from_dataset Notebook](Get_embeddings_from_dataset.ipynb)."
2 years ago
]
},
{
"cell_type": "code",
"execution_count": 2,
2 years ago
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1000, 1536)"
2 years ago
]
},
"execution_count": 2,
2 years ago
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# imports\n",
2 years ago
"import numpy as np\n",
"import pandas as pd\n",
"from ast import literal_eval\n",
2 years ago
"\n",
"# load data\n",
"datafile_path = \"./data/fine_food_reviews_with_embeddings_1k.csv\"\n",
"\n",
"df = pd.read_csv(datafile_path)\n",
"df[\"embedding\"] = df.embedding.apply(literal_eval).apply(np.array) # convert string to numpy array\n",
"matrix = np.vstack(df.embedding.values)\n",
2 years ago
"matrix.shape\n"
2 years ago
]
},
{
"attachments": {},
2 years ago
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Find the clusters using K-means"
]
},
{
"attachments": {},
2 years ago
"cell_type": "markdown",
"metadata": {},
"source": [
"We show the simplest use of K-means. You can pick the number of clusters that fits your use case best."
]
},
{
"cell_type": "code",
"execution_count": 3,
2 years ago
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/lib/python3.11/site-packages/sklearn/cluster/_kmeans.py:870: FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning\n",
" warnings.warn(\n"
]
},
2 years ago
{
"data": {
"text/plain": [
"Cluster\n",
"0 4.105691\n",
"1 4.191176\n",
"2 4.215613\n",
"3 4.306590\n",
2 years ago
"Name: Score, dtype: float64"
]
},
"execution_count": 3,
2 years ago
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.cluster import KMeans\n",
"\n",
"n_clusters = 4\n",
"\n",
"kmeans = KMeans(n_clusters=n_clusters, init=\"k-means++\", random_state=42)\n",
2 years ago
"kmeans.fit(matrix)\n",
"labels = kmeans.labels_\n",
2 years ago
"df[\"Cluster\"] = labels\n",
2 years ago
"\n",
2 years ago
"df.groupby(\"Cluster\").Score.mean().sort_values()\n"
2 years ago
]
},
{
"cell_type": "code",
"execution_count": 4,
2 years ago
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Clusters identified visualized in language 2d using t-SNE')"
]
},
"execution_count": 4,
2 years ago
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGzCAYAAAABsTylAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd5zcdZ3/n98yfWZnZ7bOlmST3VQCpNC7GghFPVCqegKW4yx4nO3kTj25Q1Hv/NlRUQ48PQ8FwVOkI0UkIAkJENJ3sy0726b3mW/5/fFhtm+ym0KW5PvkMQ+y3/nO9/v5fuc738/r+66SaZomFhYWFhYWFhZzEPlID8DCwsLCwsLCYjosoWJhYWFhYWExZ7GEioWFhYWFhcWcxRIqFhYWFhYWFnMWS6hYWFhYWFhYzFksoWJhYWFhYWExZ7GEioWFhYWFhcWcxRIqFhYWFhYWFnMWS6hYWFhYWFhYzFksoTILWlpauO666470MN5U7r77biRJorOzc7/rHunzI0kSX/nKV8Yte+mllzjjjDPweDxIksTmzZv5yle+giRJh3Tf5513Huedd94h295Ux3IkmDiO2VwPh5KZXluH+rxdd911tLS0HLLtWRw4h/o3NhOs739uYAkVoL29nRtuuIGFCxfidDqpqKjgzDPP5Lvf/S65XO5NGUM2m+UrX/kKTz/99Juyv7cqDz300IwnolKpxBVXXEE0GuXb3/42v/jFL5g/f/7hHaCFhcUktm/fzuc//3lWrlyJz+cjFApxySWXsGHDhiM9tDnL7bffzt133z2rz7z22mtcfvnlzJ8/H6fTSWNjI+effz7f//73x63X0tKCJEnceOONk7bx9NNPI0kS991338iy8gPKdK8XXnjhgI5xpqiHdetvAf74xz9yxRVX4HA4+OAHP8iKFSsoFos899xzfO5zn+P111/njjvuOOzjyGaz3HLLLQBv+lPDvvjbv/1brr76ahwOx5EeCiCEyg9/+MMpxUoul0NVRy/p9vZ2urq6+OlPf8pHPvKRkeVf/OIX+cIXvvBmDPeAmXgsc4W5dj1MZK6et2Odn/3sZ9x55528973v5eMf/ziJRIKf/OQnnHbaaTzyyCOsXbv2SA9xSn76059iGMYR2fftt99OdXX1jK3Uzz//PG9729uYN28eH/3oR6mvr6enp4cXXniB7373u1OKkp/+9KfcfPPNNDQ0zGgf//Zv/8aCBQsmLW9ra5vR5w+UY/oXvWfPHq6++mrmz5/Pn/70J0Kh0Mh7n/jEJ9i9ezd//OMfj+AID55MJoPH4zngzyuKgqIoh3BEhw+n0znu78HBQQAqKyvHLVdVdc5PZhOPZa4w16+HuXrejnWuueYavvKVr+D1ekeWfehDH2LZsmV85StfmbNCxWazHekhzJivfvWr+P1+XnrppUn3vPK9cCzHHXccO3bs4Otf/zrf+973ZrSPiy66iJNOOulQDHdWHNOun29+85uk02nuvPPOcSKlTFtbG//wD/8w7eeni3WYyo+/YcMG1q1bR3V1NS6XiwULFvChD30IgM7OTmpqagC45ZZbRsxpY60G27dv5/LLLycYDOJ0OjnppJP4/e9/P+V+n3nmGT7+8Y9TW1tLU1MTAKlUiptuuomWlhYcDge1tbWcf/75vPzyy/s8R1Mdi2ma3HrrrTQ1NeF2u3nb297G66+/PuXn4/E4N910E83NzTgcDtra2vjGN74x7imls7MTSZL4z//8T+644w5aW1txOBycfPLJvPTSSyPrXXfddfzwhz8EGGd2LDP2nF133XWce+65AFxxxRVIkjRiqZrue/vlL3/JmjVrcLlcBINBrr76anp6eiatVx6jy+XilFNO4c9//vM+z2GZFStW8La3vW3ScsMwaGxs5PLLL5/yWGBm3990cRwTffvFYpEvf/nLrFmzBr/fj8fj4eyzz+app57a7zFMvB7K53Kq19ixGIbBd77zHY477jicTid1dXXccMMNxGKxcdufzbU1FRPPW3l8u3fv5rrrrqOyshK/38/1119PNpud8XbH8p//+Z+cccYZVFVV4XK5WLNmzTgz+dixfPKTn+R3v/sdK1aswOFwcNxxx/HII49MWvfpp5/mpJNOwul00trayk9+8pNJ12n5dzKVO2DicXd1dfHxj3+cJUuW4HK5qKqq4oorrpgytujVV1/l3HPPxeVy0dTUxK233spdd901ZSzSww8/zNlnn43H48Hn83HJJZfM6PtZs2bNOJECUFVVxdlnn822bdsmrX+gv7HZnKOZ/KYmxqjM9F5V5t5772X58uU4nU5WrFjBAw88MKO4l5aWFl5//XWeeeaZkd/T/izt7e3tHHfccZNECkBtbe2U+/jgBz/IT3/6U/r6+va57SPN3H6sPMz84Q9/YOHChZxxxhmHdT+Dg4NccMEF1NTU8IUvfIHKyko6Ozu5//77AaipqeFHP/oRH/vYx7jssst4z3veA8AJJ5wAwOuvv86ZZ55JY2MjX/jCF/B4PPzmN7/h0ksv5be//S2XXXbZuP19/OMfp6amhi9/+ctkMhkA/v7v/5777ruPT37ykyxfvpxIJMJzzz3Htm3bWL169ayO58tf/jK33norF198MRdffDEvv/wyF1xwAcVicdx62WyWc889l71793LDDTcwb948nn/+eW6++WbC4TDf+c53xq3/q1/9ilQqxQ033IAkSXzzm9/kPe95Dx0dHdhsNm644Qb6+vp4/PHH+cUvfrHPMd5www00Njbyta99jU996lOcfPLJ1NXVTbv+V7/6Vb70pS9x5ZVX8pGPfIShoSG+//3vc84557Bp06aRH/+dd97JDTfcwBlnnMFNN91ER0cH7373uwkGgzQ3N+9zTFdddRVf+cpX6O/vp76+fmT5c889R19fH1dfffW0nz2U318ymeRnP/sZ11xzDR/96EdJpVLceeedrFu3jr/+9a+sXLlyxtt6z3veM8nsu3HjRr7zne+MuznecMMN3H333Vx//fV86lOfYs+ePfzgBz9g06ZN/OUvfxl5cp3ptTVbrrzyShYsWMBtt93Gyy+/zM9+9jNqa2v5xje+Mettffe73+Xd734373//+ykWi9xzzz1cccUVPPjgg1xyySXj1n3uuee4//77+fjHP47P5+N73/se733ve+nu7qaqqgqATZs2ceGFFxIKhbjlllvQdZ1/+7d/G3l4ORBeeuklnn/+ea6++mqampro7OzkRz/6Eeeddx5bt27F7XYDsHfvXt72trchSRI333wzHo+Hn/3sZ1O69n7xi19w7bXXsm7dOr7xjW+QzWb50Y9+xFlnncWmTZsOKOi0v7+f6urqccsO5jc2Gw7mN7W/exWIsIKrrrqK448/nttuu41YLMaHP/xhGhsb9zu273znO9x44414vV7+5V/+BWCf9y+A+fPns379erZs2cKKFStmdA7+5V/+hf/+7/+esVUlkUgwPDw8bpkkSSPX8mHDPEZJJBImYP7N3/zNjD8zf/5889prrx35+1//9V/NqU7hXXfdZQLmnj17TNM0zQceeMAEzJdeemnabQ8NDZmA+a//+q+T3nvHO95hHn/88WY+nx9ZZhiGecYZZ5iLFi2atN+zzjrL1DRt3Db8fr/5iU98YoZHOv2xDA4Omna73bzkkktMwzBG1vvnf/5nExh3fv793//d9Hg85s6dO8dt8wtf+IKpKIrZ3d1tmqZp7tmzxwTMqqoqMxqNjqz3f//3fyZg/uEPfxhZ9olPfGLKc26a5qTz99RTT5mAee+9945bb+L31tnZaSqKYn71q18dt95rr71mqqo6srxYLJq1tbXmypUrzUKhMLLeHXfcYQLmueeeO+W4yuzYscMEzO9///vjln/84x83vV6vmc1mpz2WmXx/E6/PMueee+64sWmaNm78pmmasVjMrKurMz/0oQ+NWz5xHBOvh4kMDQ2Z8+bNM48//ngznU6bpmmaf/7zn03A/J//+Z9x6z7yyCPjls/m2pqOieMtf9cTj+uyyy4zq6qq9ru9a6+91pw/f/64ZWO/J9MU18WKFSvMt7/97ZPGYrfbzd2
2 years ago
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
2 years ago
]
},
"metadata": {},
2 years ago
"output_type": "display_data"
}
],
"source": [
"from sklearn.manifold import TSNE\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"\n",
"tsne = TSNE(n_components=2, perplexity=15, random_state=42, init=\"random\", learning_rate=200)\n",
2 years ago
"vis_dims2 = tsne.fit_transform(matrix)\n",
"\n",
2 years ago
"x = [x for x, y in vis_dims2]\n",
"y = [y for x, y in vis_dims2]\n",
2 years ago
"\n",
2 years ago
"for category, color in enumerate([\"purple\", \"green\", \"red\", \"blue\"]):\n",
" xs = np.array(x)[df.Cluster == category]\n",
" ys = np.array(y)[df.Cluster == category]\n",
2 years ago
" plt.scatter(xs, ys, color=color, alpha=0.3)\n",
"\n",
" avg_x = xs.mean()\n",
" avg_y = ys.mean()\n",
2 years ago
"\n",
" plt.scatter(avg_x, avg_y, marker=\"x\", color=color, s=100)\n",
"plt.title(\"Clusters identified visualized in language 2d using t-SNE\")\n"
2 years ago
]
},
{
"attachments": {},
2 years ago
"cell_type": "markdown",
"metadata": {},
"source": [
"Visualization of clusters in a 2d projection. In this run, the green cluster (#1) seems quite different from the others. Let's see a few samples from each cluster."
2 years ago
]
},
{
"attachments": {},
2 years ago
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Text samples in the clusters & naming the clusters\n",
"\n",
"Let's show random samples from each cluster. We'll use gpt-4 to name the clusters, based on a random sample of 5 reviews from that cluster."
2 years ago
]
},
{
"cell_type": "code",
"execution_count": 6,
2 years ago
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cluster 0 Theme: The theme of these customer reviews is food products purchased on Amazon.\n",
"5, Loved these gluten free healthy bars, saved $$ ordering on Amazon: These Kind Bars are so good and healthy & gluten free. My daughter ca\n",
"1, Should advertise coconut as an ingredient more prominently: First, these should be called Mac - Coconut bars, as Coconut is the #2\n",
"5, very good!!: just like the runts<br />great flavor, def worth getting<br />I even o\n",
"5, Excellent product: After scouring every store in town for orange peels and not finding an\n",
"5, delicious: Gummi Frogs have been my favourite candy that I have ever tried. of co\n",
2 years ago
"----------------------------------------------------------------------------------------------------\n",
"Cluster 1 Theme: Pet food reviews\n",
"2, Messy and apparently undelicious: My cat is not a huge fan. Sure, she'll lap up the gravy, but leaves th\n",
"4, The cats like it: My 7 cats like this food but it is a little yucky for the human. Piece\n",
"5, cant get enough of it!!!: Our lil shih tzu puppy cannot get enough of it. Everytime she sees the\n",
"1, Food Caused Illness: I switched my cats over from the Blue Buffalo Wildnerness Food to this\n",
"5, My furbabies LOVE these!: Shake the container and they come running. Even my boy cat, who isn't \n",
2 years ago
"----------------------------------------------------------------------------------------------------\n",
"Cluster 2 Theme: All the reviews are about different types of coffee.\n",
"5, Fog Chaser Coffee: This coffee has a full body and a rich taste. The price is far below t\n",
"5, Excellent taste: This is to me a great coffee, once you try it you will enjoy it, this \n",
"4, Good, but not Wolfgang Puck good: Honestly, I have to admit that I expected a little better. That's not \n",
"5, Just My Kind of Coffee: Coffee Masters Hazelnut coffee used to be carried in a local coffee/pa\n",
"5, Rodeo Drive is Crazy Good Coffee!: Rodeo Drive is my absolute favorite and I'm ready to order more! That\n",
2 years ago
"----------------------------------------------------------------------------------------------------\n",
"Cluster 3 Theme: The theme of these customer reviews is food and drink products.\n",
"5, Wonderful alternative to soda pop: This is a wonderful alternative to soda pop. It's carbonated for thos\n",
"5, So convenient, for so little!: I needed two vanilla beans for the Love Goddess cake that my husbands \n",
"2, bot very cheesy: Got this about a month ago.first of all it smells horrible...it tastes\n",
"5, Delicious!: I am not a huge beer lover. I do enjoy an occasional Blue Moon (all o\n",
"3, Just ok: I bought this brand because it was all they had at Ranch 99 near us. I\n",
2 years ago
"----------------------------------------------------------------------------------------------------\n"
]
}
],
"source": [
"from openai import OpenAI\n",
"import os\n",
"\n",
"client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
2 years ago
"\n",
"# Reading a review which belong to each group.\n",
"rev_per_cluster = 5\n",
2 years ago
"\n",
"for i in range(n_clusters):\n",
" print(f\"Cluster {i} Theme:\", end=\" \")\n",
2 years ago
"\n",
" reviews = \"\\n\".join(\n",
" df[df.Cluster == i]\n",
" .combined.str.replace(\"Title: \", \"\")\n",
" .str.replace(\"\\n\\nContent: \", \": \")\n",
" .sample(rev_per_cluster, random_state=42)\n",
" .values\n",
" )\n",
"\n",
" messages = [\n",
" {\"role\": \"user\", \"content\": f'What do the following customer reviews have in common?\\n\\nCustomer reviews:\\n\"\"\"\\n{reviews}\\n\"\"\"\\n\\nTheme:'}\n",
" ]\n",
"\n",
" response = client.chat.completions.create(\n",
" model=\"gpt-4\",\n",
" messages=messages,\n",
2 years ago
" temperature=0,\n",
" max_tokens=64,\n",
" top_p=1,\n",
" frequency_penalty=0,\n",
" presence_penalty=0)\n",
" print(response.choices[0].message.content.replace(\"\\n\", \"\"))\n",
2 years ago
"\n",
2 years ago
" sample_cluster_rows = df[df.Cluster == i].sample(rev_per_cluster, random_state=42)\n",
2 years ago
" for j in range(rev_per_cluster):\n",
" print(sample_cluster_rows.Score.values[j], end=\", \")\n",
" print(sample_cluster_rows.Summary.values[j], end=\": \")\n",
" print(sample_cluster_rows.Text.str[:70].values[j])\n",
2 years ago
"\n",
" print(\"-\" * 100)\n"
2 years ago
]
},
{
"attachments": {},
2 years ago
"cell_type": "markdown",
"metadata": {},
"source": [
"It's important to note that clusters will not necessarily match what you intend to use them for. A larger amount of clusters will focus on more specific patterns, whereas a small number of clusters will usually focus on largest discrepencies in the data."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai",
2 years ago
"language": "python",
"name": "python3"
2 years ago
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
2 years ago
},
2 years ago
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
2 years ago
},
"nbformat": 4,
"nbformat_minor": 2
}