You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/Clustering.ipynb

278 lines
71 KiB
Plaintext

2 years ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Clustering\n",
"\n",
"We use a simple k-means algorithm to demonstrate how clustering can be done. Clustering can help discover valuable, hidden groupings within the data. The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb)."
]
},
{
"cell_type": "code",
"execution_count": 1,
2 years ago
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1000, 1536)"
2 years ago
]
},
"execution_count": 1,
2 years ago
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# imports\n",
2 years ago
"import numpy as np\n",
"import pandas as pd\n",
2 years ago
"\n",
"# load data\n",
"datafile_path = \"./data/fine_food_reviews_with_embeddings_1k.csv\"\n",
"\n",
"df = pd.read_csv(datafile_path)\n",
"df[\"embedding\"] = df.embedding.apply(eval).apply(np.array) # convert string to numpy array\n",
"matrix = np.vstack(df.embedding.values)\n",
2 years ago
"matrix.shape\n"
2 years ago
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Find the clusters using K-means"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We show the simplest use of K-means. You can pick the number of clusters that fits your use case best."
]
},
{
"cell_type": "code",
"execution_count": 2,
2 years ago
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ted/.virtualenvs/openai/lib/python3.9/site-packages/sklearn/cluster/_kmeans.py:870: FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning\n",
" warnings.warn(\n"
]
},
2 years ago
{
"data": {
"text/plain": [
"Cluster\n",
"0 4.105691\n",
"1 4.191176\n",
"2 4.215613\n",
"3 4.306590\n",
2 years ago
"Name: Score, dtype: float64"
]
},
"execution_count": 2,
2 years ago
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.cluster import KMeans\n",
"\n",
"n_clusters = 4\n",
"\n",
"kmeans = KMeans(n_clusters=n_clusters, init=\"k-means++\", random_state=42)\n",
2 years ago
"kmeans.fit(matrix)\n",
"labels = kmeans.labels_\n",
2 years ago
"df[\"Cluster\"] = labels\n",
2 years ago
"\n",
2 years ago
"df.groupby(\"Cluster\").Score.mean().sort_values()\n"
2 years ago
]
},
{
"cell_type": "code",
"execution_count": 3,
2 years ago
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Clusters identified visualized in language 2d using t-SNE')"
]
},
"execution_count": 3,
2 years ago
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEICAYAAABcVE8dAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAC1rUlEQVR4nOy9d3hc53nm/Xun94LBoA4I9i6JRRAly7Isy0WytVHkKNnEKbZXpjfJF3+xd7VJNsnG2U3izebTrtfxlqwVxdKmOTYTxZVykdVNkaAoUhQrwAJiMCgDTO/tfH88OMAARCMJiqQ093XhAjBzynvOnLnf572fpjRNo4EGGmiggbc/DNd6AA000EADDbw1aBB+Aw000MA7BA3Cb6CBBhp4h6BB+A000EAD7xA0CL+BBhpo4B2CBuE30EADDbxDcEMQvlLqD5VSf3Otx3EpUErtVUp9fJ73ViqlNKWU6SqdO6OUWj35t10p9W2lVFIp9Q2l1C8qpX5wmcf9hFLq5Ssd09XA7Hu60P2/gnPM+xwqpe5SSp26zONe9n19J+Bq3B+l1IrJZ9K4nMe93nHdEL5S6mNKqYOTH8Lw5Bf23ct4/KtKsrOhadr9mqY9dbXPo5R6Xin1qVnndmmadnby34eBViCgadrPapr2t5qmffBqj2s2Zo3prTjfW3L/6873kqZpG96q893IUErdrpT6oVIqppSKThoi7W/lGDRNuzD5TFaX+9hKqSeVUn+8yDY+pdRfKaVGlFJppdRppdTv1L2vKaWOKqUMda/9sVLqycm/dT7LzPr5lwud97ogfKXUvwH+O/AFhJxWAP8LePAaDmsG3qqJ4iqgGzitaVrlWg+kgQYm4Qe+AqxEns808NVrOaBrgC8CLmAT4AV+CuiftU0H8POLHMc3OXHpP/+w4Naapl3Tn8mLzQA/u8A2fwj8zeTf7wXCs94/D7x/8u/bgINAChgF/tvk6xcAbfJcGeCOydf/FXACiAPfB7rrjqsB/w/QB5wD1OQHNTZ5/KPA1nnG/Dzwqcm/jcBjwDhwdvKYGmCquwdPAMPAEPDHgHHyvU8AL0/uH58cx/2T7/0JUAUKk9f0P+rGvRb4j0AJKE++/4h+vLpxbgR+CMSAU8DP1b0XAL41ea0HgD+q33fW9e4FfmPWa0eAj9aPafLvDwPHkS/6EPBo/bXOOkb9fh8BXp8czyDwh3XbrZx1T+vv/5G6zz0zud17J9+7HfgJkJjc7r11x1wFvDA5zh8C/4PJ53CO638vdc8l8kw+CrwBJIF/AGzz7Dv7M/nS5PWlgNeAu2Z9F74O/N/JcR0Dbq17f8fkPUoD35g87x9f6f2dfP9XgAFgAvgPzPzeGYDfAc5Mvv91oGmJHLADSF/mczfjvl8CH8z1vPwR8MrkvfsB0LyUa5917k8j37cS8qx9e55xvwn89AL3RAN+G+EefYx/DDw51/iX+nM9EP59QGWhgXNphL8P+OXJv13A7fPdIGQF0Y/Msibg94GfzLrpPwSaADvwIeQL6EPIfxPQPs+Yn2eacH4VOAl0TR7ruVkP29PA/wGcQMvkQ/6v676kZWA3MnH8GhAB1OzzzPMlnrp3s7/0k+cbBD45ef3bkUlp8+T7X0O+uE5gK0LO833xfgV4pe7/zQiJWucY0zCTJIZYeztmj22ea3kvcBNCLjcjX+CfXuAL/Kk5xvnpyc/CA3QiX+APTx7zA5P/B+uepf8GWIH3IERwKYR/ALHSmhCj4lfn2XfGdQO/hJCeCfi3wAiTk8Xk51mYHLMR+M/Aq5PvWRBS+k3ADHwUIZ6lEv5C93czQmDvnjzPY8hzqX/vfhN4FQhN3q//A/z9Ejngs/o1XMZzN+O+Xy4fTD4vZ4D1yHf9eeBPl3Ltc4zpSf2eL3DNf4lM1p8E1s3xvgasQ/hG55ErJvzrQdIJAOPa8kkOZWCtUqpZ07SMpmmvLrDtrwL/WdO0E5Pn/wKwTSnVXbfNf9Y0LaZpWn7y2G7EKlaT+w0vYUw/B/x3TdMGNU2LIV9SAJRSrciX97OapmU1TRtDVhH1S7kBTdMe10RvfApoR6SvK8UDwHlN076qaVpF07TXgX8EfnbSmfUzwB9MjuvNyXPPh6eZee9+EfgnTdOKc2xbBjYrpTyapsU1TTu0lMFqmva8pmlHNU2raZr2BvD3wN1Lu1SY9An9MfBTmqalEGL9nqZp35s85g8Ra/DDSqkVQA/wHzRNK2qa9iLw7aWeaxJ/rmlaZPIz/zawbSk7aZr2N5qmTUx+Jv8VIdB6/8DLk2OuAn8N3DL5+u3IJPHnmqaVNU37J2TSWRIWub8PI9bqy5qmlYA/QAhHx68Cv6dpWnjyM/9D4OHFpFCl1M2Tx/p3k/9f6nO3GC6FD76qadrpye/615n+vBa79svBZ4C/BX4DOK6U6ldK3T9rGw1ZTfwHpZRlnuOMK6USdT+bFjrp9UD4E0DzMmrkjyCz9EmlVK9S6oEFtu0GvqTfLETWUIjlp2NQ/0PTtB8jy/r/CYwppb6ilPIsYUwd9cdBrLD6MZiB4bpx/B/E0tcxUjeG3OSfriWcdzF0A7vqHxiEqNuAIEIe8417BjRNSwPfZXqi+gXkgZ4LP4NMcgNKqReUUncsZbBKqV1KqecmHX1JhGSal7hvF/Il/rimaacnX+5GJrf66383MqF2AHFN07J1h5n3+ufBSN3fOZb4mSmlHlVKnZiMrEogkl/9dc4+rm3y+9MBDGmTJuAk6j+/xc670P2d8QxPPocTdbt3A0/X3ccTiNw4r2GilFqLSIG/qWnaS5MvX9JztwRcCh/M93ktdu0LYjIyTneq7p08Rl7TtC9omrYTMXq/DnxDKdVUv6+mad8DwsC/nufwzZqm+ep+Tiw0luuB8PcBReCnl7h9FnDo/0xaBEH9f03T+jRN+wWEMP8LsEcp5WTuGXkQkU7qb5hd07Sf1G0zYz9N0/588kPajDxI/24JYx5G5BwdK2aNocjMD86jadqWJRz3ovFdIgaBF2Zdv0vTtF8DoojUNt+458LfA78wSeA2RLq6eMCa1qtp2oPIZ/TPyMMOF3+2bbN2/TtE2+3SNM0L/AUyQS8IpZR98jz/XdO0vXVvDQJ/Pev6nZqm/Snymfknnx0di13/FUMpdRfwW8iq0K9pmg/xASx6nciYO5VS9dvWf35Xcn+HEblG39eOEJWOQcS3VH8vbZqmDc1znd3Aj4A/0jTtr+veutTn7nL54FKw2LXPxmzO+Ftt2qk624pncrX5BUTCWjXH8X4P+F3qrvNycc0JX9O0JLJE+p9KqZ9WSjmUUmal1P1KqT+bY5fTiEXzEaWUGdHdrfqbSqlfUkoFNU2rIRoyQA15kGpAfSz4XwD/Xim1ZXJfr1LqZ+cbq1KqZ9IKMiMPWmHymIvh68D/q5QKKaX8iHNLv/5hxEH0X5VSHqWUQSm1Rim1VKlidNY1XQq+A6xXSv3y5D03T17jpkm54J+AP5z8TDYDH1/keN9DLL3/BPzD5GcwA0opy6TF49U0rYw40/TtjgBblFLblFI2RBaohxuIaZpWUErdBnxsidf5V8BJTdNmP09/A/wLpdSHlFJGpZRNKfVepVRI07QBRN75j5NjfjfwL5Z4viuBGyG8KGBSSv0B4m9YCvYhVvVvKKVMSqkHEaeljiu5v3uQe/WuSXnhD5k5Cf0F8Ce6pKeUCk6e/yIopTqBHyNBBn9R/95lPHeXyweXgsWufTYW/U4qpf7D5HfNMvlZ/Obk+C7K5dA07XnEybvY929RXHPCB5jUKf8N8mFFEWvhNxCrbPa2SeDXEafHEEK84bpN7gOOKaUySLTDz08un3JIVMsrk8vO2zVNexqZ9b+mlEohN/WiGbgOHuBxJFpG99j/f0u4xMeRCKAjwCHkga7HryDOoOOTx96DyApLwZcQrTSulPrzJe4DTMkwH0RkmAiypP0vTH9hfgNZ1o4gjqivLnK8InJt70esxfnwy8D5yXv+q4iMxKTU8p8Qy68PiU6qx68D/0kplUaMhK+zNPw88JCaGa98l6Zpg4jj/neZfu7+HdPfi48BuxCp7/NIZMz
2 years ago
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
2 years ago
]
},
"metadata": {
"needs_background": "light"
},
2 years ago
"output_type": "display_data"
}
],
"source": [
"from sklearn.manifold import TSNE\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"\n",
"tsne = TSNE(n_components=2, perplexity=15, random_state=42, init=\"random\", learning_rate=200)\n",
2 years ago
"vis_dims2 = tsne.fit_transform(matrix)\n",
"\n",
2 years ago
"x = [x for x, y in vis_dims2]\n",
"y = [y for x, y in vis_dims2]\n",
2 years ago
"\n",
2 years ago
"for category, color in enumerate([\"purple\", \"green\", \"red\", \"blue\"]):\n",
" xs = np.array(x)[df.Cluster == category]\n",
" ys = np.array(y)[df.Cluster == category]\n",
2 years ago
" plt.scatter(xs, ys, color=color, alpha=0.3)\n",
"\n",
" avg_x = xs.mean()\n",
" avg_y = ys.mean()\n",
2 years ago
"\n",
" plt.scatter(avg_x, avg_y, marker=\"x\", color=color, s=100)\n",
"plt.title(\"Clusters identified visualized in language 2d using t-SNE\")\n"
2 years ago
]
},
{
"attachments": {},
2 years ago
"cell_type": "markdown",
"metadata": {},
"source": [
"Visualization of clusters in a 2d projection. In this run, the green cluster (#1) seems quite different from the others. Let's see a few samples from each cluster."
2 years ago
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Text samples in the clusters & naming the clusters\n",
"\n",
"Let's show random samples from each cluster. We'll use text-davinci-003 to name the clusters, based on a random sample of 5 reviews from that cluster."
2 years ago
]
},
{
"cell_type": "code",
"execution_count": 4,
2 years ago
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cluster 0 Theme: All of the reviews are positive and the customers are satisfied with the product they purchased.\n",
"5, Loved these gluten free healthy bars, saved $$ ordering on Amazon: These Kind Bars are so good and healthy & gluten free. My daughter ca\n",
"1, Should advertise coconut as an ingredient more prominently: First, these should be called Mac - Coconut bars, as Coconut is the #2\n",
"5, very good!!: just like the runts<br />great flavor, def worth getting<br />I even o\n",
"5, Excellent product: After scouring every store in town for orange peels and not finding an\n",
"5, delicious: Gummi Frogs have been my favourite candy that I have ever tried. of co\n",
2 years ago
"----------------------------------------------------------------------------------------------------\n",
"Cluster 1 Theme: All of the reviews are about pet food.\n",
"2, Messy and apparently undelicious: My cat is not a huge fan. Sure, she'll lap up the gravy, but leaves th\n",
"4, The cats like it: My 7 cats like this food but it is a little yucky for the human. Piece\n",
"5, cant get enough of it!!!: Our lil shih tzu puppy cannot get enough of it. Everytime she sees the\n",
"1, Food Caused Illness: I switched my cats over from the Blue Buffalo Wildnerness Food to this\n",
"5, My furbabies LOVE these!: Shake the container and they come running. Even my boy cat, who isn't \n",
2 years ago
"----------------------------------------------------------------------------------------------------\n",
"Cluster 2 Theme: All of the reviews are positive and express satisfaction with the product.\n",
"5, Fog Chaser Coffee: This coffee has a full body and a rich taste. The price is far below t\n",
"5, Excellent taste: This is to me a great coffee, once you try it you will enjoy it, this \n",
"4, Good, but not Wolfgang Puck good: Honestly, I have to admit that I expected a little better. That's not \n",
"5, Just My Kind of Coffee: Coffee Masters Hazelnut coffee used to be carried in a local coffee/pa\n",
"5, Rodeo Drive is Crazy Good Coffee!: Rodeo Drive is my absolute favorite and I'm ready to order more! That\n",
2 years ago
"----------------------------------------------------------------------------------------------------\n",
"Cluster 3 Theme: All of the reviews are about food or drink products.\n",
"5, Wonderful alternative to soda pop: This is a wonderful alternative to soda pop. It's carbonated for thos\n",
"5, So convenient, for so little!: I needed two vanilla beans for the Love Goddess cake that my husbands \n",
"2, bot very cheesy: Got this about a month ago.first of all it smells horrible...it tastes\n",
"5, Delicious!: I am not a huge beer lover. I do enjoy an occasional Blue Moon (all o\n",
"3, Just ok: I bought this brand because it was all they had at Ranch 99 near us. I\n",
2 years ago
"----------------------------------------------------------------------------------------------------\n"
]
}
],
"source": [
"import openai\n",
"\n",
"# Reading a review which belong to each group.\n",
"rev_per_cluster = 5\n",
2 years ago
"\n",
"for i in range(n_clusters):\n",
" print(f\"Cluster {i} Theme:\", end=\" \")\n",
2 years ago
"\n",
" reviews = \"\\n\".join(\n",
" df[df.Cluster == i]\n",
" .combined.str.replace(\"Title: \", \"\")\n",
" .str.replace(\"\\n\\nContent: \", \": \")\n",
" .sample(rev_per_cluster, random_state=42)\n",
" .values\n",
" )\n",
2 years ago
" response = openai.Completion.create(\n",
" engine=\"text-davinci-003\",\n",
2 years ago
" prompt=f'What do the following customer reviews have in common?\\n\\nCustomer reviews:\\n\"\"\"\\n{reviews}\\n\"\"\"\\n\\nTheme:',\n",
2 years ago
" temperature=0,\n",
" max_tokens=64,\n",
" top_p=1,\n",
" frequency_penalty=0,\n",
2 years ago
" presence_penalty=0,\n",
2 years ago
" )\n",
2 years ago
" print(response[\"choices\"][0][\"text\"].replace(\"\\n\", \"\"))\n",
2 years ago
"\n",
2 years ago
" sample_cluster_rows = df[df.Cluster == i].sample(rev_per_cluster, random_state=42)\n",
2 years ago
" for j in range(rev_per_cluster):\n",
" print(sample_cluster_rows.Score.values[j], end=\", \")\n",
" print(sample_cluster_rows.Summary.values[j], end=\": \")\n",
" print(sample_cluster_rows.Text.str[:70].values[j])\n",
2 years ago
"\n",
" print(\"-\" * 100)\n"
2 years ago
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's important to note that clusters will not necessarily match what you intend to use them for. A larger amount of clusters will focus on more specific patterns, whereas a small number of clusters will usually focus on largest discrepencies in the data."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai",
2 years ago
"language": "python",
"name": "python3"
2 years ago
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
2 years ago
},
2 years ago
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
2 years ago
},
"nbformat": 4,
"nbformat_minor": 2
}