openai-cookbook/examples/Classification_using_embeddings.ipynb

137 lines
86 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Classification using the embeddings\n",
"\n",
"In the classification task we predict one of the predefined categories given an input. We will predict the score based on the embedding of the review's text, where the algorithm is correct only if it guesses the exact number of stars. We split the dataset into a training and a testing set for all the following tasks, so we can realistically evaluate performance on unseen data. The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb).\n",
"\n",
"In the following example we're predicting the number of stars in a review, from 1 to 5."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 1 0.52 0.72 0.60 18\n",
" 2 1.00 0.35 0.52 17\n",
" 3 0.50 0.12 0.20 8\n",
" 4 0.75 0.35 0.47 26\n",
" 5 0.84 0.99 0.91 131\n",
"\n",
" accuracy 0.80 200\n",
" macro avg 0.72 0.51 0.54 200\n",
"weighted avg 0.80 0.80 0.76 200\n",
"\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import classification_report, accuracy_score\n",
"\n",
"datafile_path = \"https://cdn.openai.com/API/examples/data/fine_food_reviews_with_embeddings_1k.csv\" # for your convenience, we precomputed the embeddings\n",
"df = pd.read_csv(datafile_path)\n",
"df[\"babbage_similarity\"] = df.babbage_similarity.apply(eval).apply(np.array)\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" list(df.babbage_similarity.values), df.Score, test_size=0.2, random_state=42\n",
")\n",
"\n",
"clf = RandomForestClassifier(n_estimators=100)\n",
"clf.fit(X_train, y_train)\n",
"preds = clf.predict(X_test)\n",
"probas = clf.predict_proba(X_test)\n",
"\n",
"report = classification_report(y_test, preds)\n",
"print(report)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the model has learnt to distinguish between the categories decently. 5-star reviews show the best performance overall, and this is not too surprising, since they are the most common in the dataset."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RandomForestClassifier() - Average precision score over all classes: 0.91\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjEAAAIDCAYAAAD13U9SAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAADzo0lEQVR4nOydd3xUVfqHn3Onz6QXeu8QehUVQRFBUYqKHUGBFcu66srqur+1rbura1vXBtiwoKio2BVRrFgAQUVEauiQPqmTTDm/P+7MMAlJSMJMZpKc5/O5ycy9d859584t3/ue97yvkFKiUCgUCoVC0dTQom2AQqFQKBQKRUNQIkahUCgUCkWTRIkYhUKhUCgUTRIlYhQKhUKhUDRJlIhRKBQKhULRJFEiRqFQKBQKRZNEiRhFtQgh7hRCvBRtO2IFIUSmEOL0CLU9Rgjxe8j73kKIjUKIIiHE9UKIhUKIvx/nNv4thLjB/3qgEGLNcZoddqruh1rWu00I8XRj2NQYCCE+F0LM9b+eLYT4Oto21QchhE0I8a4QwimEeD3a9tREuM7hSF4LFPXHGG0DFHVHCJEJtAa8QDHwEXCdlLI4mnbVByFEF2AXUBIye4eUclAj2iCBnlLK7SHzEoC7gXOBFOAw8C5wj5QyJ5L2SCm/AnqHzPoLsFpKOTgc7Qsh0oHLgR7+7f0shCgQQpwjpXy3jm18DpwAeAAX8CVwrZTyYDhs9NtVdT/UtN6/wrXNqjSHcywKnI++z1KllJ5oG6NoWShPTNPjHCllHDAYGAL8NbrmNJgkKWWcf6q3gBFChE2ACyHMwKdABjAJSABGA7nAyHBtpx50Bn493kZC9tFs4AMpZVnI4qXAVfVs8jr/sdcLSAIermWbTZnmco5VSwR+o87A1oYImGZyvCiiiBIxTRQp5SHgY/QLLQBCiFuFEDv83RCbhRDTQ5bNFkJ8LYR4QAiRL4TYJYQ4M2R5VyHEF/7PfgKkhW5PCDFFCPGr/wn+cyFE35BlmUKIBUKIn4UQJUKIZ4QQrYUQH/rbWyWESD7WdxJCtBNCvCOEyBNCbBdCzAtZdqcQYrkQ4iUhRCEwWwiR6N/WQSHEfiHEPUIIg3/9Hv7v4xRC5AghXvXP/9Lf5E9CiGIhxIXoXopOwHQp5WYppU9KmSWl/IeU8oNq7BwphPjWvy8OCiEe8wshhM7DQogsIUShEOIXIUR//7Kz/L9Lkd/em/3zxwkh9vlffwacCjzmt6+XEGKJEOKekO2f7e9uKhBCrBFCDKzyW9wihPgZKPHfJM4EvqjyNT4HxgshLMf6XaoipcwD3gAC3+uobQohTvDbViCE+EkIMS7ExhQhxHNCiAP+Y3FF1f3gf3+Lfz8VCSF+F0KM98+v1NVZh2PzZv+x6RRCvCqEsNbxe1Z3jjXkeyULId4TQmT7578nhOhQ1/0dihDi5JDt7xVCzPbPD3ZJ+d9X6pYSQkghxLVCiG3ANiHEk0KIB6q0/bYQ4ib/63ZCiDf8Nu8SQlxfgz13AbcDF/qP1zlCCE0I8X9CiN3+8+AFIUSif/0uflvmCCH2AJ/V0G5tx3iN1zn/8nlCiN9Clg8NWTy4rsfCMdoJrBPWa4GiAUgp1dREJiATON3/ugPwC/BIyPIZQDt0cXohepdNW/+y2YAbmAcYgKuBA4DwL/8WeAiwAKcARcBL/mW9/G1NAEzo3R3bAXOIXd+hu5TbA1nAj+hPsVb0C9Ud/nW7ABIwVvP9vgSe8H9mMJANnOZfdqff/mn+72cD3gIWAQ6gFfADcJV//VeAv/nXtQInh2xHAj1C3i8Dnq/Hvh+G3rVi9H+f34Ab/MsmAuvRPRUC6BvyGxwExvhfJwND/a/HAftCtvU5MDfk/RL0bi38+zQLGOX/HWf5bbOE2LkR6AjY/POygRHVfKdCYKD/9SXAz7V8/6BN6AL3M+DF6rbpPwZygbP8+3+C/326f/33gVf9+8AEjK26H9C7lfYC7UKOm+4hx0J9js0f0M+LFP9vNb8h59hxfK9U4DzADsQDrwMrati3s4Gva7CtM/p5ebG//VRgcA3HTKV20I/5T/z7wIZ+ju/lyPmfDJRx5PqxHl2cmIFuwE5gYg12BX8P//sr/b9BNyAOeJMjx0oXvy0voJ+3tmraO9YxXtt1bgawHxiBfv71ADrX91ioQzsRuRaoqQH3xWgboKZ6/Fj6yVPsv5BJ9C6QpFrW3whM9b+eDWwPWWb3t9EG3QvhARwhy1/myI3i78BrIcs0/wk+LsSuS0OWvwE8GfL+j/gv2iEXsYKQ6Wb0G6AXiA/53L+BJf7XdwJfhixrDZQTchFEv7iv9r9+AVgMdKhmv1QVMZ8A99Zh359ew7IbgLf8r08DtvovbFqV9fagd+EkVJk/jrqLmCeBf1T5/O8cuWFmAldWWe4G+lRj937glDoee58Dpf7faz96d1R6ddsEbsF/0wqZ9zH6zagt4AOSq9lGcD+g3zSygNMBU5X17qR+x+ZlIcv/AyxsyDnW0O9VzTYGA/nV/d7ULmL+GjjOavh9jiViTgt5L/zH4yn+9/OAz/yvRwF7qtn2czVsO/h7+N9/ClwT8r63/xgM3Ogl0K2W/VPrMV7N+hs5cp37GPhTLb9tnY6FOrQTkWuBmuo/qe6kpsc0KWU8+gW/DyHdPkKIy0NcsAXo7v7QbqFDgRdSylL/yzj0J5N8KWVosO3ukNftQt9LKX3oT3HtQ9Y5HPK6rJr3cVW+R5qUMsk/PeDfRp6UsqiKDaHb2BvyujP60+jBkO+7CN0jA/oTuQB+8Hc1XEnN5KLfhOqE0Lt43hNCHBJ619a/8O9nKeVnwGPA40CWEGKx0IOGQX8aPwvYLfSurtF13WYInYE/B76z/3t3RN9/AfZW+Uw+ugegKvHooqSuXO//vdpLKS+VUmbXsM3OwIwqNp6Mvo87ov/O+bVtSOpB1zeg3yCzhBDLhBDtqlm1LsfmoZDXpfiPRaF3dxb7p0tD1qnpHGvQ9xJC2IUQi/zdK4XoHsck4e/6rAcdgR31/Ewowd9I6nfSZejCH3RP3FL/685Auyrf8zb0B4e6UOk38b82Vvl81WM0lFqP8WNc5461j6o9FqqhTvs6ytcCBSompskipfwC/Qn9AQAhRGfgKeA69FECScAm9Bv5sTgIJAshHCHzOoW8PoB+YcG/LYF+ku9v+Dc4igNAihAi9Gbbqco2ZMjrveiemFAxlCClzAA9nkFKOU9K2Q79iecJIUSPGra9CphY5fvXxpPAFvQRTgnoF/jgfpZS/k9KOQzoh97dscA/f62Uciq60FoBvFbH7YWyF/hnyHdOklLapZSvhKwjq3zmZ78dQYQQ7dG7Co45pLmOVP1tXqxio0NKea9/WYoQIumYDUr5spTyZPRjTwL3VbNag49NKeWZ8khw+dJqllc6x47je/0Z3Rsxyn+8nBIw91g2VmEv0L2GZSXo3tUAbapZp+px8Qpwvv/aMQrdgxrYzq4q3zNeSnlWHe2s9JtwxNMb+mBT1ZZQajzG63Cdq20f1Ye6thPNa4ECJWKaOv8FJgghBqH3L0v0+AeEEFfgD7w8FlLK3cA64C4hhFkIcTJwTsgqrwGThRDjhRAm9ItyORC2XCNSyr3+9v4thLD6A/nmANXmqpH60N6VwINCiAShBxN2F0KMBRBCzBBHgifz0feNz//+MHp/fYAX0S9abwgh+vjbShV6PpLqLtzx6PEkxUKIPujxRfi3O0IIMcq/n0rQhyP7/Pv1UiFEopTS7f+8r5q2j8VTwHz/NoQQwiGEmFxF/FXlA2BslXlj0bsPyhtgw7F4CThHCDFRCGHw/57jhBAd/L/bh+iiMlkIYRJCnFK1AaHnyjlN6IHHLnRvXnX7K9LH5n85co419HvF++0vEEKkAHc00JalwOlCiAuEHjydKoQY7F+2ETjX7/XpgX7u1IqUcgOQAzwNfCylLPAv+gEoEnpgtc3/XfsLIUbU0c5XgBuFPlggDt078aqs++il2o7xY13nngZuFkIM83+2h1/41Je6thPNa4ECJWKaNH53/gvA7VLKzcCD6AG6h4EBwDf1aO4
"text/plain": [
"<Figure size 648x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from openai.embeddings_utils import plot_multiclass_precision_recall\n",
"\n",
"plot_multiclass_precision_recall(probas, y_test, [1, 2, 3, 4, 5], clf)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Unsurprisingly 5-star and 1-star reviews seem to be easier to predict. Perhaps with more data, the nuances between 2-4 stars could be better predicted, but there's also probably more subjectivity in how people use the inbetween scores."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.9 ('openai')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}