You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/Classification_using_embedd...

137 lines
147 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Classification using the embeddings\n",
"\n",
"In the classification task we predict one of the predefined categories given an input. We will predict the score based on the embedding of the review's text, where the algorithm is correct only if it guesses the exact number of stars. We split the dataset into a training and a testing set for all the following tasks, so we can realistically evaluate performance on unseen data. The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb).\n",
"\n",
"In the following example we're predicting the number of stars in a review, from 1 to 5."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 1 0.67 0.30 0.41 20\n",
" 2 1.00 0.38 0.55 8\n",
" 3 1.00 0.18 0.31 11\n",
" 4 1.00 0.26 0.41 27\n",
" 5 0.75 1.00 0.86 134\n",
"\n",
" accuracy 0.76 200\n",
" macro avg 0.88 0.42 0.51 200\n",
"weighted avg 0.80 0.76 0.71 200\n",
"\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import classification_report, accuracy_score\n",
"\n",
"# If you have not run the \"Obtain_dataset.ipynb\" notebook, you can download the datafile from here: https://cdn.openai.com/API/examples/data/fine_food_reviews_with_embeddings_1k.csv\n",
"datafile_path = \"./data/fine_food_reviews_with_embeddings_1k.csv\"\n",
"\n",
"df = pd.read_csv(datafile_path)\n",
"df[\"ada_similarity\"] = df.ada_similarity.apply(eval).apply(np.array)\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" list(df.ada_similarity.values), df.Score, test_size=0.2, random_state=42\n",
")\n",
"\n",
"clf = RandomForestClassifier(n_estimators=100)\n",
"clf.fit(X_train, y_train)\n",
"preds = clf.predict(X_test)\n",
"probas = clf.predict_proba(X_test)\n",
"\n",
"report = classification_report(y_test, preds)\n",
"print(report)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the model has learnt to distinguish between the categories decently. 5-star reviews show the best performance overall, and this is not too surprising, since they are the most common in the dataset."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RandomForestClassifier() - Average precision score over all classes: 0.87\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwsAAALLCAYAAAC2OR66AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3gU5d7G8e9s3/RO7yBNBMSGgICAgICioGIH7B27HguKhWPviF2P5VUEBEWlCiICggVFkd5betlNtu+8f0x2kyUB0neT/D7XtVdmJ7MzT7Jt7nmaoqqqihBCCCGEEEIcQRfuAgghhBBCCCEik4QFIYQQQgghRLkkLAghhBBCCCHKJWFBCCGEEEIIUS4JC0IIIYQQQohySVgQQgghhBBClEvCghBCCCGEEKJcEhaEEEIIIYQQ5ZKwIIQQQgghhCiXhAXRqEycOJG2bduGuxiNyu7du1EUhQ8//DBsZWjbti0TJ04MWbdt2zbOOecc4uPjURSFefPm8eGHH6IoCrt3766Vcpx77rlcd911wfsLFy4kJiaGzMzMWjleuCiKwmOPPVapx8h78+jKew899thjKIoSvkKF0XPPPUf79u3R6/X06tUr3MWpcStWrEBRFGbPnl3rx1ixYkWtHUM0HBIWRK0InHQFbgaDgRYtWjBx4kQOHDgQ7uKFxcSJE0P+J6VvCxcuDHfxyjh48CCPPfYYGzZsOOo2K1as4MILL6Rp06aYTCbS0tIYM2YMc+fOrbuCVtHVV1/Nxo0beeqpp/j444855ZRTavV4P//8M4sXL+b+++8PrhsxYgQdO3Zk+vTp1dr3oEGDQl5PSUlJnHrqqbz//vv4/f7qFr3BC5x4B25Go5G2bdty++23k5eXF+7iiVIWL17MfffdR79+/fjggw94+umnw10kIRo8Q7gLIBq2adOm0a5dO5xOJ2vXruXDDz9k1apV/P3331gslnAXr86ZzWbefffdMut79uwZhtIc28GDB3n88cdp27ZtuVfvpk6dyrRp0+jUqRM33HADbdq0ITs7m++++45x48bx6aefctlll9V9wcuxZcsWdLqSayMOh4M1a9bw0EMPceuttwbXX3nllUyYMAGz2VzjZXjuuecYMmQIHTt2DFl/ww03cM899/D4448TGxtb5f23bNkyGDoyMzP53//+xzXXXMPWrVv573//W62yV5bD4cBgqNzXyzvvvBP2YPPmm28SExNDYWEhy5Yt47XXXuP3339n1apVYS2XKPHDDz+g0+l47733MJlM4S6OEI2ChAVRq0aOHBm8YnvttdeSkpLCM888w9dff83FF18c5tLVPYPBwBVXXFEr+y4qKiIqKqpW9n2k2bNnM23aNMaPH89nn32G0WgM/u7ee+9l0aJFeDyeOilLRRx58h9o9pOQkBCyXq/Xo9fra+y4hYWFREdHk5GRwbfffsvMmTPLbDNu3Dhuu+02vvzySyZPnlzlY8XHx4e8tm644QY6d+7M66+/zhNPPBHyHAX4/X7cbneNB/eq7K+88tW18ePHk5KSAmj/vwkTJvDFF1+wbt06TjvttDCXLjJ5vV78fn+dnbhnZGRgtVpr7HiqquJ0OrFarTWyPyEaImmGJOrUgAEDANixY0dwndvt5tFHH6VPnz7Ex8cTHR3NgAEDWL58echjA+12n3/+ed5++206dOiA2Wzm1FNPZf369WWONW/ePE488UQsFgsnnngiX331VbllKiws5O6776ZVq1aYzWY6d+7M888/j6qqIdspisKtt97Kl19+Sbdu3bBarfTt25eNGzcC8NZbb9GxY0csFguDBg2qcrv3GTNm0L17d8xmM82bN+eWW24p0xRi0KBBnHjiifz222+cddZZREVF8Z///AcAl8vF1KlT6dixI2azmVatWnHffffhcrlC9rFkyRL69+9PQkICMTExdO7cObiPFStWcOqppwIwadKkYPOMQJvpRx55hKSkJN5///1yT/KGDx/O6NGjj/o3/vXXX0ycOJH27dtjsVho2rQpkydPJjs7O2Q7m83GlClTaNu2LWazmbS0NIYNG8bvv/8e3Gbbtm2MGzeOpk2bYrFYaNmyJRMmTCA/Pz+4Tek+C4899hht2rQBtGCjKEqwrfzR+ix8//33DBgwgOjoaGJjYxk1ahT//PNPyDYTJ04kJiaGHTt2cO655xIbG8vll18OwLfffovX62Xo0KFl/hdpaWmcdNJJzJ8/P2R9fn4+mzdvDvk7KiMqKoozzjiDwsLCYDgKvIY//fTT4Gss0ATuwIEDTJ48mSZNmmA2m+nevTvvv/9+mf06nU4ee+wxTjjhBCwWC82aNePCCy8MeU8f2WehIs9jeX0WKvveDLznA+WvbvO+8j6vAH755RdGjBhBfHw8UVFRDBw4kJ9//rnM4w8cOMA111xD8+bNMZvNtGvXjptuugm32w1ATk4O99xzDz169CAmJoa4uDhGjhzJn3/+Wa1yH+mXX37h3HPPJTExkejoaE466SReeeWV4O8HDRrEoEGDyjzuyOek9Gfwyy+/HPwM/uOPPzAYDDz++ONl9rFlyxYUReH1118PrsvLy2PKlCnB57Vjx44888wzx61ZUhSFDz74gMLCwjKfSV6vlyeeeCJYprZt2/Kf//ynzOde27ZtGT16NIsWLeKUU07BarXy1ltvHff/d7zne8+ePdx888107twZq9VKcnIyF110UbnfA3l5edx5553B90PLli256qqryMrKCtnO7/fz1FNP0bJlSywWC0OGDGH79u3HLGvA8V575fnpp5+46KKLaN26dfC7484778ThcIRsd/jwYSZNmkTLli0xm800a9aM888/P+Rv/fXXXxk+fDgpKSlYrVbatWtXrYshIrykZkHUqcCHSWJiYnBdQUEB7777LpdeeinXXXcdNpuN9957j+HDh7Nu3boyTWA+++wzbDYbN9xwA4qi8Oyzz3LhhReyc+fO4Inr4sWLGTduHN26dWP69OlkZ2cHP9xKU1WV8847j+XLl3PNNdfQq1cvFi1axL333suBAwd46aWXQrb/6aef+Prrr7nlllsAmD59OqNHj+a+++5jxowZ3HzzzeTm5vLss88yefJkfvjhhzL/gyO/EIxGI/Hx8YB2Ivv4448zdOhQbrrpJrZs2cKbb77J+vXr+fnnn0NOzLOzsxk5ciQTJkzgiiuuoEmTJvj9fs477zxWrVrF9ddfT9euXdm4cSMvvfQSW7duZd68eQD8888/jB49mpNOOolp06ZhNpvZvn178Auwa9euTJs2jUcffZTrr78+eNJ05plnsm3bNjZv3szkyZOr3GxmyZIl7Ny5k0mTJtG0aVP++ecf3n77bf755x/Wrl0b7Lh54403Mnv2bG699Va6detGdnY2q1at4t9//+Xkk0/G7XYzfPhwXC4Xt912G02bNuXAgQMsWLCAvLy84P+1tAsvvJCEhATuvPNOLr30Us4991xiYmKOWtaPP/6Yq6++muHDh/PMM89QVFTEm2++Sf/+/fnjjz9CTqa8Xi/Dhw+nf//+PP/888GantWrV5OcnBwMKUfq06dP8LkJ+Oqrr5g0aRIffPBBmc7ZFbVz5070en1IDcoPP/zArFmzuPXWW0lJSaFt27akp6dzxhlnBE+6U1NT+f7777nmmmsoKChgypQpAPh8PkaPHs2yZcuYMGECd9xxBzabjSVLlvD333/ToUOHcstxvOexPJV9b65atYq5c+dy8803Exsby6uvvsq4cePYu3cvycnJVfr/lfd59cMPPzBy5Ej69OnD1KlT0el0fPDBB5x99tn89NNPwRqIgwcPctppp5GXl8f1119Ply5dOHDgALNnz6aoqAiTycTOnTuZN28eF110Ee3atSM9PZ233nqLgQMHsmnTJpo3b16lcpe2ZMkSRo8eTbNmzbjjjjto2rQp//77LwsWLOCOO+6o0j4/+OADnE4n119/ffBkceDAgcyaNYupU6eGbPvFF1+g1+u56KKLAK0GdODAgRw4cIAbbriB1q1bs3r1ah588EEOHTrEyy+/fNTjfvzxx7z99tusW7cu2JzzzDPPBLS
"text/plain": [
"<Figure size 900x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from openai.embeddings_utils import plot_multiclass_precision_recall\n",
"\n",
"plot_multiclass_precision_recall(probas, y_test, [1, 2, 3, 4, 5], clf)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Unsurprisingly 5-star and 1-star reviews seem to be easier to predict. Perhaps with more data, the nuances between 2-4 stars could be better predicted, but there's also probably more subjectivity in how people use the inbetween scores."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai-cookbook",
"language": "python",
"name": "openai-cookbook"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}