2022-03-11 02:08:53 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Regression using the embeddings\n",
"\n",
"Regression means predicting a number, rather than one of the categories. We will predict the score based on the embedding of the review's text. We split the dataset into a training and a testing set for all of the following tasks, so we can realistically evaluate performance on unseen data. The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb).\n",
"\n",
"We're predicting the score of the review, which is a number between 1 and 5 (1-star being negative and 5-star positive)."
]
},
{
"cell_type": "code",
2022-07-12 00:02:00 +00:00
"execution_count": 1,
2022-03-11 02:08:53 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-12-13 23:28:39 +00:00
"Ada similarity embedding performance on 1k Amazon reviews: mse=0.60, mae=0.51\n"
2022-03-11 02:08:53 +00:00
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error\n",
"\n",
2022-12-13 23:28:39 +00:00
"# If you have not run the \"Obtain_dataset.ipynb\" notebook, you can download the datafile from here: https://cdn.openai.com/API/examples/data/fine_food_reviews_with_embeddings_1k.csv\n",
"datafile_path = \"./data/fine_food_reviews_with_embeddings_1k.csv\"\n",
"\n",
2022-07-12 00:02:00 +00:00
"df = pd.read_csv(datafile_path)\n",
2022-12-13 23:28:39 +00:00
"df[\"ada_similarity\"] = df.ada_similarity.apply(eval).apply(np.array)\n",
2022-03-11 02:08:53 +00:00
"\n",
2022-12-13 23:28:39 +00:00
"X_train, X_test, y_train, y_test = train_test_split(list(df.ada_similarity.values), df.Score, test_size=0.2, random_state=42)\n",
2022-03-11 02:08:53 +00:00
"\n",
"rfr = RandomForestRegressor(n_estimators=100)\n",
"rfr.fit(X_train, y_train)\n",
"preds = rfr.predict(X_test)\n",
"\n",
"mse = mean_squared_error(y_test, preds)\n",
"mae = mean_absolute_error(y_test, preds)\n",
"\n",
2022-12-13 23:28:39 +00:00
"print(f\"Ada similarity embedding performance on 1k Amazon reviews: mse={mse:.2f}, mae={mae:.2f}\")\n"
2022-03-11 02:08:53 +00:00
]
},
{
"cell_type": "code",
2022-07-12 00:02:00 +00:00
"execution_count": 2,
2022-03-11 02:08:53 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-12-13 23:28:39 +00:00
"Dummy mean prediction performance on Amazon reviews: mse=1.73, mae=1.03\n"
2022-03-11 02:08:53 +00:00
]
}
],
"source": [
"bmse = mean_squared_error(y_test, np.repeat(y_test.mean(), len(y_test)))\n",
"bmae = mean_absolute_error(y_test, np.repeat(y_test.mean(), len(y_test)))\n",
2022-07-12 00:02:00 +00:00
"print(\n",
" f\"Dummy mean prediction performance on Amazon reviews: mse={bmse:.2f}, mae={bmae:.2f}\"\n",
")\n"
2022-03-11 02:08:53 +00:00
]
},
{
2022-12-13 23:28:39 +00:00
"attachments": {},
2022-03-11 02:08:53 +00:00
"cell_type": "markdown",
"metadata": {},
"source": [
2022-12-13 23:28:39 +00:00
"We can see that the embeddings are able to predict the scores with an average error of 0.60 per score prediction. This is roughly equivalent to predicting 1 out of 3 reviews perfectly, and 1 out of two reviews by a one star error."
2022-03-11 02:08:53 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You could also train a classifier to predict the label, or use the embeddings within an existing ML model to encode free text features."
]
}
],
"metadata": {
"kernelspec": {
2022-12-13 23:28:39 +00:00
"display_name": "openai-cookbook",
2022-07-12 00:02:00 +00:00
"language": "python",
2022-12-13 23:28:39 +00:00
"name": "openai-cookbook"
2022-03-11 02:08:53 +00:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2022-12-13 23:28:39 +00:00
"version": "3.9.6"
2022-03-11 02:08:53 +00:00
},
2022-07-12 00:02:00 +00:00
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
2022-03-11 02:08:53 +00:00
},
"nbformat": 4,
"nbformat_minor": 2
}