2022-03-11 02:08:53 +00:00
{
"cells": [
{
2023-07-17 23:40:54 +00:00
"attachments": {},
2022-03-11 02:08:53 +00:00
"cell_type": "markdown",
"metadata": {},
"source": [
"## Regression using the embeddings\n",
"\n",
2023-09-20 21:50:59 +00:00
"Regression means predicting a number, rather than one of the categories. We will predict the score based on the embedding of the review's text. We split the dataset into a training and a testing set for all of the following tasks, so we can realistically evaluate performance on unseen data. The dataset is created in the [Get_embeddings_from_dataset Notebook](Get_embeddings_from_dataset.ipynb).\n",
2022-03-11 02:08:53 +00:00
"\n",
"We're predicting the score of the review, which is a number between 1 and 5 (1-star being negative and 5-star positive)."
]
},
{
"cell_type": "code",
2024-01-25 17:59:05 +00:00
"execution_count": 2,
2022-03-11 02:08:53 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-01-25 17:59:05 +00:00
"text-embedding-3-small performance on 1k Amazon reviews: mse=0.65, mae=0.52\n"
2022-03-11 02:08:53 +00:00
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
2023-07-17 23:40:54 +00:00
"from ast import literal_eval\n",
2022-03-11 02:08:53 +00:00
"\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error\n",
"\n",
2023-01-10 18:37:25 +00:00
"datafile_path = \"data/fine_food_reviews_with_embeddings_1k.csv\"\n",
2022-12-13 23:28:39 +00:00
"\n",
2022-07-12 00:02:00 +00:00
"df = pd.read_csv(datafile_path)\n",
2023-07-17 23:40:54 +00:00
"df[\"embedding\"] = df.embedding.apply(literal_eval).apply(np.array)\n",
2022-03-11 02:08:53 +00:00
"\n",
2023-01-10 18:37:25 +00:00
"X_train, X_test, y_train, y_test = train_test_split(list(df.embedding.values), df.Score, test_size=0.2, random_state=42)\n",
2022-03-11 02:08:53 +00:00
"\n",
"rfr = RandomForestRegressor(n_estimators=100)\n",
"rfr.fit(X_train, y_train)\n",
"preds = rfr.predict(X_test)\n",
"\n",
"mse = mean_squared_error(y_test, preds)\n",
"mae = mean_absolute_error(y_test, preds)\n",
"\n",
2024-01-25 17:59:05 +00:00
"print(f\"text-embedding-3-small performance on 1k Amazon reviews: mse={mse:.2f}, mae={mae:.2f}\")\n"
2022-03-11 02:08:53 +00:00
]
},
{
"cell_type": "code",
2024-01-25 17:59:05 +00:00
"execution_count": 3,
2022-03-11 02:08:53 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-12-13 23:28:39 +00:00
"Dummy mean prediction performance on Amazon reviews: mse=1.73, mae=1.03\n"
2022-03-11 02:08:53 +00:00
]
}
],
"source": [
"bmse = mean_squared_error(y_test, np.repeat(y_test.mean(), len(y_test)))\n",
"bmae = mean_absolute_error(y_test, np.repeat(y_test.mean(), len(y_test)))\n",
2022-07-12 00:02:00 +00:00
"print(\n",
" f\"Dummy mean prediction performance on Amazon reviews: mse={bmse:.2f}, mae={bmae:.2f}\"\n",
")\n"
2022-03-11 02:08:53 +00:00
]
},
{
2022-12-13 23:28:39 +00:00
"attachments": {},
2022-03-11 02:08:53 +00:00
"cell_type": "markdown",
"metadata": {},
"source": [
2023-01-10 18:37:25 +00:00
"We can see that the embeddings are able to predict the scores with an average error of 0.53 per score prediction. This is roughly equivalent to predicting half of reviews perfectly, and half off by one star."
2022-03-11 02:08:53 +00:00
]
},
{
2023-07-17 23:40:54 +00:00
"attachments": {},
2022-03-11 02:08:53 +00:00
"cell_type": "markdown",
"metadata": {},
"source": [
"You could also train a classifier to predict the label, or use the embeddings within an existing ML model to encode free text features."
]
}
],
"metadata": {
"kernelspec": {
2023-01-10 18:37:25 +00:00
"display_name": "openai",
2022-07-12 00:02:00 +00:00
"language": "python",
2023-01-10 18:37:25 +00:00
"name": "python3"
2022-03-11 02:08:53 +00:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-01-25 17:59:05 +00:00
"version": "3.11.5"
2022-03-11 02:08:53 +00:00
},
2022-07-12 00:02:00 +00:00
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
2022-03-11 02:08:53 +00:00
},
"nbformat": 4,
"nbformat_minor": 2
}