You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/User_and_product_embeddings...

175 lines
15 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## User and product embeddings\n",
"\n",
"We calculate user and product embeddings based on the training set, and evaluate the results on the unseen test set. We will evaluate the results by plotting the user and product similarity versus the review score. The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Calculate user and product embeddings\n",
"\n",
"We calculate these embeddings simply by averaging all the reviews about the same product or written by the same user within the training set."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(24502, 19035)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"df = pd.read_csv('output/embedded_babbage_similarity_50k.csv', index_col=0) # note that you will need to generate this file to run the code below\n",
"df['babbage_similarity'] = df.babbage_similarity.apply(eval).apply(np.array)\n",
"X_train, X_test, y_train, y_test = train_test_split(df, df.Score, test_size = 0.2, random_state=42)\n",
"\n",
"user_embeddings = X_train.groupby('UserId').babbage_similarity.apply(np.mean)\n",
"prod_embeddings = X_train.groupby('ProductId').babbage_similarity.apply(np.mean)\n",
"len(user_embeddings), len(prod_embeddings)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that most of the users and products appear within the 50k examples only once."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Evaluate the embeddings\n",
"\n",
"To evaluate the recommendations, we look at the similarity of the user and product embeddings amongst the reviews in the unseen test set. We calculate the cosine distance between the user and product embeddings, which gives us a similarity score between 0 and 1. We then normalize the scores to be evenly split between 0 and 1, by calculating the percentile of the similarity score amongst all predicted scores."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from openai.embeddings_utils import cosine_similarity\n",
"\n",
"# evaluate embeddings as recommendations on X_test\n",
"def evaluate_single_match(row):\n",
" user_id = row.UserId\n",
" product_id = row.ProductId\n",
" try:\n",
" user_embedding = user_embeddings[user_id]\n",
" product_embedding = prod_embeddings[product_id]\n",
" similarity = cosine_similarity(user_embedding, product_embedding)\n",
" return similarity\n",
" except Exception as e:\n",
" return np.nan\n",
"\n",
"X_test['cosine_similarity'] = X_test.apply(evaluate_single_match, axis=1)\n",
"X_test['percentile_cosine_similarity'] = X_test.cosine_similarity.rank(pct=True)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 2.1 Visualize cosine similarity by review score\n",
"\n",
"We group the cosine similarity scores by the review score, and plot the distribution of cosine similarity scores for each review score."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Correlation between user&vector similarity percentile metric and review number of stars (score): 22.11%\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEcCAYAAADA5t+tAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAaO0lEQVR4nO3de5Qc5X3m8e+j0QUMWEhgTwAJDY5lBa2IAQlYDt5YWhEsbEciCUkgNjGLYMxiJfJhFySIFzA+2YPjS7LGJFnFloEQJBx8OQqIiw3TdnzBIK42wmR1QLJGsCtuQgwgpBG//aNKuJnumWnoalXPvM/nnDp0Vb1d9etXwzPvvF1drYjAzMzSMabsAszMbO9y8JuZJcbBb2aWGAe/mVliHPxmZolx8JuZJcbBb3uVpJD03rLrKJOkuZJ6h9iffB9Zazn4EyVpo6RXJfVJekHSrZKmll3XHpLOlvSjsusYySSNl/QlSb35v/NGSX9bdl1WPgd/2n4vIvYHDgH+H3B1yfW0jKSxZddQgkuAOcDxwAHAXOCBIk+QaL+OeA5+IyJ2ADcDM/dskzRR0vWSnpG0SdJnJI2RNDkfQf5e3m5/SRsk/Vm+fq2kf5D0PUkvSfqBpGn1zjvEOY4E/gE4MR+pbhvk+UdI+mF+nu9LukbSDfm+rnzKZLGkXwF358f+TH6urfm5J+bta6Zf8hHyyfnjKyTdLOmm/HwPSHp/VdtDJX0rfy1PSvqLqn375v3ygqT1wHEN/LN8WNITkp6V9IW89vGSnpd0VNWx3y3pFUnvqnOM44DvRMRTkdkYEddXPXeqpG/nNT8n6av59qH6qaZf8+3nSHosf413DPZvbu3BwW9IegfwJ8A9VZuvBiYC7wE+CPwZ8F8i4nngHOAfJb0b+BvgoepAAT4GfA44GHgI+OdBTj3YOR4Dzgd+GhH7R8SBgzz/RuBe4CDgCuCsOm0+CBwJfAg4O1/m5efcH/jqIMeuZxHwL8Dk/NzflTRO0hjgX4GHgcOA+cCnJX0of97lwG/my4eATzRwrt8nG60fm5/3nIjYCawGPl7V7kzgroh4ps4x7gEulHSBpKMkac8OSR3ALcAmoCuve3W++2yG76c3+lXSIuBS4A+AdwH/Bqxq4DVaWSLCS4ILsBHoA7YBu4CngKPyfR3ATmBmVftPApWq9auBnwNbgIOqtl8LrK5a3x/YDUzN1wN473DnIAueHw1R/+FAP/COqm03ADfkj7vyc72nav9dwAVV6zPy1z6WbBqkt04fnZw/vgK4p2rfGOBp4D8BJwC/GvDcS4Bv5I+fABZU7eseeK4Bz40B7S8gC3f2nAtQvr4O+ONBjtMBfAr4MfBa/m/8iXzficAzwNg6zxuqn+r1623A4gF98wowreyfcy/1F4/403ZaZKPpfYAlwA8k/QbZSH0c2Whwj01ko8I9VgCzgGsj4rkBx92850FE9AHPA4cOaNPIOYZyKPB8RLxS77yDbDu0zvnGAp0NnrP6db0O9ObHnAYcKmnbnoVsBLznuIcOqKO6hmHPlbc/ND/vz8hCda6k3yL7Jbqm3gEiYndEXBMRJwEHAn8FrMyn0qYCmyKiv85TG+mn6vqmAf+r6rU/D4jG/y1tL3Pw256A+DbZyPwDwLNkI7zqedrDyUb3e6YJVgDXAxeo9tLDN64OkrQ/2dTIUwPaDHkOslHlUJ4GJufTVDXnrX55VY+fqnO+frI3tl8G3jhW/hoHzptXv64xwJT8mJuBJyPiwKrlgIj4cFWt1bUdPsxrG/haDufN/Xcd2XTPWcDNkb1HM6SIeDUirgFeIHsvZzNw+CBvzg7VT28csurxZuCTA17/vhHxk+HqspKU/SeHl3IW3jyNIbJ55H7gP+TbbgC+Q3Y1yDTgl8C5+b7/AfyEbCrh0j2P833XAtvJfoGMJ3sP4MdV5w3gvQ2cY0Fe4/ghXsM9wF/n5zkReJHaqZ6xVe3PBf4PcATZFNTNVe0nko2kP0L2l8jleX9UT/XsIpvHHgtcmNc3Lu+HB4BlwL75+izguPy5nwd+AEwi+2XxCMNP9dyVt5+a90t31f6pZKPqTcDvDHGcT5NNYe2b1/wJsimf9+Q1Pgx8EdiP7K++kxrop3r9+vvAL/j1z85E4I/K/hn3MsT//2UX4KWkf/gstF4lm+d/Kf8f92NV+yflwfwM2YjuMrK/EGeTjRr3hHcH2RzyX+br15JdkfO9/Ng/BI6oOm518Nc9R75vPHBrHnDPDvIafpPsjcSX8qBcAXw931cvoMbk59icn/MGYFLV/rPJRudbgf9O7Rz/zcBN+fkeBI6teu6hZG9o/t+8f+6peu47yP462gasBy5i+OD/C7L3Bp4DvkT+i7Wqzffz+jTEcbqB+8l+IW4jeyP8o1X7Dwe+m5/jWeArw/VTvX7Nt59F9p7P9vx5K8v+Gfcy+LLnDSKzQki6lizUPlPCuW8CfhkRl7fg2FeQ/cL6+HBt9wZJK4GnyuhnG/n84QsbsSQdR/YXwZPAKWTTVVeVWtReIKmLbMrpmJJLsRHKb+7aSPYbQIVsSukrwH+NiAdLrajFJH2ObFruCxHxZNn12MjkqR4zs8R4xG9mlhgHv5lZYhz8ZmaJcfCbmSXGwW9mlhgHv5lZYhz8ZmaJcfCbmSXGwW9mlpjS7tVz8MEHR1dXV1mnf5OXX36Z/fbbr+wy2or7pJb7pJb7pFY79cn999//bETUfB9zacHf1dXFunXryjr9m1QqFebOnVt2GW3FfVLLfVLLfVKrnfpEUt1ve/NUj5lZYhz8ZmaJcfCbmSXGwW9mlphhg1/SSklbJf1ikP2S9BVJGyQ9IunY4ss0M7OiNDLivxZYMMT+U4Hp+dIN/H3zZZmZWasMG/wR8UOy7zUdzCLg+sjcAxwo6ZCiCjQzs2IVMcd/GLC5ar0332ZmZm1or36AS1I32XQQnZ2dVCqVvXn6QfX19bVNLe0itT6ZN29eYcfq6ekp7Fhlcp/UGi19UkTwbwGmVq1PybfViIgVwAqAOXPmRLt8uq2dPmnXLlLrk4gYtk3X8lvZeNVH9kI17cF9Umu09EkRwb8GWCJpNXAC8GJEPF3Aca1FJBVynEb+JzCz9tPI5ZyrgJ8CMyT1Slos6XxJ5+dN1gJPABuAfwQuaFm1VoiIGHaZtuyWYduY2cg07Ig/Is4cZn8AnyqsIjMzayl/ctfMLDFJB/+qVauYNWsW8+fPZ9asWaxatarskszMWi7Z4F+1ahVLly7l5ZdfBrIvT1i6dKnD38xGvWSD/+KLL2bXrl1v2rZr1y4uvvjikioyM9s7kg3+3t7emitTIoLe3t6SKjIz2ztK++rFdtDR0cHKlSvZvXs3HR0dnH766WWXZGbWcsmO+KH2A0i+Nt3MUpD0iH/37t2cc845bNq0iWnTprF79+6ySzIza7lkR/xTpkyhv7+fLVu2EBFs2bKF/v5+pkyZUnZpZmYtlWzwn3baaezYsYPJkycjicmTJ7Njxw5OO+20skszM2upZIO/p6eHhQsXsm3bNiKCbdu2sXDhwlFz+1gzs8EkO8e/fv16XnnlFW677bY3rupZvHgxGzduLLs0M7OWSnbEP378eJYsWcK8efMYO3Ys8+bNY8mSJYwfP77s0szMWirZEf/OnTu54oorWL58Obt27WLcuHHss88+7Ny5s+zSzMxaKtkR/6RJk+jr6+Oggw5izJgxHHTQQfT19TFp0qSySzMza6lkR/zbt29n0qRJ3HjjjW/65O727dvLLs3MrKWSDf7+/n6OPvpo5s+fT0QgiXnz5nH33XeXXZqZWUslG/wdHR1UKhW++MUvMnPmTNavX89FF11ER0dH2aWZmbVUsnP8g92Xx/frMbPRLtkR/+uvv053dzeXXnopr732GhMmTODcc89lxYoVZZdmZtZSyY74J0yYwIwZM9ixYwc9PT3s2LGDGTNmMGHChLJLMzNrqWRH/Oeddx7Lli0DYObMmXz5y19m2bJlnH/++SVXZmbWWskG/9VXXw3wpqme888//43tZmajVbJTPZCFf/VUj0PfzFIw6kf8kgo5jq/2MbPRYtSP+CNi2GXasluGbWNmNlqM+uA3M7M3c/CbmSXGwW9mlhgHv5lZYhoKfkkLJD0uaYOk5XX2Hy6pR9KDkh6R9OHiSzUzsyIMG/ySOoBrgFOBmcCZkmYOaPYZ4JsRcQxwBvB3RRdqZmbFaGTEfzywISKeiIidwGpg0YA2AbwzfzwReKq4Es3MrEiNfIDrMGBz1XovcMKANlcAd0r6c2A/4ORCqjMzs8IV9cndM4FrI+JLkk4E/knSrIh4vbqRpG6gG6Czs5NKpVLQ6ZvXTrW0C/dJLfdJLfdJrXbvk0aCfwswtWp9Sr6t2mJgAUBE/FTSPsDBwNbqRhGxAlgBMGfOnJg7d+7bq7pot99K29TSLtwntdwntUZZn7z/s3fy4qu7mj7O2be/3NTzJ+47jocvP6XpOgbTSPDfB0yXdARZ4J8B/OmANr8C5gPXSjoS2Ad4pshCzcxa7cVXd7Hxqo80dYxKpdL0L8Ou5bc29fzhDPvmbkT0A0uAO4DHyK7eeVTSlZIW5s3+G3CepIeBVcDZ4RvcmJm1pYbm+CNiLbB2wLbLqh6vB04qtjQzM2sFf3LXzCwxDn4zs8Q4+M3MEuPgNzNLjIPfzCwxDn4zs8Q4+M3MEuPgNzNLTFE3aTOzEaao+9I0e3uBVt+Xxmo5+M0Slcp9aayWp3rMzBLj4DczS4yD38wsMQ5+M7PEOPjNzBLj4DczS4yD38wsMQ5+M7PEOPjNzBLjT+5aEnx7ArNfc/BbEnx7ArNf81SPmVliHPxmZolx8JuZJcbBb2aWGAe/mVliHPxmZolx8JuZJcbBb2aWGAe/mVliHPxmZolpKPglLZD0uKQNkpYP0uaPJa2X9KikG4st08zMijLsvXokdQDXAL8L9AL3SVoTEeur2kwHLgFOiogXJL27VQWbmVlzGhnxHw9siIgnImInsBpYNKDNecA1EfECQERsLbZMMzMrSiN35zwM2Fy13gucMKDN+wAk/RjoAK6IiNsHHkhSN9AN0NnZSaVSeRslt0Y71dIuRlufNPt6+vr6CumTdupX90mtJPokIoZcgNOBr1WtnwV8dUCbW4DvAOOAI8h+URw41HFnz54d7WLaslvKLqHtjLY+KeL19PT0tEUdRXGf1BptfQKsizr528hUzxZgatX6lHxbtV5gTUTsiogngX8Hpr/dX0ZmZtY6jUz13AdMl3QEWeCfAfzpgDbfBc4EviHpYLKpnycKrNMaVNQ3TYG/bcpstBo2+COiX9IS4A6y+fuVEfGopCvJ/oxYk+87RdJ6YDdwUUQ818rCrb4ivmkK/G1TZqNZQ1+9GBFrgbUDtl1W9TiAC/PFzMzamD+5a2aWGAe/mVliHPxmZolpaI7fzCwFBxy5nKOuq3s7srfmumbrAGj+Io3BOPjNzHIvPXZV01fFjYQr4jzVY2aWGAe/mVliHPxmZolx8JuZJcbBb2aWGAe/mVliHPxmZolx8JuZJcbBb2aWGAe/mVliHPxmZonxvXrMEpXKDcmsloPfLFGp3JDManmqx8wsMQ5+M7PEOPjNzBLj4DczS4yD38wsMQ5+M7PEjOjLOd//2Tt58dVdhRyr2UvKJu47jocvP6WQWszMWmlEB/+Lr+5q+jpk8LXIZpYWT/WYmSXGwW9mlhgHv5lZYhz8ZmaJaSj4JS2Q9LikDZIGvZ2fpD+UFJLmFFeimZkVadjgl9QBXAOcCswEzpQ0s067A4ClwM+KLtLMzIrTyIj/eGBDRDwRETuB1cCiOu0+B3we2FFgfWZmVrBGruM/DNhctd4LnFDdQNKxwNSIuFXSRYMdSFI30A3Q2dlJpVJ5ywUPVMQx+vr62qaWIrhP6mu2FvdJLfdJrRHRJxEx5AKcDnytav0s4KtV62OACtCVr1eAOcMdd/bs2dGsactuafoYERE9PT1NH6OoWprlPqmviFrcJ7XcJ7XaqU+AdVEnfxuZ6tkCTK1an5Jv2+MAYBZQkbQR+I/AGr/Ba2bWnhoJ/vuA6ZKOkDQeOANYs2dnRLwYEQdHRFdEdAH3AAsjYl1LKjYzs6YMG/wR0Q8sAe4AHgO+GRGPSrpS0sJWF2hmZsVq6CZtEbEWWDtg22WDtJ3bfFlmxTrgyOUcdd2gH0Fp3HXN1gHQ/I0FzZoxou/Oadaolx67quk7ufourjZa+JYNZmaJcfCbmSXGwW9mlhjP8Y8yhb2JCX4j02yUcvCPMkW8iQl+I9NsNPNUj5lZYhz8ZmaJcfCbmSXGwW9mlhgHv5lZYhz8ZmaJcfCbmSVmRF/H7w8rmZm9dSM6+P1hJTOzt25EB7+ZWdEKGcTd3twxJu47rvkahuDgNzPLFTGD0LX81kKO00p+c9fMLDEOfjOzxDj4zcwS4+A3M0uMg9/MLDEOfjOzxDj4zcwS4+A3M0uMP8BllrAUPqVqtRz8ZolK5VOqVstTPWZmiXHwm5klxsFvZpaYhoJf0gJJj0vaIKnmm08kXShpvaRHJN0laVrxpZqZWRGGDX5JHcA1wKnATOBMSTMHNHsQmBMRvw3cDPx10YWamVkxGhnxHw9siIgnImInsBpYVN0gInoi4pV89R5gSrFlmplZURoJ/sOAzVXrvfm2wSwGbmumKDMza51Cr+OX9HFgDvDBQfZ3A90AnZ2dVCqVps9ZxDH6+vrappYiFPb9v01+MGe/ce3TJ9B8LaPt56Qoo+31FKHt+yQihlyAE4E7qtYvAS6p0+5k4DHg3cMdMyKYPXt2NGvasluaPkZERE9PT9PHKKqWduHXU8s/J7VG2+spQjv1CbAu6uRvI1M99wHTJR0haTxwBrCmuoGkY4D/DSyMiK0F/U4yM7MWGDb4I6IfWALcQTai/2ZEPCrpSkkL82ZfAPYH/kXSQ5LWDHI4MzMrWUNz/BGxFlg7YNtlVY9PLrguMzNrkRF/k7Z2eSPTdxg0s5FiRAd/UXcF9B0GzSwlvlePmVliHPxmZokZ0VM9Zm+Fv23KLOPgtyT426bMfs1TPWZmiXHwm5klxsFvZpYYB7+ZWWIc/GZmiXHwm5klxsFvZpYYB7+ZWWIc/GZmiXHwm5klxsFvZpYYB7+ZWWIc/GZmiXHwm5klxsFvZpYYB7+ZWWIc/GZmiXHwm5klxsFvZpYYB7+ZWWIc/GZmiXHwm5klxsFvZpYYB7+ZWWIaCn5JCyQ9LmmDpOV19k+QdFO+/2eSugqv1MzMCjFs8EvqAK4BTgVmAmdKmjmg2WLghYh4L/A3wOeLLtTMzIrRyIj/eGBDRDwRETuB1cCiAW0WAdflj28G5ktScWWamVlRGgn+w4DNVeu9+ba6bSKiH3gROKiIAs3MrFhj9+bJJHUD3QCdnZ1UKpWWn3PevHkNtdMwk1M9PT0FVNMe3Ce1iuoTGD394j6pNVr6pJHg3wJMrVqfkm+r16ZX0lhgIvDcwANFxApgBcCcOXNi7ty5b6PktyYihm1TqVTYG7W0C/dJLfdJLfdJrdHSJ41M9dwHTJd0hKTxwBnAmgFt1gCfyB+fDtwdjfSQmZntdcOO+COiX9IS4A6gA1gZEY9KuhJYFxFrgK8D/yRpA/A82S8HMzNrQw3N8UfEWmDtgG2XVT3eAfxRsaWZmVkr+JO7ZmaJcfCbmSXGwW9mlhgHv5lZYhz8ZmaJUVmX20t6BthUyslrHQw8W3YRbcZ9Ust9Ust9Uqud+mRaRLxr4MbSgr+dSFoXEXPKrqOduE9quU9quU9qjYQ+8VSPmVliHPxmZolx8GdWlF1AG3Kf1HKf1HKf1Gr7PvEcv5lZYjziNzNLTNLBL2mlpK2SflF2Le1A0lRJPZLWS3pU0tKyayqbpH0k3Svp4bxPPlt2Te1CUoekByXdUnYt7ULSRkk/l/SQpHVl1zOYpKd6JP0O0AdcHxGzyq6nbJIOAQ6JiAckHQDcD5wWEetLLq00+XdH7xcRfZLGAT8ClkbEPSWXVjpJFwJzgHdGxEfLrqcdSNoIzImIdrmOv66kR/wR8UOy7w8wICKejogH8scvAY9R+/3KSYlMX746Ll/SHS3lJE0BPgJ8rexa7K1LOvhtcJK6gGOAn5VcSunyKY2HgK3A9yIi+T4B/ha4GHi95DraTQB3Sro//47xtuTgtxqS9ge+BXw6IraXXU/ZImJ3RBxN9n3Tx0tKelpQ0keBrRFxf9m1tKEPRMSxwKnAp/Lp5Lbj4Lc3yeexvwX8c0R8u+x62klEbAN6gAUll1K2k4CF+Xz2auA/S7qh3JLaQ0Rsyf+7FfgOcHy5FdXn4Lc35G9kfh14LCK+XHY97UDSuyQdmD/eF/hd4JelFlWyiLgkIqZERBfZ92vfHREfL7ms0knaL78oAkn7AacAbXnFYNLBL2kV8FNghqReSYvLrqlkJwFnkY3gHsqXD5ddVMkOAXokPQLcRzbH78sXrZ5O4EeSHgbuBW6NiNtLrqmupC/nNDNLUdIjfjOzFDn4zcwS4+A3M0uMg9/MLDEOfjOzxDj4LVmS/jK/4+Yj+aWrJ5Rdk9neMLbsAszKIOlE4KPAsRHxmqSDgfFNHG9sRPQXVqBZC3nEb6k6BHg2Il4DiIhnI+IpScdJ+kl+//17JR2Q35P/G/l91h+UNA9A0tmS1ki6G7gr/+Tmyvx5D0paVOYLNBuMR/yWqjuByyT9O/B94CayT3HfBPxJRNwn6Z3Aq8BSsjs0HyXpt8juvvi+/DjHAr8dEc9L+p9kty84J7/Nw72Svh8RL+/l12Y2JI/4LUn5PfZnA93AM2SB/0ng6Yi4L2+zPZ+++QBwQ77tl8AmYE/wfy8i9nynwynA8vwWzhVgH+DwvfF6zN4Kj/gtWRGxmyygK5J+DnzqbRymejQv4A8j4vECyjNrGY/4LUmSZkiaXrXpaLJvHDtE0nF5mwMkjQX+DfhYvu19ZKP4euF+B/Dn+V1OkXRM616B2dvnEb+lan/g6nwuvh/YQDbt8418+75k8/snA38H/H3+V0E/cHZ+JdDAY36O7JupHpE0BniS7Mohs7biu3OamSXGUz1mZolx8JuZJcbBb2aWGAe/mVliHPxmZolx8JuZJcbBb2aWGAe/mVli/j93vt+d34eekwAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import statsmodels.api as sm\n",
"\n",
"\n",
"correlation = X_test[['percentile_cosine_similarity', 'Score']].corr().values[0,1]\n",
"print('Correlation between user & vector similarity percentile metric and review number of stars (score): %.2f%%' % (100*correlation))\n",
"\n",
"# boxplot of cosine similarity for each score\n",
"X_test.boxplot(column='percentile_cosine_similarity', by='Score')\n",
"plt.title('')\n",
"plt.show()\n",
"plt.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can observe a weak trend, showing that the higher the similarity score between the user and the product embedding, the higher the review score. Therefore, the user and product embeddings can weakly predict the review score - even before the user receives the product!\n",
"\n",
"Because this signal works in a different way than the more commonly used collaborative filtering, it can act as an additional feature to slightly improve the performance on existing problems."
]
}
],
"metadata": {
"interpreter": {
"hash": "be4b5d5b73a21c599de40d6deb1129796d12dc1cc33a738f7bac13269cfcafe8"
},
"kernelspec": {
"display_name": "Python 3.7.3 64-bit ('base': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}