openai-cookbook/examples/Classification_using_embeddings.ipynb

144 lines
137 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Classification using embeddings\n",
"\n",
"There are many ways to classify text. This notebook shares an example of text classification using embeddings. For many text classification tasks, we've seen fine-tuned models do better than embeddings. See an example of fine-tuned models for classification in [Fine-tuned_classification.ipynb](Fine-tuned_classification.ipynb). We also recommend having more examples than embedding dimensions, which we don't quite achieve here.\n",
"\n",
"In this text classification task, we predict the score of a food review (1 to 5) based on the embedding of the review's text. We split the dataset into a training and a testing set for all the following tasks, so we can realistically evaluate performance on unseen data. The dataset is created in the [Get_embeddings_from_dataset Notebook](Get_embeddings_from_dataset.ipynb).\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 1 0.89 0.40 0.55 20\n",
" 2 1.00 0.38 0.55 8\n",
" 3 1.00 0.18 0.31 11\n",
" 4 1.00 0.26 0.41 27\n",
" 5 0.75 1.00 0.86 134\n",
"\n",
" accuracy 0.77 200\n",
" macro avg 0.93 0.44 0.53 200\n",
"weighted avg 0.82 0.77 0.72 200\n",
"\n"
]
}
],
"source": [
"# imports\n",
"import pandas as pd\n",
"import numpy as np\n",
"from ast import literal_eval\n",
"\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import classification_report, accuracy_score\n",
"\n",
"# load data\n",
"datafile_path = \"data/fine_food_reviews_with_embeddings_1k.csv\"\n",
"\n",
"df = pd.read_csv(datafile_path)\n",
"df[\"embedding\"] = df.embedding.apply(literal_eval).apply(np.array) # convert string to array\n",
"\n",
"# split data into train and test\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" list(df.embedding.values), df.Score, test_size=0.2, random_state=42\n",
")\n",
"\n",
"# train random forest classifier\n",
"clf = RandomForestClassifier(n_estimators=100)\n",
"clf.fit(X_train, y_train)\n",
"preds = clf.predict(X_test)\n",
"probas = clf.predict_proba(X_test)\n",
"\n",
"report = classification_report(y_test, preds)\n",
"print(report)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the model has learnt to distinguish between the categories decently. 5-star reviews show the best performance overall, and this is not too surprising, since they are the most common in the dataset."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RandomForestClassifier() - Average precision score over all classes: 0.87\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwsAAALLCAYAAAC2OR66AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd5wTdf7H8ddMet++S68KggjYxYaCYgEb6unZ8ey9nl6xnnJ2Tz3bWc/yU+yeFWxn4+wooiC9s72k1/n9MZuwYXdhe5Ldz/PxyGMnk0nmm02b93ybommahhBCCCGEEEJsQc10AYQQQgghhBDZScKCEEIIIYQQokUSFoQQQgghhBAtkrAghBBCCCGEaJGEBSGEEEIIIUSLJCwIIYQQQgghWiRhQQghhBBCCNEiCQtCCCGEEEKIFklYEEIIIYQQQrRIwoLoU04//XSGDh2a6WL0KatWrUJRFJ566qmMlWHo0KGcfvrpaeuWLl3KwQcfjMfjQVEUXn/9dZ566ikURWHVqlXdUo7DDjuMs846K3X9vffew+l0UllZ2S37yxRFUbjhhhvadR/5bLaupc/QDTfcgKIomStUBt1xxx0MHz4cg8HAhAkTMl2cLvfJJ5+gKAovv/xyt+/jk08+6bZ9iN5DwoLoFsmDruTFaDQyYMAATj/9dNavX5/p4mXE6aefnvY/aXp57733Ml28ZjZs2MANN9zAggULWt3mk08+4ZhjjqGsrAyz2UxJSQkzZszg1Vdf7bmCdtBpp53GwoULueWWW3jmmWfYddddu3V/X3zxBXPnzuWPf/xjat0hhxzCyJEjmT17dqcee/LkyWnvp4KCAnbbbTeeeOIJEolEZ4ve6yUPvJMXk8nE0KFDufjii6mrq8t08UQTc+fO5eqrr2bvvffmySef5NZbb810kYTo9YyZLoDo3W666SaGDRtGKBTif//7H0899RSff/45P//8M1arNdPF63EWi4XHHnus2frx48dnoDRbt2HDBm688UaGDh3a4tm766+/nptuuontttuOc845hyFDhlBdXc0777zDzJkzee655/j973/f8wVvwZIlS1DVzedGgsEg8+fP589//jMXXnhhav0pp5zCCSecgMVi6fIy3HHHHUyZMoWRI0emrT/nnHO48sorufHGG3G5XB1+/IEDB6ZCR2VlJf/+978588wz+e233/j73//eqbK3VzAYxGhs38/Lv/71r4wHm4ceegin04nf7+fDDz/k/vvv5/vvv+fzzz/PaLnEZh999BGqqvL4449jNpszXRwh+gQJC6JbHXrooakztn/4wx8oKiritttu48033+T444/PcOl6ntFo5OSTT+6Wxw4EAtjt9m557C29/PLL3HTTTRx77LE8//zzmEym1G1XXXUV77//PtFotEfK0hZbHvwnm/3k5eWlrTcYDBgMhi7br9/vx+FwUFFRwdtvv83DDz/cbJuZM2dy0UUX8dJLLzFr1qwO78vj8aS9t8455xxGjRrFAw88wM0335z2GiUlEgkikUiXB/eOPF5L5etpxx57LEVFRYD+/zvhhBN48cUX+frrr9l9990zXLrsFIvFSCQSPXbgXlFRgc1m67L9aZpGKBTCZrN1yeMJ0RtJMyTRo/bdd18Ali9fnloXiUS47rrr2GWXXfB4PDgcDvbdd18+/vjjtPsm2+3eeeedPProo4wYMQKLxcJuu+3GN99802xfr7/+OjvuuCNWq5Udd9yR1157rcUy+f1+rrjiCgYNGoTFYmHUqFHceeedaJqWtp2iKFx44YW89NJLjBkzBpvNxl577cXChQsBeOSRRxg5ciRWq5XJkyd3uN37gw8+yNixY7FYLPTv358LLrigWVOIyZMns+OOO/Ldd9+x3377Ybfb+dOf/gRAOBzm+uuvZ+TIkVgsFgYNGsTVV19NOBxOe4x58+axzz77kJeXh9PpZNSoUanH+OSTT9htt90AOOOMM1LNM5Jtpv/6179SUFDAE0880eJB3rRp05g+fXqrz/Gnn37i9NNPZ/jw4VitVsrKypg1axbV1dVp23m9Xi699FKGDh2KxWKhpKSEgw46iO+//z61zdKlS5k5cyZlZWVYrVYGDhzICSecQH19fWqbpn0WbrjhBoYMGQLowUZRlFRb+db6LLz77rvsu+++OBwOXC4Xhx9+OIsWLUrb5vTTT8fpdLJ8+XIOO+wwXC4XJ510EgBvv/02sViMqVOnNvtflJSUsNNOO/HGG2+kra+vr2fx4sVpz6M97HY7e+65J36/PxWOku/h5557LvUeSzaBW79+PbNmzaK0tBSLxcLYsWN54oknmj1uKBTihhtuYPvtt8dqtdKvXz+OOeaYtM/0ln0W2vI6ttRnob2fzeRnPln+zjbva+n7CuCrr77ikEMOwePxYLfb2X///fniiy+a3X/9+vWceeaZ9O/fH4vFwrBhwzjvvPOIRCIA1NTUcOWVVzJu3DicTidut5tDDz2UH3/8sVPl3tJXX33FYYcdRn5+Pg6Hg5122ol//OMfqdsnT57M5MmTm91vy9ek6Xfwvffem/oO/uGHHzAajdx4443NHmPJkiUoisIDDzyQWldXV8ell16ael1HjhzJbbfdts2aJUVRePLJJ/H7/c2+k2KxGDfffHOqTEOHDuVPf/pTs++9oUOHMn36dN5//3123XVXbDYbjzzyyDb/f9t6vVevXs3555/PqFGjsNlsFBYWctxxx7X4O1BXV8dll12W+jwMHDiQU089laqqqrTtEokEt9xyCwMHDsRqtTJlyhSWLVu21bImbeu915LPPvuM4447jsGDB6d+Oy677DKCwWDadps2beKMM85g4MCBWCwW+vXrx5FHHpn2XL/99lumTZtGUVERNpuNYcOGdepkiMgsqVkQPSr5ZZKfn59a19DQwGOPPcaJJ57IWWedhdfr5fHHH2fatGl8/fXXzZrAPP/883i9Xs455xwUReH222/nmGOOYcWKFakD17lz5zJz5kzGjBnD7Nmzqa6uTn25NaVpGkcccQQff/wxZ555JhMmTOD999/nqquuYv369dxzzz1p23/22We8+eabXHDBBQDMnj2b6dOnc/XVV/Pggw9y/vnnU1tby+23386sWbP46KOPmv0PtvxBMJlMeDweQD+QvfHGG5k6dSrnnXceS5Ys4aGHHuKbb77hiy++SDswr66u5tBDD+WEE07g5JNPprS0lEQiwRFHHMHnn3/O2WefzQ477MDChQu55557+O2333j99dcBWLRoEdOnT2ennXbipptuwmKxsGzZstQP4A477MBNN93Eddddx9lnn506aJo0aRJLly5l8eLFzJo1q8PNZubNm8eKFSs444wzKCsrY9GiRTz66KMsWrSI//3vf6mOm+eeey4vv/wyF154IWPGjKG6uprPP/+cX3/9lZ133plIJMK0adMIh8NcdNFFlJWVsX79et566y3q6upS/9emjjnmGPLy8rjssss48cQTOeyww3A6na2W9ZlnnuG0005j2rRp3HbbbQQCAR566CH22Wcffvjhh7SDqVgsxrRp09hnn3248847UzU9X375JYWFhamQsqVddtkl9dokvfbaa5xxxhk8+eSTzTpnt9WKFSswGAxpNSgfffQRc+bM4cILL6SoqIihQ4dSXl7OnnvumTroLi4u5t133+XMM8+koaGBSy+9FIB4PM706dP58MMPOeGEE7jkkkvwer3MmzePn3/+mREjRrRYjm29ji1p72fz888/59VXX+X888/H5XJx3333MXPmTNasWUNhYWGH/n8tfV999NFHHHrooeyyyy5cf/31qKrKk08+yYEHHshnn32WqoHYsGEDu+++O3V1dZx99tmMHj2a9evX8/LLLxMIBDCbzaxYsYLXX3+d4447jmHDhlFeXs4jjzzC/vvvzy+//EL//v07VO6m5s2bx/Tp0+nXrx+XXHIJZWVl/Prrr7z11ltccsklHXrMJ598klAoxNlnn506WNx///2ZM2cO119/fdq2L774IgaDgeOOOw7Qa0D3339/1q9fzznnnMPgwYP58ssvufbaa9m4cSP33nt
"text/plain": [
"<Figure size 900x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from utils.embeddings_utils import plot_multiclass_precision_recall\n",
"\n",
"plot_multiclass_precision_recall(probas, y_test, [1, 2, 3, 4, 5], clf)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Unsurprisingly 5-star and 1-star reviews seem to be easier to predict. Perhaps with more data, the nuances between 2-4 stars could be better predicted, but there's also probably more subjectivity in how people use the inbetween scores."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}