You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/Classification_using_embedd...

142 lines
142 KiB
Plaintext

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Classification using embeddings\n",
"\n",
"There are many ways to classify text. This notebook shares an example of text classification using embeddings. For many text classification tasks, we've seen fine-tuned models do better than embeddings. See an example of fine-tuned models for classification in [Fine-tuned_classification.ipynb](Fine-tuned_classification.ipynb). We also recommend having more examples than embedding dimensions, which we don't quite achieve here.\n",
"\n",
"In this text classification task, we predict the score of a food review (1 to 5) based on the embedding of the review's text. We split the dataset into a training and a testing set for all the following tasks, so we can realistically evaluate performance on unseen data. The dataset is created in the [Get_embeddings_from_dataset Notebook](Get_embeddings_from_dataset.ipynb).\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 1 0.90 0.45 0.60 20\n",
" 2 1.00 0.38 0.55 8\n",
" 3 1.00 0.18 0.31 11\n",
" 4 0.88 0.26 0.40 27\n",
" 5 0.76 1.00 0.86 134\n",
"\n",
" accuracy 0.78 200\n",
" macro avg 0.91 0.45 0.54 200\n",
"weighted avg 0.81 0.78 0.73 200\n",
"\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from ast import literal_eval\n",
"\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import classification_report, accuracy_score\n",
"\n",
"datafile_path = \"data/fine_food_reviews_with_embeddings_1k.csv\"\n",
"\n",
"df = pd.read_csv(datafile_path)\n",
"df[\"embedding\"] = df.embedding.apply(literal_eval).apply(np.array) # convert string to array\n",
"\n",
"# split data into train and test\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" list(df.embedding.values), df.Score, test_size=0.2, random_state=42\n",
")\n",
"\n",
"# train random forest classifier\n",
"clf = RandomForestClassifier(n_estimators=100)\n",
"clf.fit(X_train, y_train)\n",
"preds = clf.predict(X_test)\n",
"probas = clf.predict_proba(X_test)\n",
"\n",
"report = classification_report(y_test, preds)\n",
"print(report)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the model has learnt to distinguish between the categories decently. 5-star reviews show the best performance overall, and this is not too surprising, since they are the most common in the dataset."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RandomForestClassifier() - Average precision score over all classes: 0.90\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwsAAALLCAYAAAC2OR66AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3xT9f7H8dfJ3ukulL0EQQRFcV9QUFTAhQNcDPfeXv151Ste5TquW0GvF7dXxT0BcVw3OMABsneB7iZpmtEk5/dHmtDQAh1pk7af5+ORR5OT5JxvmuTkvM93KaqqqgghhBBCCCHELjSpLoAQQgghhBAiPUlYEEIIIYQQQjRIwoIQQgghhBCiQRIWhBBCCCGEEA2SsCCEEEIIIYRokIQFIYQQQgghRIMkLAghhBBCCCEaJGFBCCGEEEII0SAJC0IIIYQQQogGSVgQncq0adPo3bt3qovRqWzcuBFFUXj++edTVobevXszbdq0hGVr1qzhuOOOw+l0oigK7777Ls8//zyKorBx48ZWKceJJ57IRRddFL89f/58bDYbJSUlrbK9VFEUhb///e9Neo58N3evoe/Q3//+dxRFSV2hUuiBBx6gb9++aLVahg8fnuriJN2XX36Joii8+eabrb6NL7/8stW2IToOCQuiVcQOumIXnU5Ht27dmDZtGoWFhakuXkpMmzYt4X9S9zJ//vxUF6+ebdu28fe//51ly5bt9jFffvklp512Gl26dMFgMJCXl8fEiRN5++23266gzTR16lR+//137rnnHl566SUOOuigVt3et99+y8KFC/nrX/8aX3b88cfTv39/Zs2a1aJ1jx49OuHzlJWVxcEHH8zcuXOJRCItLXqHFzvwjl30ej29e/fm6quvprKyMtXFE3UsXLiQm2++mSOOOILnnnuOe++9N9VFEqLD06W6AKJjmzlzJn369MHv9/PDDz/w/PPP88033/DHH39gMplSXbw2ZzQaefbZZ+stHzZsWApKs2fbtm3jrrvuonfv3g2evbvzzjuZOXMmAwYM4JJLLqFXr16UlZXx8ccfM2nSJF555RXOPvvsti94A1atWoVGs/PciM/n4/vvv+e2227jyiuvjC8/77zzmDx5MkajMelleOCBBxgzZgz9+/dPWH7JJZdw4403ctddd2G325u9/u7du8dDR0lJCS+++CIXXHABq1ev5p///GeLyt5UPp8Pna5pPy///ve/Ux5sZs+ejc1mw+v18tlnn/H444/zyy+/8M0336S0XGKnzz//HI1Gw3/+8x8MBkOqiyNEpyBhQbSqE044IX7G9sILLyQnJ4f77ruP999/nzPPPDPFpWt7Op2Oc889t1XWXV1djcViaZV17+rNN99k5syZnH766bz66qvo9fr4fTfddBMLFiygpqamTcrSGLse/Mea/WRkZCQs12q1aLXapG3X6/VitVopLi7mo48+Ys6cOfUeM2nSJK666irmzZvHjBkzmr0tp9OZ8Nm65JJLGDhwIE888QR33313wnsUE4lECAaDSQ/uzVlfQ+Vra6effjo5OTlA9P83efJkXn/9dZYsWcLIkSNTXLr0FAqFiEQibXbgXlxcjNlsTtr2VFXF7/djNpuTsj4hOiJphiTa1FFHHQXAunXr4suCwSB33HEHI0aMwOl0YrVaOeqoo/jiiy8Snhtrt/vggw/yzDPP0K9fP4xGIwcffDA//vhjvW29++677LfffphMJvbbbz/eeeedBsvk9Xq54YYb6NGjB0ajkYEDB/Lggw+iqmrC4xRF4corr2TevHkMHjwYs9nMYYcdxu+//w7A008/Tf/+/TGZTIwePbrZ7d6feuophgwZgtFopKCggCuuuKJeU4jRo0ez33778fPPP/OXv/wFi8XC//3f/wEQCAS488476d+/P0ajkR49enDzzTcTCAQS1vHpp59y5JFHkpGRgc1mY+DAgfF1fPnllxx88MEATJ8+Pd48I9Zm+vbbbycrK4u5c+c2eJA3btw4JkyYsNvX+NtvvzFt2jT69u2LyWSiS5cuzJgxg7KysoTHeTwerr32Wnr37o3RaCQvL49jjz2WX375Jf6YNWvWMGnSJLp06YLJZKJ79+5MnjwZl8sVf0zdPgt///vf6dWrFxANNoqixNvK767PwieffMJRRx2F1WrFbrczfvx4li9fnvCYadOmYbPZWLduHSeeeCJ2u51zzjkHgI8++ohQKMTYsWPr/S/y8vLYf//9ee+99xKWu1wuVq5cmfA6msJisXDooYfi9Xrj4Sj2GX7llVfin7FYE7jCwkJmzJhBfn4+RqORIUOGMHfu3Hrr9fv9/P3vf2efffbBZDLRtWtXTjvttITv9K59FhrzPjbUZ6Gp383Ydz5W/pY272tofwWwePFijj/+eJxOJxaLhVGjRvHtt9/We35hYSEXXHABBQUFGI1G+vTpw2WXXUYwGASgvLycG2+8kaFDh2Kz2XA4HJxwwgn8+uuvLSr3rhYvXsyJJ55IZmYmVquV/fffn0cffTR+/+jRoxk9enS95+36ntTdBz/yyCPxffDSpUvR6XTcdddd9daxatUqFEXhiSeeiC+rrKzk2muvjb+v/fv357777ttrzZKiKDz33HN4vd56+6RQKMTdd98dL1Pv3r35v//7v3r7vd69ezNhwgQWLFjAQQcdhNls5umnn97r/29v7/emTZu4/PLLGThwIGazmezsbM4444wGfwcqKyu57rrr4t+H7t27c/7551NaWprwuEgkwj333EP37t0xmUyMGTOGtWvX7rGsMXv77DXk66+/5owzzqBnz57x347rrrsOn8+X8LgdO3Ywffp0unfvjtFopGvXrpx88skJr/Wnn35i3Lhx5OTkYDab6dOnT4tOhojUkpoF0aZiO5PMzMz4MrfbzbPPPsuUKVO46KKL8Hg8/Oc//2HcuHEsWbKkXhOYV199FY/HwyWXXIKiKNx///2cdtpprF+/Pn7gunDhQiZNmsTgwYOZNWsWZWVl8Z1bXaqqctJJJ/HFF19wwQUXMHz4cBYsWMBNN91EYWEhDz/8cMLjv/76a95//32uuOIKAGbNmsWECRO4+eabeeqpp7j88supqKjg/vvvZ8aMGXz++ef1/ge7/iDo9XqcTicQPZC96667GDt2LJdddhmrVq1i9uzZ/Pjjj3z77bcJB+ZlZWWccMIJTJ48mXPPPZf8/HwikQgnnXQS33zzDRdffDH77rsvv//+Ow8//DCrV6/m3XffBWD58uVMmDCB/fffn5kzZ2I0Glm7dm38B3Dfffdl5syZ3HHHHVx88cXxg6bDDz+cNWvWsHLlSmbMmNHsZjOffvop69evZ/r06XTp0oXly5fzzDPPsHz5cn744Yd4x81LL72UN998kyuvvJLBgwdTVlbGN998w59//smBBx5IMBhk3LhxBAIBrrrqKrp06UJhYSEffvghlZWV8f9rXaeddhoZGRlcd911TJkyhRNPPBGbzbbbsr700ktMnTqVcePGcd9991FdXc3s2bM58sgjWbp0acLBVCgUYty4cRx55JE8+OCD8Zqe7777juzs7HhI2dWIESPi703MO++8w/Tp03nuuefqdc5urPXr16PVahNqUD7//HPeeOMNrrzySnJycujduzdFRUUceuih8YPu3NxcPvnkEy644ALcbjfXXnstAOFwmAkTJvDZZ58xefJkrrnmGjweD59++il//PEH/fr1a7Ace3sfG9LU7+Y333zD22+/zeWXX47dbuexxx5j0qRJbN68mezs7Gb9/xraX33++eeccMIJjBgxgjvvvBONRsNzzz3HMcccw9dffx2vgdi2bRsjR46ksrKSiy++mEGDBlFYWMibb75JdXU1BoOB9evX8+6773LGGWfQp08fioqKePrppxk1ahQrVqygoKCgWeWu69NPP2XChAl07dqVa665hi5duvDnn3/y4Ycfcs011zRrnc899xx+v5+LL744frA4atQo3njjDe68886Ex77++utotVrOOOMMIFoDOmrUKAoLC7nkkkvo2bMn3333Hbfeeivbt2/nkUce2e12X3rpJZ555hm
"text/plain": [
"<Figure size 900x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from utils.embeddings_utils import plot_multiclass_precision_recall\n",
"\n",
"plot_multiclass_precision_recall(probas, y_test, [1, 2, 3, 4, 5], clf)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Unsurprisingly 5-star and 1-star reviews seem to be easier to predict. Perhaps with more data, the nuances between 2-4 stars could be better predicted, but there's also probably more subjectivity in how people use the inbetween scores."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}