You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/Classification_using_embedd...

143 lines
87 KiB
Plaintext

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Classification using embeddings\n",
"\n",
"There are many ways to classify text. This notebook shares an example of text classification using embeddings. For many text classification tasks, we've seen fine-tuned models do better than embeddings. See an example of fine-tuned models for classification in [Fine-tuned_classification.ipynb](Fine-tuned_classification.ipynb). We also recommend having more examples than embedding dimensions, which we don't quite achieve here.\n",
"\n",
"In this text classification task, we predict the score of a food review (1 to 5) based on the embedding of the review's text. We split the dataset into a training and a testing set for all the following tasks, so we can realistically evaluate performance on unseen data. The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb).\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 1 0.88 0.35 0.50 20\n",
" 2 1.00 0.38 0.55 8\n",
" 3 1.00 0.18 0.31 11\n",
" 4 1.00 0.26 0.41 27\n",
" 5 0.74 1.00 0.85 134\n",
"\n",
" accuracy 0.77 200\n",
" macro avg 0.92 0.43 0.52 200\n",
"weighted avg 0.82 0.77 0.72 200\n",
"\n"
]
}
],
"source": [
"# imports\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import classification_report, accuracy_score\n",
"\n",
"# load data\n",
"datafile_path = \"data/fine_food_reviews_with_embeddings_1k.csv\"\n",
"\n",
"df = pd.read_csv(datafile_path)\n",
"df[\"embedding\"] = df.embedding.apply(eval).apply(np.array) # convert string to array\n",
"\n",
"# split data into train and test\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" list(df.embedding.values), df.Score, test_size=0.2, random_state=42\n",
")\n",
"\n",
"# train random forest classifier\n",
"clf = RandomForestClassifier(n_estimators=100)\n",
"clf.fit(X_train, y_train)\n",
"preds = clf.predict(X_test)\n",
"probas = clf.predict_proba(X_test)\n",
"\n",
"report = classification_report(y_test, preds)\n",
"print(report)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the model has learnt to distinguish between the categories decently. 5-star reviews show the best performance overall, and this is not too surprising, since they are the most common in the dataset."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RandomForestClassifier() - Average precision score over all classes: 0.87\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjEAAAIDCAYAAAD13U9SAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAD3qElEQVR4nOydd5gb1dWH3zvq0u5qu+u6995Nc8N0g4HQQw+QkEASkuBAyBdaSEICqUAoCR1iA6Z3MLiBARdsDBj3bq+9fVdadel+f4wka729Srt732fnkUYzunM0O+U35557jpBSolAoFAqFQtHZ0JJtgEKhUCgUCkVLUCJGoVAoFApFp0SJGIVCoVAoFJ0SJWIUCoVCoVB0SpSIUSgUCoVC0SlRIkahUCgUCkWnRIkYRZ0IIe4UQjyXbDtSBSHEbiHESe3U9gwhxJaE+eFCiA1CCJcQ4mdCiEeEEL9r5Tb+JIS4Kfp+nBBiVSvNbnOO3g8NrHebEOK/HWFTRyCEWCaEuDb6/iohxCfJtqk5CCFsQog3hRCVQoiXkm1PfbTVOdye1wJF8zEm2wBF0xFC7AZ6AGHADbwH3CildCfTruYghBgA7AKqEz7eIaUc34E2SGColHJ7wmcZwN3A94Bs4DDwJnCPlLKkPe2RUq4Ehid89GtgqZRyQlu0L4TIA64AhkS3t1EIUSGEOEtK+WYT21gGHAOEAB+wArhBSlnYFjZG7Tp6P9S33h/baptH0xXOsSRwPvo+y5FShpJtjKJ7oTwxnY+zpJRpwARgIvCb5JrTYjKllGnRqdkCRgjRZgJcCGEGPgJGA6cBGcCxQCkwra220wz6A9+2tpGEfXQV8I6U0puw+HngR81s8sbosTcMyAT+3sA2OzNd5Ryrk3b4H/UHtrZEwHSR40WRRJSI6aRIKQ8B76NfaAEQQtwqhNgR7YbYJIQ4N2HZVUKIT4QQ9wshyoUQu4QQpycsHyiEWB797odAbuL2hBDzhRDfRp/glwkhRiYs2y2EWCCE2CiEqBZCPC6E6CGEeDfa3hIhRFZjv0kI0VsI8YYQokwIsV0IcV3CsjuFEIuFEM8JIaqAq4QQzui2CoUQB4QQ9wghDNH1h0R/T6UQokQI8UL08xXRJr8SQriFEBeheyn6AedKKTdJKSNSyiIp5e+llO/UYec0IcRn0X1RKIR4MCqEEDp/F0IUCSGqhBBfCyHGRJedEf2/uKL23hz9fLYQYn/0/cfAHODBqH3DhBBPCSHuSdj+mdHupgohxCohxLij/he3CCE2AtXRm8TpwPKjfsYyYK4QwtLY/+VopJRlwMtA7HfV2qYQ4piobRVCiK+EELMTbMwWQjwphDgYPRZfO3o/ROdvie4nlxBiixBibvTzGl2dTTg2b44em5VCiBeEENYm/s66zrGW/K4sIcRbQoji6OdvCSH6NnV/JyKEOCFh+/uEEFdFP493SUXna3RLCSGkEOIGIcQ2YJsQ4mEhxP1Htf26EOKX0fe9hRAvR23eJYT4WT323AXcDlwUPV6vEUJoQoj/E0LsiZ4HzwghnNH1B0RtuUYIsRf4uJ52GzrG673ORZdfJ4T4LmH5pITFE5p6LDTSTmydNr0WKFqAlFJNnWQCdgMnRd/3Bb4G/pmw/AKgN7o4vQi9y6ZXdNlVQBC4DjAAPwYOAiK6/DPgb4AFmAm4gOeiy4ZF2zoZMKF3d2wHzAl2fY7uUu4DFAFfoj/FWtEvVHdE1x0ASMBYx+9bAfw7+p0JQDFwYnTZnVH7z4n+PhvwKvAo4ADygdXAj6LrLwR+G13XCpyQsB0JDEmYXwQ83Yx9Pxm9a8UY/T3fATdFl50KrEP3VAhgZML/oBCYEX2fBUyKvp8N7E/Y1jLg2oT5p9C7tYju0yJgevT/eGXUNkuCnRuAAsAW/awYmFrHb6oCxkXffx/Y2MDvj9uELnA/Bp6ta5vRY6AUOCO6/0+OzudF138beCG6D0zArKP3A3q30j6gd8JxMzjhWGjOsbka/bzIjv6vrm/JOdaK35UDnAfYgXTgJeC1evbtVcAn9djWH/28vCTafg4woZ5jpkY76Mf8h9F9YEM/x/dx5PzPArwcuX6sQxcnZmAQsBM4tR674v+P6PwPov+DQUAa8ApHjpUBUVueQT9vbXW019gx3tB17gLgADAV/fwbAvRv7rHQhHba5VqgphbcF5NtgJqa8c/STx539EIm0btAMhtYfwNwdvT9VcD2hGX2aBs90b0QIcCRsPx/HLlR/A54MWGZFj3BZyfYdWnC8peBhxPmf0r0op1wEatImG5GvwGGgfSE7/0JeCr6/k5gRcKyHoCfhIsg+sV9afT9M8BjQN869svRIuZD4N4m7PuT6ll2E/Bq9P2JwNbohU07ar296F04GUd9Ppumi5iHgd8f9f0tHLlh7gZ+cNTyIDCiDrsPADObeOwtAzzR/9cB9O6ovLq2CdxC9KaV8Nn76DejXkAEyKpjG/H9gH7TKAJOAkxHrXcnzTs2L0tY/hfgkZacYy39XXVsYwJQXtf/m4ZFzG9ix1k9/5/GRMyJCfMiejzOjM5fB3wcfT8d2FvHtp+sZ9vx/0d0/iPgJwnzw6PHYOxGL4FBDeyfBo/xOtbfwJHr3PvAzxv43zbpWGhCO+1yLVBT8yfVndT5OEdKmY5+wR9BQrePEOKKBBdsBbq7P7Fb6FDsjZTSE32bhv5kUi6lTAy23ZPwvnfivJQygv4U1ydhncMJ7711zKcd9TtypZSZ0en+6DbKpJSuo2xI3Ma+hPf90Z9GCxN+76PoHhnQn8gFsDra1fAD6qcU/SbUJITexfOWEOKQ0Lu2/kh0P0spPwYeBB4CioQQjwk9aBj0p/EzgD1C7+o6tqnbTKA/8KvYb47+7gL0/Rdj31HfKUf3ABxNOrooaSo/i/6/+kgpL5VSFtezzf7ABUfZeAL6Pi5A/z+XN7QhqQdd34R+gywSQiwSQvSuY9WmHJuHEt57iB6LQu/udEenSxPWqe8ca9HvEkLYhRCPRrtXqtA9jpki2vXZDAqAHc38TiLx/5HU76SL0IU/6J6456Pv+wO9j/qdt6E/ODSFGv+T6HvjUd8/+hhNpMFjvJHrXGP7qM5joQ6atK+TfC1QoGJiOi1SyuXoT+j3Awgh+gP/AW5EHyWQCXyDfiNvjEIgSwjhSPisX8L7g+gXFqLbEugn+YGW/4JaHASyhRCJN9t+R21DJrzfh+6JSRRDGVLK0aDHM0gpr5NS9kZ/4vm3EGJIPdteApx61O9viIeBzegjnDLQL/Dx/Syl/JeUcjIwCr27Y0H08zVSyrPRhdZrwItN3F4i+4A/JPzmTCmlXUq5MGEdedR3NkbtiCOE6IPeVdDokOYmcvT/5tmjbHRIKe+NLssWQmQ22qCU/5NSnoB+7Engz3Ws1uJjU0p5ujwSXP58HctrnGOt+F2/QvdGTI8eLzNj5jZm41HsAwbXs6wa3bsao2cd6xx9XCwEzo9eO6aje1Bj29l11O9Ml1Ke0UQ7a/xPOOLpTXywOdqWROo9xptwnWtoHzWHpraTzGuBAiViOjv/AE4WQoxH71+W6PEPCCGuJhp42RhSyj3AWuAuIYRZCHECcFbCKi8C84QQc4UQJvSLsh9os1wjUsp90fb+JISwRgP5rgHqzFUj9aG9HwB/FUJkCD2YcLAQYhaAEOICcSR4shx930Si84fR++tjPIt+0XpZCDEi2laO0POR1HXhTkePJ3ELIUagxxcR3e5UIcT06H6qRh+OHInu10uFEE4pZTD6/UgdbTfGf4Dro9sQQgiHEGLeUeLvaN4BZh312Sz07gN/C2xojOeAs4QQpwohDNH/52whRN/o/+1ddFGZJYQwCSFmHt2A0HPlnCj0wGMfujevrv3V3sfmPzhyjrX0d6VH7a8QQmQDd7TQlueBk4QQFwo9eDpHCDEhumwD8L2o12cI+rnTIFLK9UAJ8F/gfSllRXTRasAl9MBqW/S3jhFCTG2inQuBXwh9sEAaunfiBdn00UsNHeONXef+C9wshJgc/e6QqPBpLk1tJ5nXAgVKxHRqou78Z4DbpZSbgL+iB+geBsY
"text/plain": [
"<Figure size 648x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from openai.embeddings_utils import plot_multiclass_precision_recall\n",
"\n",
"plot_multiclass_precision_recall(probas, y_test, [1, 2, 3, 4, 5], clf)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Unsurprisingly 5-star and 1-star reviews seem to be easier to predict. Perhaps with more data, the nuances between 2-4 stars could be better predicted, but there's also probably more subjectivity in how people use the inbetween scores."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}