mirror of
https://github.com/openai/openai-cookbook
synced 2024-11-09 19:10:56 +00:00
235 lines
62 KiB
Plaintext
235 lines
62 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Zero-shot classification with embeddings\n",
|
|
"\n",
|
|
"In this notebook we will classify the sentiment of reviews using embeddings and zero labeled data! The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb).\n",
|
|
"\n",
|
|
"We'll define positive sentiment to be 4- and 5-star reviews, and negative sentiment to be 1- and 2-star reviews. 3-star reviews are considered neutral and we won't use them for this example.\n",
|
|
"\n",
|
|
"We will perform zero-shot classification by embedding descriptions of each class and then comparing new samples to those class embeddings."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# imports\n",
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"from sklearn.metrics import classification_report\n",
|
|
"\n",
|
|
"# parameters\n",
|
|
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
|
|
"\n",
|
|
"# load data\n",
|
|
"datafile_path = \"data/fine_food_reviews_with_embeddings_1k.csv\"\n",
|
|
"\n",
|
|
"df = pd.read_csv(datafile_path)\n",
|
|
"df[\"embedding\"] = df.embedding.apply(eval).apply(np.array)\n",
|
|
"\n",
|
|
"# convert 5-star rating to binary sentiment\n",
|
|
"df = df[df.Score != 3]\n",
|
|
"df[\"sentiment\"] = df.Score.replace({1: \"negative\", 2: \"negative\", 4: \"positive\", 5: \"positive\"})\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Zero-Shot Classification\n",
|
|
"To perform zero shot classification, we want to predict labels for our samples without any training. To do this, we can simply embed short descriptions of each label, such as positive and negative, and then compare the cosine distance between embeddings of samples and label descriptions. \n",
|
|
"\n",
|
|
"The highest similarity label to the sample input is the predicted label. We can also define a prediction score to be the difference between the cosine distance to the positive and to the negative label. This score can be used for plotting a precision-recall curve, which can be used to select a different tradeoff between precision and recall, by selecting a different threshold."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" negative 0.61 0.88 0.72 136\n",
|
|
" positive 0.98 0.90 0.94 789\n",
|
|
"\n",
|
|
" accuracy 0.90 925\n",
|
|
" macro avg 0.79 0.89 0.83 925\n",
|
|
"weighted avg 0.92 0.90 0.91 925\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA5S0lEQVR4nO3dd3xW5f3/8dc7A8IGAREZgogioKBGXFWQKg7EXUddWBy0dfy0ddW60FZarf1qtSoWKo6CSq1SRREVXLWVIRsExAgBlT0DIePz++OcxJt4JzkhuXNnfJ6PRx65z3XW57oD9+c+13XOdcnMcM4550pKSXYAzjnnaiZPEM455+LyBOGccy4uTxDOOefi8gThnHMuLk8Qzjnn4vIE4aqUpKGSPk52HFVJ0iWS3omw3VOS7qqOmKqDpCxJJ4Wv75X0QrJjctXLE4RDUkNJoyV9LWmrpNmSTkt2XFGEH2I7JG2T9J2kZyU1rcpzmNmLZjYownbDzez+qjx3EUkmaXtYz1WSHpGUmohzOVfEE4QDSANWAv2BFsBvgZcldUlmUBUwxMyaAocDmQTx70ZSWrVHVfX6hPXsD1wI/CzJ8VSpOvI3qlM8QTjMbLuZ3WtmWWZWaGZvAF8BR5S2j6ROkl6VtFbSekmPl7Ldo5JWStoiaaak42PW9ZM0I1z3naRHwvIMSS+Ex90kabqkdhHqsQp4C+gdHsck/VLSUmBpWHZGeIW0SdJ/JB1aXp1im80U+LOkNWHc8yQVne9ZSQ/EHO9qScskbZA0UdK+MetM0nBJS8NYnpCk8uoY1nMZ8AnQN+Z4e1KvbpLeD8vWSXpRUssoMZQk6azw/FskfSnp1LC8uJkqXC5uqpLUJXwfhklaAbwv6S1J15U49hxJ54ave0iaEr6nX0i6YE/iddF4gnA/EH4YHwgsKGV9KvAG8DXQBegAjC/lcNMJPsj2Av4BvCIpI1z3KPComTUHugEvh+VXEFzJdAJaA8OBHRHi7gScDnweU3w2cBTQU9JhwBjg2vC4TwMTwya2qHUaBJxA8P60AC4A1seJZSDwYLi+fXjcksc7AzgSODTc7pTy6hgeuwdwPLAsXN7TeimMcV/gYIL3+94oMZSIpx/wHHAL0JLg/cmqwCH6h+c/BRgHXBxz7J7AfsCbkpoAUwj+He0NXAT8NdzGJYAnCLcbSenAi8BYM1tcymb9CD5UbgmvPnaaWdyOaTN7wczWm1m+mf0JaAgcFK7OAw6Q1MbMtpnZf2PKWwMHmFmBmc00sy1lhP2apE3Ax8AHwO9j1j1oZhvMbAdwDfC0mf0vPO5YIBc4ugJ1ygOaAT0AmdkiM/smznaXAGPMbJaZ5QJ3AMeUaLYbaWabzGwFMJWYK4JSzJK0HVgETAP+GpbvUb3MbJmZTTGzXDNbCzxC8GFdUcPCuk4Jr0BXlfFvJ557w9h2AP8C+kraL1x3CfBq+B6eAWSZ2d/Df0+fA/8EfrIHMbsIPEG4YpJSgOeBXcB1MeVvKegc3SbpEoJvml+bWX6EY/5a0iJJm8MP8RZAm3D1MIJv4ovDZqQzwvLngcnAeEmrJf0xTFylOdvMWprZfmb2i/CDpsjKmNf7Ab8Km2E2hfF0IvgAjVQnM3sfeBx4AlgjaZSk5nE23ZfgW3vRftsIrjQ6xGzzbczrHKApgKQFMe/38THbHB5ucyHBVVGTytRLUjtJ4xV0em8BXuD7v01FdAK+3IP9ihT/jcxsK/AmwdUBBFcTL4av9wOOKlHPS4B9KnFuVwZPEA4I2taB0UA74DwzyytaZ2anmVnT8OdFgv/QnVVOp2L44XYrQfNJKzNrCWwmaNrAzJaa2cUEzQV/ACZIamJmeWZ2n5n1BI4l+OZ4+R5WLXa44pXA78JkUvTT2MzGRa1TGPdjZnYE0JMgwd0SZ7PVBB9oAITNI62BVRGO3yvm/f6oxDozs5eBT4G7K1mv3xO8P4eEzXyXEv5tKmglQRNhPNuBxjHL8T7MSw4pPQ64WNIxQAbB1VXReT4oUc+mZvbzPYjZReAJwhV5kqAdeEiJb+DxfAZ8A4yU1ERBp/JxcbZrBuQDa4E0SXcDxd+2JV0qqa2ZFQKbwuJCSSdKOiRsP99C0KxTWJnKhZ4Bhks6SoEmkgZLaha1TpKODPdPJ/jw21lKbOOAKyX1ldSQ4MP4f2aWVQX1ABgJXC1pn0rUqxmwDdgsqQPxE10Uownq+mNJKZI6hP0kALOBiySlS8oEzo9wvEkEyXUE8FL47wOCvpQDJV0WHi89/HscvIdxu3J4gnCE7b3XErSBf1uiOekHzKwAGAIcAKwAsgmaPUqaDLwNLCFobtnJ7k0+pwILJG0j6LC+KExO+wATCJLDIoJ+hecrWU3MbAZwNUET0UaCTt6hFaxTc4IP5I1hndYDD8U517vAXQRt5N8QfMO+qOR2lajLPOBDgr6FPa3XfQTNVpsJmnVe3cNYPgOuBP4cHusDvr96uoug7hvD8/0jwvFyw1hOit0+bH4aRPA+riZoovsDQb+WSwD5hEHOOefi8SsI55xzcXmCcM45F5cnCOecc3F5gnDOORdXnRkcq02bNtalS5dkh+Gcc7XKzJkz15lZ23jr6kyC6NKlCzNmzEh2GM45V6tI+rq0dd7E5JxzLi5PEM455+LyBOGccy4uTxDOOefi8gThnHMuroQlCEljFEzLOL+U9ZL0mIIpGedKOjxm3RUKpmJcKumKRMXonHOudIm8gniWYLTO0pwGdA9/riEYbhpJewH3EEyI0g+4R1KrBMbpnHMujoQ9B2FmH5aYXrGks4DnLBhO9r+SWkpqDwwAppjZBgBJUwgSzbhExJmzK5+nplVmMixXV3Ru3YTzj+iY7DCcqzGS+aBcB3afGyA7LCut/AckXUNw9UHnzp33KIgduwr4y9Rle7SvqzuKRr0f0qc9DdNSkxuMczVErX6S2sxGAaMAMjMz92hii9ZNG/LVg4OrNC5X+/x12jL++PYX+PQozn0vmQliFcFk50U6hmWrCJqZYsunVVtUzlWCmVFQaOQXGrl5heQVFpJfYOQVFJJXUEjOrgIA8goKyS8Myjfn5NEgLYW8AiO/sJCCQmNXfiFrt+XSrGEa+YVGfkFwzILCQlZt2sleTdKD5eLyYN8v12ynbbOGFBQaBWbF8RRYENvX63No17wh+YVGYbhNfoFRGG5XaNCtbVP+dkVmkt9JVxMkM0FMBK6TNJ6gQ3qzmX0jaTLw+5iO6UHAHckK0tUP32zaCcB1//ic1k0akF9o7CooZP22XHJ2FZCRnkJ+QVCWtW47bZo2JK+wkLx849stO0lLERLkF1q1XYU0TEshLUWkpoi01BRSBAWFxqJvt9C+RQYpCtalSKSkiFTBXk0asH77Ljq0bESKRFpK0TqRmiqWfreVdxd9Vz0VcDVewhKEpHEEVwJtJGUT3JmUDmBmTxFMTH46wfy5OQRz2mJmGyTdD0wPDzWiqMPauUSRgt8fL1vLzrxCOrRsRIO0FNJTxbad+bRp1pBmGWk0aZhGy84NyM0roEPLRqSnppCWKjbm7KLTXo1pkJpSXLZlRz77tswgLSVYbpCaQs6uAto2a0haikhLFWkpKUiQnppCRnq4b0pQnpICDdNSSU8NPujTU1NIDT/MU1KUkPfhkSlLWPLd0oQc29U+ibyL6eJy1hvwy1LWjQHGJCIu5+IZcVZvRpzVO9lhOFej+JPUzjnn4vIE4ZxzLi5PEM455+LyBOGccy4uTxDOOefi8gThnHMuLk8Qzjnn4vIE4ZxzLi5PEM455+Kq1aO5Oudqh5xd+UzP2oiFA1Vt2L6LLTvyyCswvlq/nVSJL77bSqP0VLbl5nNyz3YM798tyVE7TxDOuQrbsauAtVtz2Zizi6VrtpFXUMiKDTl8t3knkli2dhvrtuaSniqy1udEOmaHlo3YkVcQjExbUOgJogaIlCDCkVX3BXYAWWZWmNConHNJkVcQ/Nd+ffYqlq3ZxvJ12yksNFZv2sG23Hy+Wredwgij1XZo2YimDdNomJ7CuYd3oGFaKp32asTR+7cGoLDQaNk4nZaNG9CyUTppqd+3dne9403mZG/mijGf0bxROn++oM9u6131KTVBSGpBMJjexUADYC2QAbST9F/gr2Y2tVqidM5Vi8kLvgXgxvGzdys/tGMLUlPEoJ77kJISJIBOezWmUXpq8evmGek0b5SGVLmRZls1bsCG7bv4YMlaAH47+GDaNc+o1DHdninrCmIC8BxwvJltil0h6QjgMkn7m9noBMbnnKtGE4Yfy7/nrKZvp5a0b5lBmyYNEza0eGlm3XUyAP/43wp+86951Xput7tSE4SZnVzGupnAzIRE5JxLmr2aNOCKY7skOwxXQ5TbB6HgevESYH8zGyGpM7CPmX2W8Oicc/XWttw8AG6ZMJdWjdNZuHoLDdNTaJyexnPD+pGRnprkCOu+KJ3UfwUKgYHACGAr8E/gyATG5Zyr59Zt2wXAh2FfRI99mpG1Lodtufms376LfVtksG7bLtZs3cmqjTvYviufxd9u5au129m+K5+CQuOwzq247dQeFTpvfkEhG3PyWLN1J9kbd7B2ay5bd+azbM02tu7MI2dXAcvWbKNhegqpEuOvPZq9m9XNPpIoCeIoMztc0ucAZrZRUoMEx+Wcq+fuOK0HN/64O40bpBZ3fP9+0iJGfbic40a+X+7+bZs1ZMl327jt1B6YGWu25rJ60w627MznyzXbyFq/nfXbd5GTm8+KDTms3ZqLJDbvyCvzuM0z0jhon2as3rSTrzflsHLDjnqdIPIkpQIGIKktwRVFuSSdCjwKpAJ/M7ORJdbvRzC1aFtgA3CpmWWH6/4IDCZ42nsKcKNZdU0H75xLNkk0abj7R9Tx3dsw6sPlHNa5Jfu2bMS+LTLo1rYp7Vpk0LZpQzq0bETLxulI4ujfv8fa7bkc8JtJ5Jdxb+5B7ZqRmiK6t2tGk4ZpHLh3U5pmpNG6aUO6tG5M6yYN2bt5Q/Zq3GC3DvsPlqzlijF1u6U9SoJ4DPgXsLek3wHnA78tb6cwqTwBnAxkA9MlTTSzhTGbPQw8Z2ZjJQ0EHiS4O+pY4Djg0HC7j4H+wLRItXLO1UnHd29L1sjBkbYd0qc9z3z0FUfs14rOezVmnxYZdNqrMa2bNKBjq8a0b5lBs4aVvy23Lis3QZjZi5JmAj8GBJxtZosiHLsfsMzMlgNIGg+cBcQmiJ7AzeHrqcBrRacleOaiQXjOdOC7COd0zjkA7hzckzsH90x2GLVauY8nSnoM2MvMnjCzxyMmB4AOwMqY5eywLNYc4Nzw9TlAM0mtzexTgoTxTfgzOd55JV0jaYakGWvXro0YlnPOuSiiPL8+E/itpC8lPSwpswrP/2ugf9gB3h9YBRRIOgA4GOhIkFQGSjq+5M5mNsrMMs0ss23btlUYlnPOVUxhobF1Zx4rN+Twny/XMXXxGrbsLLvDu6aL0sQ0FhgraS/gPOAPkjqbWfdydl0FdIpZ7hiWxR57NeEVhKSmwHlmtknS1cB/zWxbuO4t4Bjgo2jVcs65xNq2Mx+A8578D3s1CYYHKen6gQfwq0EHAWBmFBqkVvOT6ZVRkdFcDwB6APsBUZqZpgPdJXUlSAwXAT+N3UBSG2BDOPjfHQR3NAGsAK6W9CBBH0R/4P8qEKtzziVUm6bB3f6pKaJt04Yc3L4ZHVo2Yr/WTejYqhE3jp/NX95fxkvTV7IpJ49dBYU0aZDKtFtOpG2zhkmOPpooT1L/kaB/4EvgJeD+kmMzxWNm+ZKuAyYT3OY6xswWSBoBzDCzicAA4EFJBnxIMDggBONADQTmEXRYv21m/65g3ZxzLmGO2r91mXdUFQ142K1tU3LyCti5q4AvvtvKum25dSdBECSGY8xsXUUPbmaTgEklyu6OeT2BIBmU3K8AuLai53POuZqiZPJ4e/43DH9hVpKi2TNlDffdw8wWEzQVdQ7HYCpmZrWrps455yqkrCuIm4FrgD/FWWcETUDOOefqqLKG+74mfHmame2MXSepbg484pxzCbJ83XYATnv0IybdcDw9922e5IjKF+U5iP9ELHPOOVeKU3rtU/z6X59nJzGS6EpNEJL2CWeOayTpMEmHhz8DgMbVFaBzztUF3do25a0bg+d9n/noq7jPTdQ0ZfVBnAIMJXjA7ZGY8q3AbxIYk3PO1UkHt2/OEfu1YubXG8nNL0h2OOUqqw+i6Anq88zsn9UYk3PO1Vk/OaIjM7/eSBkjkNcYZd3meqmZvQB0kXRzyfVm9kic3ZxzzpXh8xWbADhu5Pu8e3N/Dti7aXIDKkNZndRNwt9NgWZxfpxzzlXQ0OO6FL+e9fXG5AUSgerKJG2ZmZk2Y8aMZIfhnHPl+mjpWi4b/Rkpgsn/7wRy8wtZuzWXLm2a0LVNk/IPUIUkzTSzuKN0Rx2L6QFgB/A2wSxvN4XNT8455yro+O7B9ASFBif/+cPi8o6tGvHxbTXnGeQoz0EMMrMtwBlAFsGorrckMijnnKvrGqQGH7+P//Qwnrr0cACyN+6gy+1vcvj9U/huy86ydq8WUQbrK9pmMPCKmW32OVydc65ylvzutN2WB/VsxzsLg5mVN2zfxTebd9KueXIHrYiSIN6QtJigiennktoCyU9tzjlXh4y6POgGmLp4DVc+Oz3J0QTKbWIys9uBY4FMM8sDtgNnJTow55xzyRWlkzoduBQ4IWxa+gB4KsFxOeecS7IoTUxPAunAX8Ply8KyqxIVlHPOueSLkiCONLM+McvvS5qTqICcc87VDFFucy2Q1K1oQdL+QKRRpiSdKukLScsk3R5n/X6S3pM0V9I0SR1j1nWW9I6kRZIWSuoS5ZzOOeeqRpQEcQswNfwA/wB4H/hVeTtJSgWeAE4DegIXS+pZYrOHgefM7FBgBPBgzLrngIfM7GCgH7AmQqzOOVerZa0PJhY6+4lP+HxFcofiiHIX03tAd+AG4HrgIDObGuHY/YBlZrbczHYB4/nh3U89CRIOwNSi9WEiSTOzKWEM28wsJ8I5nXOuVhtw0N7Fr5ev3Z7ESCIkiHB60V8C9wL3EDwLEeXpjQ7Aypjl7LAs1hzg3PD1OUAzSa2BA4FNkl6V9Lmkh8IrkpKxXSNphqQZa9eujRCSc87VbF3bNOHDW05MdhhAtCam54BewF+Ax8PXz1fR+X8N9Jf0OdAfWEXQv5EGHB+uPxLYn2Dyot2Y2SgzyzSzzLZt21ZRSM455yDaXUy9zSy272CqpIUR9lsFdIpZ7hiWFTOz1YRXEJKaAueZ2SZJ2cBsM1sernsNOBoYHeG8zjnnqkCUK4hZko4uWpB0FBBlXO3pQHdJXSU1AC4CJsZuIKmNpKIY7gDGxOzbMhzWA2AgECUpOeecqyJREsQRwH8kZUnKAj4FjpQ0T9Lc0nYys3zgOmAysAh42cwWSBoh6cxwswHAF5KWAO2A34X7FhA0L70naR4g4Jk9qaBzztU2RjBPz69emcPc7E1JiyNKE9Ope3pwM5sETCpRdnfM6wnAhFL2nUIw94RzztUrjdK/vyfnsfeW8szlmSRjFO0ot7l+XdZPdQTpnHP1yd7NM8gaORiAdxetYXpWcp6HiNLE5JxzLglOPCjohn1lxspytkyMKE1MzjnnkuD+s3vzoz9M5ZWZ2WRv3AHA4EPbc+nR+1XL+T1BOOdcDdWxVePi158uX0+zjDTyCgqrLUFEamKSNKqsZeecc4kx+opM7j+rF1/+/nQO7diiWs8d9Qri6XKWnXPOJcCPD26XtHNHuoIws5llLTvnnKt7Sr2CkPRvCJ/WiMPMzixtnXPOudqvrCamh6stCuecczVOqQnCzD4oei2pEdDZzL6olqicc84lXZT5IIYAs4G3w+W+kiaWuZNzzrlaL0on9b0Es8NtAjCz2UDXhEXknHOuRoiSIPLMbHOJslI7r51zztUNUZ6DWCDpp0CqpKK5qf+T2LCcc84lW5QriOsJphnNBcYBW4D/l8CYnHPOxZG1LocZX2/kntfnsylnV8LPF2W47xwzuxP4MXCimd1pZjsTHplzzrndrNoUDNg39tOv+eyrDQk/X5S7mI4MZ3WbC8yTNEfSEQmPzDnn3G4+vWMgf/pJn2o7X5QmptHAL8ysi5l1AX4J/D2hUTnnnPuB9i0a0aVNEwDeW7Qm4eeLkiAKzOyjogUz+xjIj3JwSadK+kLSMkm3x1m/n6T3JM2VNE1SxxLrm0vKlvR4lPM551xd17ZpQwCmfpH4BFHWWEyHhy8/kPQ0QQe1ARcC08o7sKRU4AngZCAbmC5popktjNnsYeA5MxsraSDwIHBZzPr7gQ+jV8c55+q2zq0b02vf5rRvkZHwc5V1m+ufSizfE/M6ynMQ/YBlZrYcQNJ44CwgNkH0BG4OX08FXitaEfZztCN4gjszwvmcc85VobLGYjqxksfuAMROpJoNHFVimznAucCjwDlAM0mtgY0ECepS4KTSTiDpGuAagM6dO1cyXOecc7EiTRgkaTDBsxDF1zRmNqIKzv9r4HFJQwmaklYBBcAvgElmli2p1J3NbBQwCiAzM9Of7nbOuSpUboKQ9BTQGDgR+BtwPvBZhGOvAjrFLHcMy4qZ2WqCKwgkNQXOM7NNko4Bjpf0C6Ap0EDSNjP7QUe3c865xIhyF9OxZnY5sNHM7gOOAQ6MsN90oLukrpIaABcBu40CK6mNpKIY7gDGAJjZJWbWObyt9tcEHdmeHJxzrhpFSRA7wt85kvYF8oD25e1kZvnAdcBkYBHwspktkDRCUtFsdAOALyQtIeiQ/l0F43fOOZcgUfog3pDUEngImEVwB9PfohzczCYBk0qU3R3zegIwoZxjPAs8G+V8zjlXHyxYvYUFq7dgZpTVT1tZ5SYIM7s/fPlPSW8AGXGG/3bOOVfNNmzfRevwwblEKOtBuXPLWIeZvZqYkJxzzpVlxFm9uPv1BQk/T1lXEEPKWGeAJwjnnKvDynpQ7srqDMQ551w0u/ILg98FhQk9T5S7mJxzztUgnyxbB8ArM7ITeh5PEM45V8s8FM4JkZaauDuYwBOEc87VOk0bRholqdKizCjXWNJdkp4Jl7tLOiPxoTnnnCtLXn5ih6CLcgXxdyCXYIgNCMZTeiBhETnnnCuThXnhz+8uIWvd9oSdJ0qC6GZmfyQYYgMzywES2/DlnHOuVI0apBa/XrM1N2HniZIgdklqRDhJkKRuBFcUzjnnkuTFq0pOr1P1ovR03Eswq1snSS8CxwFDExiTc865GiDKWEzvSJoJHE3QtHSjma1LeGTOOeeSKsqEQf8G/gFMNLPE9YY455yrUaL0QTwMHA8slDRB0vmSMsrbyTnnXO1WboIwsw/M7BfA/sDTwAXAmkQH5pxzrnTfbdkJwOuzV5Wz5Z6L9DheeBfTEOBC4HBgbMIics45V66+nVoC8FUCn4OI0gfxMtCP4E6mx4EPzCyxQwg655wr0/5tm7JviwzaNU9ci3+UPojRBA/LDTezqRVJDpJOlfSFpGWSbo+zfj9J70maK2mapI5heV9Jn0paEK67MHqVnHOufli9eSf/+nwVZokZcqOsGeUGmtn7QBPgrJLznpY3o5ykVOAJ4GQgG5guaaKZLYzZ7GHgOTMbK2kg8CBwGZADXG5mSyXtC8yUNNnMNlW4hs45V8cVGiRiYNeyriD6h7+HxPmJMlhfP2CZmS03s13AeOCsEtv0BN4PX08tWm9mS8xsafh6NUGneNsI53TOuXrj5pMPTOjxy5pR7p7w5Qgz+yp2naSuEY7dAVgZs5wNlHw2fA5wLvAocA7QTFJrM1sfc65+QAPgy5InkHQNcA1A586dI4TknHMuqih9EP+MUzahis7/a6C/pM8JrlhWAQVFKyW1B54HrozX92Fmo8ws08wy27b1CwznnKtKZfVB9AB6AS0knRuzqjkQpdt8FdApZrljWFYsbD46NzxfU+C8on4GSc2BN4E7zey/Ec7nnHOuCpV1m+tBBH0NLQn6HYpsBa6OcOzpQPewOWoVcBHw09gNJLUBNoRXB3cAY8LyBsC/CDqwq+pqxTnnXAWU1QfxOvC6pGPM7NOKHtjM8iVdB0wGUoExZrZA0ghghplNBAYAD0oy4EPgl+HuFwAnAK0lDQ3LhprZ7IrG4ZxzddWked8A8NlXGzimW+sqP75Ku39W0q1m9kdJfyGcCyKWmd1Q5dFUQmZmps2YMSPZYTjnXLUZ/fFX3P/GQh489xAu7rdnN+pImmlmmfHWldXEtCj87Z+6zjlXAw0+pD33v7GQTTl5CTl+WU1M/w5/F4+7JCkFaGpmWxISjXPOucgKwhagP7y9mJ8P6Fblxy/3NldJ/5DUXFITYD7BsN+3VHkkzjnnKqRDy0YAdG3TJCHHj/IcRM/wiuFs4C2gK8FwGM4555LshAPb0qJRekKOHSVBpEtKJ0gQE80sjzid1s455+qWKAniaSCLYNC+DyXtB3gfhHPO1XHlzgdhZo8Bj8UUfS3pxMSF5JxzriaI0kndQtIjkmaEP38iuJpwzjlXh0VpYhpDMLzGBeHPFuDviQzKOedc8kWZk7qbmZ0Xs3yfpNkJisc551wNEeUKYoekHxUtSDoO2JG4kJxzztUEUa4ghgPPSWoRLm8ErkhcSM4552qCMhOEpL7AAQRDda8C8GE2nHOufii1iUnS3cDLwHkEE/dc6MnBOefqj7KuIC4E+ppZjqTWwNvAM9UTlnPOuWQrq5M618xyAMxsfTnbOuecq2PKuoLYX9LE8LWAbjHLmNmZCY3MOedcUpWVIM4qsfxwIgNxzjlXs5Q1YdAHlT24pFOBRwnmpP6bmY0ssX4/gie12wIbgEvNLDtcdwXw23DTB2InLnLOOZd4Zd3F9G9JQ8Khvkuu21/SCEk/K2P/VOAJ4DSgJ3CxpJ4lNnsYeM7MDgVGAA+G++4F3AMcBfQD7pHUqmJVc845VxlldTxfDRwPLJY0XdIkSe9LWk4wBPhMMxtTxv79gGVmttzMdgHj+WGzVU/g/fD11Jj1pwBTzGyDmW0EpgCnVqhmzjnnKqWsJqZvgVuBWyV1AdoTDLGxpOjupnJ0AFbGLGcTXBHEmgOcS9AMdQ7QLLylNt6+HUqeQNI1wDUAnTt3jhCSc865qCLdumpmWWb2qZnNjpgcovo10F/S50B/gqe1C6LubGajzCzTzDLbtm1bhWE555yLMhbTnloFdIpZ7hiWFTOz1QRXEEhqCpxnZpskrQIGlNh3WgJjdc45V0IiH36bDnSX1FVSA4LxnCbGbiCpjaSiGO4guKMJYDIwSFKrsHN6UFjmnHOumiQsQZhZPnAdwQf7IuBlM1sQ3v1U9JDdAOALSUuAdsDvwn03APcTJJnpwIiwzDnnXDUpt4kpnP/hXmC/cHsBZmb7l7evmU0CJpUouzvm9QRgQin7juH7KwrnnHPVLEofxGjgJmAmFehAds45V7tFSRCbzeythEfinHOuRomSIKZKegh4FcgtKjSzWQmLyjnnXNJFSRBFD7dlxpQZMLDqw3HOOVdTlJsgzOzE6gjEOedczVLuba6SWkh6RNKM8OdPklpUR3DOOeeSJ8pzEGOArcAF4c8W4O+JDMo551zyRemD6GZm58Us3ydpdoLicc45V0NEuYLYIelHRQvhg3M7EheSc865miDKFcTPgbFhv4MIZn4bmsignHPOJV+Uu5hmA30kNQ+XtyQ6KOecc8lXaoKQdKmZvSDp5hLlAJjZIwmOzTnnXBKVdQXRJPzdrDoCcc45V7OUNeXo0+Hv+6ovHOecczVFlAfl/iipuaR0Se9JWivp0uoIzjnnXPJEuc11UNgxfQaQBRwA3JLIoJxzziVflARR1Aw1GHjFzDYnMB7nnHM1RJQE8YakxcARwHuS2gI7oxxc0qmSvpC0TNLtcdZ3ljRV0ueS5ko6PSxPlzRW0jxJiyTdUZFKOeecq7xyE4SZ3Q4cC2SaWR6wHTirvP0kpQJPAKcBPYGLJfUssdlvCeaqPgy4CPhrWP4ToKGZHUKQmK6V1CVSjZxzzlWJsp6DGGhm70s6N6YsdpNXyzl2P2CZmS0P9x1PkFgWxmxjQPPwdQtgdUx5E0lpQCNgF8Eggc4556pJWc9B9AfeB4bEWWeUnyA6ACtjlrP5fvKhIvcC70i6nuC5i5PC8gkEyeQboDFwk5ltKOd8zjnnqlBZz0HcE/6+MoHnvxh41sz+JOkY4HlJvQmuPgqAfYFWwEeS3i26Giki6RrgGoDOnTsnMEznnKt/ojwH8XtJLWOWW0l6IMKxVwGdYpY7hmWxhgEvA5jZp0AG0Ab4KfC2meWZ2RrgE3af8pRwn1FmlmlmmW3bto0QknPOuaii3MV0mpltKlows43A6RH2mw50l9RVUgOCTuiJJbZZAfwYQNLBBAlibVg+MCxvAhwNLI5wTuecc1UkSoJIldSwaEFSI6BhGdsDYGb5wHXAZGARwd1KCySNkHRmuNmvgKslzQHGAUPNzAjufmoqaQFBovm7mc2tSMWcc85VTpT5IF4keP6haJrRK4GxUQ5uZpOASSXK7o55vRA4Ls5+2whudXXOOZckUeaD+EP4Db/oDqP7zWxyYsNyzjmXbFGuICBoIso3s3clNZbUzMy2JjIw55xzyRXlLqarCZ5LeDos6gC8lsCYnHPO1QBROql/SdBPsAXAzJYCeycyKOecc8kXJUHkmtmuooVw+AtLXEjOOedqgigJ4gNJvwEaSToZeAX4d2LDcs45l2xREsRtBA+vzQOuJbht9beJDMo551zylXkXUzhk9wIz6wE8Uz0hOeecqwnKvIIwswLgC0k+Ep5zztUzUZ6DaAUskPQZwWRBAJjZmaXv4pxzrraLkiDuSngUzjnnapyyZpTLAIYDBxB0UI8OB+BzzjlXD5TVBzGWYA6GeQTzSv+pWiJyzjlXI5TVxNTTzA4BkDQa+Kx6QnLOOVcTlHUFkVf0wpuWnHOu/inrCqKPpC3haxE8Sb0lfG1m1jzh0TnnnEuaUhOEmaVWZyDOOedqlihDbTjnnKuHok4YtEcknQo8CqQCfzOzkSXWdya4W6pluM3t4TSlSDqUYA6K5kAhcKSZ7azI+fPy8sjOzmbnzgrt5uq5jIwMOnbsSHp6erJDcS6pEpYgwnGcngBOBrKB6ZImhvNQF/kt8LKZPSmpJ8FAgF3CIcVfAC4zszmSWhPTaR5VdnY2zZo1o0uXLkiqdJ1c3WdmrF+/nuzsbLp27ZrscJxLqkQ2MfUDlpnZ8nA+ifHAWSW2MYIrBIAWwOrw9SBgrpnNATCz9eG4UBWyc+dOWrdu7cnBRSaJ1q1b+1WncyQ2QXQAVsYsZ4dlse4FLpWUTXD1cH1YfiBgkiZLmiXp1ngnkHSNpBmSZqxduzZuEJ4cXEX5vxnnAsnupL4YeNbMOgKnA89LSiFo+voRcEn4+xxJPy65s5mNMrNMM8ts27ZtdcbtnHN1XiITxCqgU8xyx7As1jDgZQAz+xTIANoQXG18aGbrzCyH4Ori8ATGmjDffvstF110Ed26deOII47g9NNPZ8mSJWRlZdG7d+8qO8/dd9/Nu+++C8BHH31Er1696Nu3L6tWreL888+v1LHNjIEDB7Jly5bistdeew1JLF68uLgsKyuLRo0a0bdvX3r27Mnw4cMpLCys1Llzc3O58MILOeCAAzjqqKPIysqKu92jjz5K79696dWrF//3f/9XXD5nzhyOOeYYDjnkEIYMGVJch3nz5jF06NBKxeZcXZfIBDEd6C6pq6QGwEXAxBLbrAB+DCDpYIIEsRaYDBwiqXHYYd0fWEgtY2acc845DBgwgC+//JKZM2fy4IMP8t1331X5uUaMGMFJJ50EwIsvvsgdd9zB7Nmz6dChAxMmTIh8nPz8Hz40P2nSJPr06UPz5t8/Gzlu3Dh+9KMfMW7cuN227datG7Nnz2bu3LksXLiQ1157bc8qFBo9ejStWrVi2bJl3HTTTdx2220/2Gb+/Pk888wzfPbZZ8yZM4c33niDZcuWAXDVVVcxcuRI5s2bxznnnMNDDz0EwCGHHEJ2djYrVqyoVHzO1WUJu4vJzPIlXUfwYZ8KjDGzBZJGADPMbCLwK+AZSTcRdFgPNTMDNkp6hCDJGDDJzN6sTDz3/XsBC1dvKX/DCui5b3PuGdKr1PVTp04lPT2d4cOHF5f16dMHYLdvwllZWVx22WVs3x5Mt/H4449z7LHH8s0333DhhReyZcsW8vPzefLJJzn22GMZNmwYM2bMQBI/+9nPuOmmmxg6dChnnHEGmzZt4uWXX2by5Mm89dZb/O53v+OMM85g/vz5FBQUcPvttzNt2jRyc3P55S9/ybXXXsu0adO46667aNWqFYsXL2bJkiW71ePFF1/kmmuuKV7etm0bH3/8MVOnTmXIkCHcd999P6h7Wloaxx57bPEH9Z56/fXXuffeewE4//zzue666zCz3foJFi1axFFHHUXjxo0B6N+/P6+++iq33norS5Ys4YQTTgDg5JNP5pRTTuH+++8HYMiQIYwfP55bb43bxeVcvZfQ5yDCZxomlSi7O+b1QuC4UvZ9geBW11pr/vz5HHHEEeVut/feezNlyhQyMjJYunQpF198MTNmzOAf//gHp5xyCnfeeScFBQXk5OQwe/ZsVq1axfz58wHYtGnTbse66qqr+PjjjznjjDM4//zzd0tEo0ePpkWLFkyfPp3c3FyOO+44Bg0aBMCsWbOYP39+3Fs7P/nkE55++uni5ddff51TTz2VAw88kNatWzNz5swf1DMnJ4f33nuPESNG/OB4xx9/PFu3bv1B+cMPP1x8FVRk1apVdOoUtFSmpaXRokUL1q9fT5s2bYq36d27N3feeSfr16+nUaNGTJo0iczMTAB69erF66+/ztlnn80rr7zCypXf3zeRmZnJyJEjPUE4V4qEJoiapKxv+smWl5fHddddx+zZs0lNTS3+Bn/kkUfys5/9jLy8PM4++2z69u3L/vvvz/Lly7n++usZPHhw8Qd8FO+88w5z584tbnLavHkzS5cupUGDBvTr16/U+/43bNhAs2bNipfHjRvHjTfeCMBFF13EuHHjihPEl19+Sd++fZHEWWedxWmnnfaD43300UeRY47i4IMP5rbbbmPQoEE0adKEvn37kpoajBQzZswYbrjhBu6//37OPPNMGjRoULzf3nvvzerVq0s7rHP1Xr1JEMnQq1evSO3/f/7zn2nXrh1z5syhsLCQjIwMAE444QQ+/PBD3nzzTYYOHcrNN9/M5Zdfzpw5c5g8eTJPPfUUL7/8MmPGjIkUj5nxl7/8hVNOOWW38mnTptGkSZNS90tLS6OwsJCUlBQ2bNjA+++/z7x585BEQUEBkorb9ov6IMpSkSuIDh06sHLlSjp27Eh+fj6bN2+mdevWP9h32LBhDBs2DIDf/OY3dOzYEYAePXrwzjvvALBkyRLefPP7lsqdO3fSqFGjMmN1rj5L9m2uddrAgQPJzc1l1KhRxWVz5879wTfozZs30759e1JSUnj++ecpKAieCfz6669p164dV199NVdddRWzZs1i3bp1FBYWct555/HAAw8wa9asyPGccsopPPnkk+TlBQ+lL1mypLjfoywHHXQQy5cvB2DChAlcdtllfP3112RlZbFy5Uq6du1aoauCjz76iNmzZ//gp2RyADjzzDMZO3Zs8bkHDhwY9zmFNWvWALBixQpeffVVfvrTn+5WXlhYyAMPPLBbf9CSJUuq9E4y5+oaTxAJJIl//etfvPvuu3Tr1o1evXpxxx13sM8+++y23S9+8QvGjh1Lnz59WLx4cfG3+WnTptGnTx8OO+wwXnrpJW688UZWrVrFgAED6Nu3L5deeikPPvhg5HiuuuoqevbsyeGHH07v3r259tpr4961VNLgwYOZNm0aEDQvnXPOObutP++8835wN1NVGTZsGOvXr+eAAw7gkUceYeTIYDiv1atXc/rpp+8WQ8+ePRkyZAhPPPEELVu2LI73wAMPpEePHuy7775ceeWVxftMnTqVwYMHJyRu5+oCBTcN1X6ZmZk2Y8aM3coWLVrEwQcfnKSI6o5vvvmGyy+/nClTpiQ7lCqTm5tL//79+fjjj0lL+2FLq//bcbXF4+8vJWdXAbee2mOP9pc008wy463zPghXrvbt23P11VezZcuW3Z6FqM1WrFjByJEj4yYH52qT6wZ2T9ix/X+Hi+SCCy5IdghVqnv37nTvnrj/WM7VBXW+D6KuNKG56uP/ZpwL1OkEkZGRwfr16/0/vIusaD6IoluNnavP6nQTU8eOHcnOzqa0ocCdi6doRjnn6rs6nSDS09N9VjDnnNtDdbqJyTnn3J7zBOGccy4uTxDOOefiqjNPUktaC3xdiUO0AdZVUTi1RX2rc32rL3id64vK1Hk/M4s7Z3OdSRCVJWlGaY+b11X1rc71rb7gda4vElVnb2JyzjkXlycI55xzcXmC+N6o8jepc+pbnetbfcHrXF8kpM7eB+Gccy4uv4JwzjkXlycI55xzcdWrBCHpVElfSFom6fY46xtKeilc/z9JXZIQZpWKUOebJS2UNFfSe5L2S0acVam8Osdsd54kk1Trb4mMUmdJF4R/6wWS/lHdMVa1CP+2O0uaKunz8N/36fGOU1tIGiNpjaT5payXpMfC92OupMMrfVIzqxc/QCrwJbA/0ACYA/Qssc0vgKfC1xcBLyU77mqo84lA4/D1z+tDncPtmgEfAv8FMpMddzX8nbsDnwOtwuW9kx13NdR5FPDz8HVPICvZcVeyzicAhwPzS1l/OvAWIOBo4H+VPWd9uoLoBywzs+VmtgsYD5xVYpuzgLHh6wnAjyWpGmOsauXW2cymmllOuPhfoLaPcx3l7wxwP/AHYGd1BpcgUep8NfCEmW0EMLM11RxjVYtSZwOK5shtAayuxviqnJl9CGwoY5OzgOcs8F+gpaT2lTlnfUoQHYCVMcvZYVncbcwsH9gMtK6W6BIjSp1jDSP4BlKblVvn8NK7k5m9WZ2BJVCUv/OBwIGSPpH0X0mnVlt0iRGlzvcCl0rKBiYB11dPaElT0f/v5arT80G46CRdCmQC/ZMdSyJJSgEeAYYmOZTqlkbQzDSA4CrxQ0mHmNmmZAaVYBcDz5rZnyQdAzwvqbeZFSY7sNqiPl1BrAI6xSx3DMvibiMpjeCydH21RJcYUeqMpJOAO4EzzSy3mmJLlPLq3AzoDUyTlEXQVjuxlndUR/k7ZwMTzSzPzL4ClhAkjNoqSp2HAS8DmNmnQAbBoHZ1VaT/7xVRnxLEdKC7pK6SGhB0Qk8ssc1E4Irw9fnA+xb2/tRS5dZZ0mHA0wTJoba3S0M5dTazzWbWxsy6mFkXgn6XM81sRnLCrRJR/m2/RnD1gKQ2BE1Oy6sxxqoWpc4rgB8DSDqYIEHU5fmHJwKXh3czHQ1sNrNvKnPAetPEZGb5kq4DJhPcATHGzBZIGgHMMLOJwGiCy9BlBJ1BFyUv4sqLWOeHgKbAK2F//AozOzNpQVdSxDrXKRHrPBkYJGkhUADcYma19uo4Yp1/BTwj6SaCDuuhtfkLn6RxBEm+Tdivcg+QDmBmTxH0s5wOLANygCsrfc5a/H4555xLoPrUxOScc64CPEE455yLyxOEc865uDxBOOeci8sThHPOubg8Qbg9JqlA0mxJ8yX9W1LLKj5+VnjPPpK2lbJNI0kfSEqV1EXSjjCmhZKeCp+crsg5MyU9Fr4eIOnYmHXDJV1emTqFx7lX0q/L2eZZSedX4JhdShvls6pJOrNo9FRJZ0vqGbNuRPjg5Z4cd7yk2vzwXp3jCcJVxg4z62tmvQmeG/llEmL4GfCqmRWEy1+aWV/gUIIRPM+uyMHMbIaZ3RAuDgCOjVn3lJk9V9mAazszm2hmI8PFswne56J1d5vZu3t46CeBWysZnqtCniBcVfmUcGAwSd0kvS1ppqSPJPUIy9tJ+pekOeHPsWH5a+G2CyRdU8HzXgK8XrIwHGzxP8AB4bfr9/X9nBedw/P+JLz6mSPpw7BsgKQ3FMwFMhy4KbwiOb7om7+kHpI+KzpXePx54esjwiuamZImq5zRNCVdLWl6GMM/JTWOWX2SpBmSlkg6I9w+VdJD4T5zJV1bkTdL0jZJfw7f6/cktQ3L+yoYxG9u+DdqFZbfoO/nCxkflg2V9Hj49zsTeCh8j7oVXfkomKvhlZjzDpD0Rvh6kKRPJc2S9IqkpuFmH4V1rjcP8NZ0niBcpUlKJRjSoOgp5VHA9WZ2BPBr4K9h+WPAB2bWh2Bc+wVh+c/CbTOBGyRFGkFXwRAL+5tZVpx1jcOY5gF/Acaa2aHAi2EcAHcDp4Tx7Pb0eHjMp4A/h1dJH8WsWww0kNQ1LLoQeElSeniu88P6jAF+V041XjWzI8MYFhGMH1SkC8Gw1oOBpyRlhOs3m9mRwJHA1TFxFNV9X0mTSjlfE4InjXsBHxA8jQvwHHBb+B7Niym/HTgsLB9e4j36D8Hf/JbwPfoyZvW7wFGSmoTLFwLjFTQZ/hY4ycwOB2YAN4fHKyR4CrhP6W+Xq06eIFxlNJI0G/gWaAdMCb8NHkswdMdsgnGeir5FDyRoRsDMCsxsc1h+g6Q5BOMidSL6IHJtgE0lyrqF5/0EeNPM3gKOAYpmUHse+FH4+hPgWUlXEwzXUBEvE3zoEf5+CTiIYCDAKWEMv6X8+TV6h1dZ8wiuhnrFnsPMCs1sKcG4ST2AQQTj7cwG/kcwHP1u75eZrTaz0mZPKwxjBXgB+JGkFkBLM/sgLB9LMDkNwFzgRQWj/eaXU5fYGPKBt4Eh4RXBYIIrvaMJmqQ+CetwBRA7i+EaYN+o53GJ5ZdyrjJ2mFnf8Nv6ZII+iGeBTWE/QLkkDQBOAo4xsxxJ0wgGVYt0/jjbfhn13GY2XNJRBB9eMyUdEfG8EHzIviLp1eBQtlTSIcACMzumAsd5FjjbzOZIGko4oF5RiCVDJpgt7Hozmxy7Qns+PW55Y+0MJkgWQ4A7wzpGNR64jqB/aoaZbZUkYIqZXVzKPhkEf1dXA/gVhKu0cEa6GwgGR8sBvpL0EyieJ7eoyeA9gmlNi9rSWxAMqb4xTA49CL5hRj3vRiA1bHopy3/4fuDFSwjaupHUzcz+Z2Z3E4zy2anEflsJhgePd+4vCQa9u4vvv5F/AbRVMPcAktIl9Yq3f4xmwDdh89QlJdb9RFKKpG4EU2t+QZCIfx5uj6QDY5pxokghGKkY4KfAx+GV3EZJx4fllwEfKLgDrJOZTQVuI/hbNS1xvFLfI4ImrMMJZrMbH5b9FzhO0gFh/E0kHRizz4FAtdyN5crnCcJVCTP7nKA54mKCD7phYbPRAr6fCvJG4MSwOWUmQVPD20CapEXASIIPkIp4h++bjEpzPXClpLkEH343huUPSZqn4PbQ/xDMaxzr38A5YQfs8fzQS8ClfD/nwC6CD98/hHWfTcxdUKW4i6Cp6BNgcYl1K4DPCGb5G25mO4G/AQuBWWHcT1OiJaCcPojtQL9w34HAiLD8CoL3Yy7QNyxPBV4I/16fA4/FmWBoPHCLpM/DRFYsvLPsDeC08DdmtpZgsqZx4bk+JWg6Q1I7gqvSb0uJ3VUzH83V1WoKpg+9ycwuS3YstYGkbWZW8iqgRlAwLPcWMxud7FhcwK8gXK1mZrOAqeGdVK5220TQQe5qCL+CcM45F5dfQTjnnIvLE4Rzzrm4PEE455yLyxOEc865uDxBOOeci+v/A7qNNh/O/Qb4AAAAAElFTkSuQmCC",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"from openai.embeddings_utils import cosine_similarity, get_embedding\n",
|
|
"from sklearn.metrics import PrecisionRecallDisplay\n",
|
|
"\n",
|
|
"def evaluate_embeddings_approach(\n",
|
|
" labels = ['negative', 'positive'], \n",
|
|
" model = EMBEDDING_MODEL,\n",
|
|
"):\n",
|
|
" label_embeddings = [get_embedding(label, engine=model) for label in labels]\n",
|
|
"\n",
|
|
" def label_score(review_embedding, label_embeddings):\n",
|
|
" return cosine_similarity(review_embedding, label_embeddings[1]) - cosine_similarity(review_embedding, label_embeddings[0])\n",
|
|
"\n",
|
|
" probas = df[\"embedding\"].apply(lambda x: label_score(x, label_embeddings))\n",
|
|
" preds = probas.apply(lambda x: 'positive' if x>0 else 'negative')\n",
|
|
"\n",
|
|
" report = classification_report(df.sentiment, preds)\n",
|
|
" print(report)\n",
|
|
"\n",
|
|
" display = PrecisionRecallDisplay.from_predictions(df.sentiment, probas, pos_label='positive')\n",
|
|
" _ = display.ax_.set_title(\"2-class Precision-Recall curve\")\n",
|
|
"\n",
|
|
"evaluate_embeddings_approach(labels=['negative', 'positive'], model=EMBEDDING_MODEL)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can see that this classifier already performs extremely well. We used similarity embeddings, and the simplest possible label name. Let's try to improve on this by using more descriptive label names, and search embeddings."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" negative 0.98 0.73 0.84 136\n",
|
|
" positive 0.96 1.00 0.98 789\n",
|
|
"\n",
|
|
" accuracy 0.96 925\n",
|
|
" macro avg 0.97 0.86 0.91 925\n",
|
|
"weighted avg 0.96 0.96 0.96 925\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"evaluate_embeddings_approach(labels=['An Amazon review with a negative sentiment.', 'An Amazon review with a positive sentiment.'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Using the search embeddings and descriptive names leads to an additional improvement in performance."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" negative 0.98 0.73 0.84 136\n",
|
|
" positive 0.96 1.00 0.98 789\n",
|
|
"\n",
|
|
" accuracy 0.96 925\n",
|
|
" macro avg 0.97 0.86 0.91 925\n",
|
|
"weighted avg 0.96 0.96 0.96 925\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"evaluate_embeddings_approach(labels=['An Amazon review with a negative sentiment.', 'An Amazon review with a positive sentiment.'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"As shown above, zero-shot classification with embeddings can lead to great results, especially when the labels are more descriptive than just simple words."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "openai",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.9"
|
|
},
|
|
"orig_nbformat": 4,
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|