You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/Zero-shot_classification_wi...

235 lines
62 KiB
Plaintext

2 years ago
{
"cells": [
{
"attachments": {},
2 years ago
"cell_type": "markdown",
"metadata": {},
"source": [
"## Zero-shot classification with embeddings\n",
"\n",
"In this notebook we will classify the sentiment of reviews using embeddings and zero labeled data! The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb).\n",
"\n",
"We'll define positive sentiment to be 4- and 5-star reviews, and negative sentiment to be 1- and 2-star reviews. 3-star reviews are considered neutral and we won't use them for this example.\n",
2 years ago
"\n",
"We will perform zero-shot classification by embedding descriptions of each class and then comparing new samples to those class embeddings."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
2 years ago
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"from sklearn.metrics import classification_report\n",
"\n",
"# parameters\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
"\n",
"# load data\n",
"datafile_path = \"data/fine_food_reviews_with_embeddings_1k.csv\"\n",
"\n",
"df = pd.read_csv(datafile_path)\n",
"df[\"embedding\"] = df.embedding.apply(eval).apply(np.array)\n",
2 years ago
"\n",
"# convert 5-star rating to binary sentiment\n",
"df = df[df.Score != 3]\n",
"df[\"sentiment\"] = df.Score.replace({1: \"negative\", 2: \"negative\", 4: \"positive\", 5: \"positive\"})\n"
2 years ago
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Zero-Shot Classification\n",
"To perform zero shot classification, we want to predict labels for our samples without any training. To do this, we can simply embed short descriptions of each label, such as positive and negative, and then compare the cosine distance between embeddings of samples and label descriptions. \n",
"\n",
"The highest similarity label to the sample input is the predicted label. We can also define a prediction score to be the difference between the cosine distance to the positive and to the negative label. This score can be used for plotting a precision-recall curve, which can be used to select a different tradeoff between precision and recall, by selecting a different threshold."
]
},
{
"cell_type": "code",
"execution_count": 2,
2 years ago
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" negative 0.61 0.88 0.72 136\n",
" positive 0.98 0.90 0.94 789\n",
2 years ago
"\n",
" accuracy 0.90 925\n",
" macro avg 0.79 0.89 0.83 925\n",
"weighted avg 0.92 0.90 0.91 925\n",
2 years ago
"\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA5S0lEQVR4nO3dd3xW5f3/8dc7A8IGAREZgogioKBGXFWQKg7EXUddWBy0dfy0ddW60FZarf1qtSoWKo6CSq1SRREVXLWVIRsExAgBlT0DIePz++OcxJt4JzkhuXNnfJ6PRx65z3XW57oD9+c+13XOdcnMcM4550pKSXYAzjnnaiZPEM455+LyBOGccy4uTxDOOefi8gThnHMuLk8Qzjnn4vIE4aqUpKGSPk52HFVJ0iWS3omw3VOS7qqOmKqDpCxJJ4Wv75X0QrJjctXLE4RDUkNJoyV9LWmrpNmSTkt2XFGEH2I7JG2T9J2kZyU1rcpzmNmLZjYownbDzez+qjx3EUkmaXtYz1WSHpGUmohzOVfEE4QDSANWAv2BFsBvgZcldUlmUBUwxMyaAocDmQTx70ZSWrVHVfX6hPXsD1wI/CzJ8VSpOvI3qlM8QTjMbLuZ3WtmWWZWaGZvAF8BR5S2j6ROkl6VtFbSekmPl7Ldo5JWStoiaaak42PW9ZM0I1z3naRHwvIMSS+Ex90kabqkdhHqsQp4C+gdHsck/VLSUmBpWHZGeIW0SdJ/JB1aXp1im80U+LOkNWHc8yQVne9ZSQ/EHO9qScskbZA0UdK+MetM0nBJS8NYnpCk8uoY1nMZ8AnQN+Z4e1KvbpLeD8vWSXpRUssoMZQk6azw/FskfSnp1LC8uJkqXC5uqpLUJXwfhklaAbwv6S1J15U49hxJ54ave0iaEr6nX0i6YE/iddF4gnA/EH4YHwgsKGV9KvAG8DXQBegAjC/lcNMJPsj2Av4BvCIpI1z3KPComTUHugEvh+VXEFzJdAJaA8OBHRHi7gScDnweU3w2cBTQU9JhwBjg2vC4TwMTwya2qHUaBJxA8P60AC4A1seJZSDwYLi+fXjcksc7AzgSODTc7pTy6hgeuwdwPLAsXN7TeimMcV/gYIL3+94oMZSIpx/wHHAL0JLg/cmqwCH6h+c/BRgHXBxz7J7AfsCbkpoAUwj+He0NXAT8NdzGJYAnCLcbSenAi8BYM1tcymb9CD5UbgmvPnaaWdyOaTN7wczWm1m+mf0JaAgcFK7OAw6Q1MbMtpnZf2PKWwMHmFmBmc00sy1lhP2apE3Ax8AHwO9j1j1oZhvMbAdwDfC0mf0vPO5YIBc4ugJ1ygOaAT0AmdkiM/smznaXAGPMbJaZ5QJ3AMeUaLYbaWabzGwFMJWYK4JSzJK0HVgETAP+GpbvUb3MbJmZTTGzXDNbCzxC8GFdUcPCuk4Jr0BXlfFvJ557w9h2AP8C+kraL1x3CfBq+B6eAWSZ2d/Df0+fA/8EfrIHMbsIPEG4YpJSgOeBXcB1MeVvKegc3SbpEoJvml+bWX6EY/5a0iJJm8MP8RZAm3D1MIJv4ovDZqQzwvLngcnAeEmrJf0xTFylOdvMWprZfmb2i/CDpsjKmNf7Ab8Km2E2hfF0IvgAjVQnM3sfeBx4AlgjaZSk5nE23ZfgW3vRftsIrjQ6xGzzbczrHKApgKQFMe/38THbHB5ucyHBVVGTytRLUjtJ4xV0em8BXuD7v01FdAK+3IP9ihT/jcxsK/AmwdUBBFcTL4av9wOOKlHPS4B9KnFuVwZPEA4I2taB0UA74DwzyytaZ2anmVnT8OdFgv/QnVVOp2L44XYrQfNJKzNrCWwmaNrAzJaa2cUEzQV/ACZIamJmeWZ2n5n1BI4l+OZ4+R5WLXa44pXA78JkUvTT2MzGRa1TGPdjZnYE0JMgwd0SZ7PVBB9oAITNI62BVRGO3yvm/f6oxDozs5eBT4G7K1mv3xO8P4eEzXyXEv5tKmglQRNhPNuBxjHL8T7MSw4pPQ64WNIxQAbB1VXReT4oUc+mZvbzPYjZReAJwhV5kqAdeEiJb+DxfAZ8A4yU1ERBp/JxcbZrBuQDa4E0SXcDxd+2JV0qqa2ZFQKbwuJCSSdKOiRsP99C0KxTWJnKhZ4Bhks6SoEmkgZLaha1TpKODPdPJ/jw21lKbOOAKyX1ldSQ4MP4f2aWVQX1ABgJXC1pn0rUqxmwDdgsqQPxE10Uownq+mNJKZI6hP0kALOBiySlS8oEzo9wvEkEyXUE8FL47wOCvpQDJV0WHi89/HscvIdxu3J4gnCE7b3XErSBf1uiOekHzKwAGAIcAKwAsgmaPUqaDLwNLCFobtnJ7k0+pwILJG0j6LC+KExO+wATCJLDIoJ+hecrWU3MbAZwNUET0UaCTt6hFaxTc4IP5I1hndYDD8U517vAXQRt5N8QfMO+qOR2lajLPOBDgr6FPa3XfQTNVpsJmnVe3cNYPgOuBP4cHusDvr96uoug7hvD8/0jwvFyw1hOit0+bH4aRPA+riZoovsDQb+WSwD5hEHOOefi8SsI55xzcXmCcM45F5cnCOecc3F5gnDOORdXnRkcq02bNtalS5dkh+Gcc7XKzJkz15lZ23jr6kyC6NKlCzNmzEh2GM45V6tI+rq0dd7E5JxzLi5PEM455+LyBOGccy4uTxDOOefi8gThnHMuroQlCEljFEzLOL+U9ZL0mIIpGedKOjxm3RUKpmJcKumKRMXonHOudIm8gniWYLTO0pwGdA9/riEYbhpJewH3EEyI0g+4R1KrBMbpnHMujoQ9B2FmH5aYXrGks4DnLBhO9r+SWkpqDwwAppjZBgBJUwgSzbhExJmzK5+nplVmMixXV3Ru3YTzj+iY7DCcqzGS+aBcB3afGyA7LCut/AckXUNw9UHnzp33KIgduwr4y9Rle7SvqzuKRr0f0qc9DdNSkxuMczVErX6S2sxGAaMAMjMz92hii9ZNG/LVg4OrNC5X+/x12jL++PYX+PQozn0vmQliFcFk50U6hmWrCJqZYsunVVtUzlWCmVFQaOQXGrl5heQVFpJfYOQVFJJXUEjOrgIA8goKyS8Myjfn5NEgLYW8AiO/sJCCQmNXfiFrt+XSrGEa+YVGfkFwzILCQlZt2sleTdKD5eLyYN8v12ynbbOGFBQaBWbF8RRYENvX63No17wh+YVGYbhNfoFRGG5XaNCtbVP+dkVmkt9JVxMkM0FMBK6TNJ6gQ3qzmX0jaTLw+5iO6UHAHckK0tUP32zaCcB1//ic1k0akF9o7CooZP22XHJ2FZCRnkJ+QVCWtW47bZo2JK+wkLx849stO0lLERLkF1q1XYU0TEshLUWkpoi01BRSBAWFxqJvt9C+RQYpCtalSKSkiFTBXk0asH77Ljq0bESKRFpK0TqRmiqWfreVdxd9Vz0VcDVewhKEpHEEVwJtJGUT3JmUDmBmTxFMTH46wfy5OQRz2mJmGyTdD0wPDzWiqMPauUSRgt8fL1vLzrxCOrRsRIO0FNJTxbad+bRp1pBmGWk0aZhGy84NyM0roEPLRqSnppCWKjbm7KLTXo1pkJpSXLZlRz77tswgLSVYbpCaQs6uAto2a0haikhLFWkpKUiQnppCRnq4b0pQnpICDdNSSU8NPujTU1NIDT/MU1KUkPfhkSlLWPLd0oQc29U+ibyL6eJy1hvwy1LWjQHGJCIu5+IZcVZvRpzVO9lhOFej+JPUzjnn4vIE4ZxzLi5PEM455+LyBOGccy4uTxDOOefi8gThnHMuLk8Qzjnn4vIE4ZxzLi5PEM455+Kq1aO5Oudqh5xd+UzP2oiFA1Vt2L6LLTvyyCswvlq/nVSJL77bSqP0VLbl5nNyz3YM798tyVE7TxDOuQrbsauAtVtz2Zizi6VrtpFXUMiKDTl8t3knkli2dhvrtuaSniqy1udEOmaHlo3YkVcQjExbUOgJogaIlCDCkVX3BXYAWWZWmNConHNJkVcQ/Nd+ffYqlq3ZxvJ12yksNFZv2sG23Hy+Wredwgij1XZo2YimDdNomJ7CuYd3oGFaKp32asTR+7cGoLDQaNk4nZaNG9CyUTppqd+
2 years ago
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
2 years ago
]
},
"metadata": {
"needs_background": "light"
},
2 years ago
"output_type": "display_data"
}
],
"source": [
"from openai.embeddings_utils import cosine_similarity, get_embedding\n",
"from sklearn.metrics import PrecisionRecallDisplay\n",
"\n",
"def evaluate_embeddings_approach(\n",
2 years ago
" labels = ['negative', 'positive'], \n",
" model = EMBEDDING_MODEL,\n",
2 years ago
"):\n",
" label_embeddings = [get_embedding(label, engine=model) for label in labels]\n",
2 years ago
"\n",
" def label_score(review_embedding, label_embeddings):\n",
" return cosine_similarity(review_embedding, label_embeddings[1]) - cosine_similarity(review_embedding, label_embeddings[0])\n",
"\n",
" probas = df[\"embedding\"].apply(lambda x: label_score(x, label_embeddings))\n",
2 years ago
" preds = probas.apply(lambda x: 'positive' if x>0 else 'negative')\n",
"\n",
" report = classification_report(df.sentiment, preds)\n",
" print(report)\n",
"\n",
" display = PrecisionRecallDisplay.from_predictions(df.sentiment, probas, pos_label='positive')\n",
" _ = display.ax_.set_title(\"2-class Precision-Recall curve\")\n",
"\n",
"evaluate_embeddings_approach(labels=['negative', 'positive'], model=EMBEDDING_MODEL)"
2 years ago
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that this classifier already performs extremely well. We used similarity embeddings, and the simplest possible label name. Let's try to improve on this by using more descriptive label names, and search embeddings."
]
},
{
"cell_type": "code",
"execution_count": 3,
2 years ago
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" negative 0.98 0.73 0.84 136\n",
" positive 0.96 1.00 0.98 789\n",
2 years ago
"\n",
" accuracy 0.96 925\n",
" macro avg 0.97 0.86 0.91 925\n",
"weighted avg 0.96 0.96 0.96 925\n",
2 years ago
"\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAzlUlEQVR4nO3deZyWVf3/8debRVEENUBEEEHUFFxIEbdMcldc03Lfcq3UfvqtvvqtXCjT0upRaW5J4hKGZIaKEiq4pcUiqxtoqIOoKJuKIMx8fn9cZ/BmvGfmGmbumWF4Px+Pecx9nXMtn3MP3Oc+51zXOYoIzMzMqmrV1AGYmVnz5ArCzMyKcgVhZmZFuYIwM7OiXEGYmVlRriDMzKwoVxDWoCSdKenZpo6jIUk6RdI/c+x3i6SfNkZMjUHSHEkHptdXSbqnqWOyxuUKwpC0vqQ7JL0p6SNJUyQd1tRx5ZE+xD6V9LGk9yTdKWmjhrxGRNwbEQfn2O+CiPhZQ167kqSQ9Ekq51xJv5HUuhTXMqvkCsIA2gBvA/sBGwM/AUZI6tWUQdXBkRGxEbArMIAs/tVIatPoUTW8XVI59wNOAL7dxPE0qBbyN2pRXEEYEfFJRFwVEXMioiIiHgb+C+xW3TGStpT0gKT5kj6UdGM1+/1O0tuSlkiaJGnfgryBkiamvPck/Salt5N0TzrvIkkTJHXNUY65wKPAjuk8Iel7kmYBs1LaEamFtEjSvyTtXFuZCrvNlPmtpPdT3NMlVV7vTkk/LzjfuZJmS1ogaZSkLQryQtIFkmalWG6SpNrKmMo5G3gO6F9wvjUpVx9JT6a0DyTdK2mTPDFUJenodP0lkl6XdGhKX9VNlbZXdVVJ6pXeh7MlvQU8KelRSRdWOfdUSd9Ir7eXNDa9p69K+taaxGv5uIKwL0gfxtsBM6vJbw08DLwJ9AK6A/dVc7oJZB9kXwL+AtwvqV3K+x3wu4joCPQBRqT0M8haMlsCnYALgE9zxL0lcDjwYkHyMcAeQF9JXwGGAuen894KjEpdbHnLdDDwNbL3Z2PgW8CHRWLZH7g25XdL5616viOA3YGd036H1FbGdO7tgX2B2Wl7TculFOMWwA5k7/dVeWKoEs9A4C7gh8AmZO/PnDqcYr90/UOA4cBJBefuC2wFPCKpPTCW7N/RZsCJwB/TPlYCriBsNZLaAvcCwyLilWp2G0j2ofLD1PpYFhFFB6Yj4p6I+DAiVkbEr4H1gS+n7BXANpI6R8THEfFCQXonYJuIKI+ISRGxpIawH5S0CHgWeAr4RUHetRGxICI+Bc4Dbo2If6fzDgOWA3vWoUwrgA7A9oAi4uWImFdkv1OAoRExOSKWA5cDe1XptrsuIhZFxFvAOApaBNWYLOkT4GVgPPDHlL5G5YqI2RExNiKWR8R84DdkH9Z1dXYq69jUAp1bw7+dYq5KsX0K/B3oL2mrlHcK8EB6D48A5kTEn9O/pxeBvwHfXIOYLQdXELaKpFbA3cBnwIUF6Y8qGxz9WNIpZN8034yIlTnO+QNJL0tanD7ENwY6p+yzyb6Jv5K6kY5I6XcDY4D7JL0j6Vep4qrOMRGxSURsFRHfTR80ld4ueL0V8D+pG2ZRimdLsg/QXGWKiCeBG4GbgPcl3SapY5FdtyD71l553MdkLY3uBfu8W/B6KbARgKSZBe/3vgX77Jr2OYGsVdS+PuWS1FXSfcoGvZcA9/D536YutgReX4PjKq36G0XER8AjZK0DyFoT96bXWwF7VCnnKcDm9bi21cAVhAFZ3zpwB9AVOC4iVlTmRcRhEbFR+rmX7D90T9UyqJg+3H5E1n2yaURsAiwm69ogImZFxElk3QW/BEZKah8RKyLi6ojoC+xN9s3x9DUsWuF0xW8D16TKpPJnw4gYnrdMKe7fR8RuQF+yCu6HRXZ7h+wDDYDUPdIJmJvj/P0K3u9nquRFRIwAngeuqGe5fkH2/uyUuvlOJf1t6uhtsi7CYj4BNizYLvZhXnVK6eHASZL2AtqRta4qr/NUlXJuFBHfWYOYLQdXEFbpZrJ+4COrfAMv5j/APOA6Se2VDSrvU2S/DsBKYD7QRtIVwKpv25JOldQlIiqARSm5QtLXJe2U+s+XkHXrVNSncMntwAWS9lCmvaTBkjrkLZOk3dPxbck+/JZVE9tw4CxJ/SWtT/Zh/O+ImNMA5QC4DjhX0ub1KFcH4GNgsaTuFK/o8riDrKwHSGolqXsaJwGYApwoqa2kAcDxOc43mqxyHQL8Nf37gGwsZTtJp6XztU1/jx3WMG6rhSsII/X3nk/WB/5ule6kL4iIcuBIYBvgLaCMrNujqjHAY8BrZN0ty1i9y+dQYKakj8kGrE9MldPmwEiyyuFlsnGFu+tZTCJiInAuWRfRQrJB3jPrWKaOZB/IC1OZPgSuL3Ktx4GfkvWRzyP7hn1i1f3qUZbpwNNkYwtrWq6rybqtFpN16zywhrH8BzgL+G0611N83nr6KVnZF6br/SXH+ZanWA4s3D91Px1M9j6+Q9ZF90uycS0rAXnBIDMzK8YtCDMzK8oVhJmZFeUKwszMinIFYWZmRbWYybE6d+4cvXr1auowzMzWKpMmTfogIroUy2sxFUSvXr2YOHFiU4dhZrZWkfRmdXnuYjIzs6JcQZiZWVGuIMzMrChXEGZmVpQrCDMzK6pkFYSkocqWZZxRTb4k/V7ZkozTJO1akHeGsqUYZ0k6o1QxmplZ9UrZgriTbLbO6hwGbJt+ziObbhpJXwKuJFsQZSBwpaRNSxinmZkVUbLnICLi6SrLK1Z1NHBXZNPJviBpE0ndgEHA2IhYACBpLFlFM7wUcS79bCW3jK/PYlhm1hzs2H1jDu7nxeUaUlM+KNed1dcGKEtp1aV/gaTzyFof9OzZc42C+PSzcv4wbvYaHWtmzUMEbN6xnSuIBrZWP0kdEbcBtwEMGDBgjRa26LTR+vz32sENGpeZNa5z75rI2JfeY8TEt4kIKgIq0u/5Hy1n0w3bUhGkvM/zI6CionB79WP3334zBvb+UlMXr8k0ZQUxl2yx80o9Utpcsm6mwvTxjRaVma11npv9AQA/GjmtXueRoJVEK8GK8mDmO4u5++w9GiLEtVJTVhCjgAsl3Uc2IL04IuZJGgP8omBg+mDg8qYK0syav6lXHkzZwk9p21rpAz77kJcEQNvWQimtMr+wMqjcrtwfoNdlj/DMrA/4v79Pp6IiKFv4KV06rM/KiqCiIiiviOx1ZL/LFiylc4f1s7yIVb9Xlmf7lFcEHdq15a6zB9KxXdumeqvqpGQVhKThZC2BzpLKyO5MagsQEbeQLUx+ONn6uUvJ1rQlIhZI+hkwIZ1qSOWAtZlZMW1bt6J35/YlOfc/Z75L61ZZxfHv/35Ij003pJWgTatWtGol2rQSrVqJTTZsy7zFn7LlphvSupWyH2V5rSXe/2gZk99axNyFn9Kx2zpeQUTESbXkB/C9avKGAkNLEZeZWR5zrmvYscnHZszjgnsmN+g5S22tHqQ2M1tbRLqNZtyr7zP/o+WsKK9gRXkFH3z8Ga1bKeuyKq9gZUWwojzo3bk9h+7YtHdluYIwM2sEcz5cCsCvHns11/7rt2nFqz8/rJQh1coVhJlZIzj/a1vTtrXo3bk9HTdoy3qtW9G2dSvWayNat2pF+/Va06Z1K9q0Fjc+OZuhz/63qUN2BWFm1hhatRLn7Lt1rn3btlbtOzUCz+ZqZmZFuQVhZtbMLFuRDVY/NmMey1ZUUF4RHNSva6M/P+EKwsysmRn70nsAq90We+Wyvpy1T+9GjcMVhJlZM/PwxV9l7Mz3+PLmHQA44g/PsqK8otHjcAVhZtbMdGzXluN26wHAJ8tXNlkcHqQ2M7OiXEGYmVlRubqY0syqWwCfAnMiovE7w8zMrFFVW0FI2phsMr2TgPWA+UA7oKukF4A/RsS4RonSzMwaXU0tiJHAXcC+EbGoMEPSbsBpkraOiDtKGJ+ZmTWRaiuIiDiohrxJwKSSRGRmZs1CrYPUypwq6Yq03VPSwNKHZmZmlRYuXdHo18xzF9Mfgb3IxiIAPgJuKllEZma2SuV
2 years ago
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
2 years ago
]
},
"metadata": {
"needs_background": "light"
},
2 years ago
"output_type": "display_data"
}
],
"source": [
"evaluate_embeddings_approach(labels=['An Amazon review with a negative sentiment.', 'An Amazon review with a positive sentiment.'])"
2 years ago
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using the search embeddings and descriptive names leads to an additional improvement in performance."
]
},
{
"cell_type": "code",
"execution_count": 4,
2 years ago
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" negative 0.98 0.73 0.84 136\n",
" positive 0.96 1.00 0.98 789\n",
2 years ago
"\n",
" accuracy 0.96 925\n",
" macro avg 0.97 0.86 0.91 925\n",
"weighted avg 0.96 0.96 0.96 925\n",
2 years ago
"\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAze0lEQVR4nO3dd5hV1dn+8e9NUSygBhCRIogaBQtRxB6JXbFr7L2gie2nb8yLb2IjMWosuWLsRiKowSAxBhVFFLBFE4pUG2hQB1FRmkoRZp7fH3sPHsYzM3uYOXOG4f5c11yz99rtWWdgr7PW2nstRQRmZmYVNSl2AGZm1jC5gDAzs7xcQJiZWV4uIMzMLC8XEGZmlpcLCDMzy8sFhNUpSWdJerXYcdQlSadKej7DfvdKuro+YqoPkmZJOiBdvk7SI8WOyeqXCwhD0rqSHpT0oaSvJE2SdGix48oivYktkfS1pM8kPSRpw7q8RkQ8GhEHZdjvwoj4TV1eu5ykkPRNms/Zkm6X1LQQ1zIr5wLCAJoBHwP7AhsBvwaGSupSzKBq4IiI2BDYGehFEv8qJDWr96jq3k5pPvcFTgTOKXI8daqR/I0aFRcQRkR8ExHXRcSsiCiLiKeB/wK7VHaMpE6SnpA0V9KXku6sZL8/SvpY0iJJEyTtk7Ott6Tx6bbPJN2epreQ9Eh63gWSxklqlyEfs4Fnge3T84SkiyTNAGakaYenNaQFkv4lacfq8pTbbKbEHyR9nsY9VVL59R6S9Nuc850vaaakeZKGS9o8Z1tIulDSjDSWuySpujym+ZwJvAb0zDnf6uSrm6TRadoXkh6VtHGWGCqSdFR6/UWS3pd0SJq+spkqXV/ZVCWpS/o5nCvpI2C0pGclXVzh3JMlHZsubytpVPqZvivphNWJ17JxAWHfk96MtwGmV7K9KfA08CHQBegAPFbJ6caR3Mh+APwVeFxSi3TbH4E/RkQroBswNE0/k6Qm0wloDVwILMkQdyfgMODNnOSjgd2A7pJ+BAwELkjPex8wPG1iy5qng4Afk3w+GwEnAF/miWU/4MZ0e/v0vBXPdziwK7Bjut/B1eUxPfe2wD7AzHR9dfOlNMbNge1IPu/rssRQIZ7ewGDgSmBjks9nVg1OsW96/YOBIcDJOefuDmwBPCNpA2AUyb+jTYGTgLvTfawAXEDYKiQ1Bx4FBkXEO5Xs1pvkpnJlWvtYGhF5O6Yj4pGI+DIiVkTEbcC6wA/TzcuBrSS1iYivI+KNnPTWwFYRURoREyJiURVhPylpAfAq8BLwu5xtN0bEvIhYAvQD7ouIf6fnHQQsA3avQZ6WAy2BbQFFxNsRMSfPfqcCAyNiYkQsA64C9qjQbHdTRCyIiI+AMeTUCCoxUdI3wNvAWODuNH218hURMyNiVEQsi4i5wO0kN+uaOjfN66i0Bjq7in87+VyXxrYE+AfQU9IW6bZTgSfSz/BwYFZE/CX99/Qm8Hfgp6sRs2XgAsJWktQEeBj4Frg4J/1ZJZ2jX0s6leSb5ocRsSLDOX8h6W1JC9Ob+EZAm3TzuSTfxN9Jm5EOT9MfBkYCj0n6RNLv04KrMkdHxMYRsUVE/Dy90ZT7OGd5C+B/0maYBWk8nUhuoJnyFBGjgTuBu4DPJd0vqVWeXTcn+dZeftzXJDWNDjn7fJqzvBjYEEDS9JzPe5+cfXZO9zmRpFa0QW3yJamdpMeUdHovAh7hu79NTXQC3l+N48qt/BtFxFfAMyS1A0hqE4+my1sAu1XI56nAZrW4tlXBBYQBSds68CDQDjguIpaXb4uIQyNiw/TnUZL/0J1VTadienP7JUnzySYRsTGwkKRpg4iYEREnkzQX3AwMk7RBRCyPiOsjojuwJ8k3xzNWM2u5wxV/DNyQFiblP+tHxJCseUrjviMidgG6kxRwV+bZ7ROSGxoAafNIa2B2hvP3yPm8X6mwLSJiKPA6cE0t8/U7ks9nh7SZ7zTSv00NfUzSRJjPN8D6Oev5buYVh5QeApwsaQ+gBUntqvw6L1XI54YR8bPViNkycAFh5e4haQc+osI38Hz+A8wBbpK0gZJO5b3y7NcSWAHMBZpJugZY+W1b0mmS2kZEGbAgTS6T9BNJO6Tt54tImnXKapO51APAhZJ2U2IDSX0ltcyaJ0m7psc3J7n5La0ktiHA2ZJ6SlqX5Gb874iYVQf5ALgJOF/SZrXIV0vga2ChpA7kL+iyeJAkr/tLaiKpQ9pPAjAJOElSc0m9gOMznG8ESeE6APhb+u8Dkr6UbSSdnp6vefr32G4147ZquIAw0vbeC0jawD+t0Jz0PRFRChwBbAV8BJSQNHtUNBJ4DniPpLllKas2+RwCTJf0NUmH9Ulp4bQZMIykcHibpF/h4Vpmk4gYD5xP0kQ0n6ST96wa5qkVyQ15fpqnL4Fb8lzrBeBqkjbyOSTfsE+quF8t8jIVeJmkb2F183U9SbPVQpJmnSdWM5b/AGcDf0jP9RLf1Z6uJsn7/PR6f81wvmVpLAfk7p82Px1E8jl+QtJEdzNJv5YVgDxhkJmZ5eMahJmZ5eUCwszM8nIBYWZmebmAMDOzvBrN4Fht2rSJLl26FDsMM7M1yoQJE76IiLb5tjWaAqJLly6MHz++2GGYma1RJH1Y2TY3MZmZWV4uIMzMLC8XEGZmlpcLCDMzy8sFhJmZ5VWwAkLSQCXTMk6rZLsk3aFkSsYpknbO2XamkqkYZ0g6s1AxmplZ5QpZg3iIZLTOyhwKbJ3+9CMZbhpJPwCuJZkQpTdwraRNChinmZnlUbD3ICLi5QrTK1Z0FDA4kuFk35C0saT2QB9gVETMA5A0iqSgGVKIOBd/u4J7x9ZmMiwzawi277ARB/Xw5HJ1qZgvynVg1bkBStK0ytK/R1I/ktoHnTt3Xq0glnxbyp/GzFytY82sYYiAzVq1cAFRx9boN6kj4n7gfoBevXqt1sQWrTdcl//e2LdO4zKz+nX+4PGMeuszho7/mIigLKAs/T33q2Vssn5zyoJ023fbI6CsLHd91WP323ZTenf9QbGzVzTFLCBmk0x2Xq5jmjabpJkpN31svUVlZmuc12Z+AcAvh02p1XkkaCLRRLC8NJj+yUIePne3ughxjVTMAmI4cLGkx0g6pBdGxBxJI4Hf5XRMHwRcVawgzazhm3ztQZTMX0Lzpkpv8MlNXhIAzZsKpWnl23MLg/L18v0BuvR/hldmfMH//WMqZWVByfwltG25LivKgrKyoLQskuVIfpfMW0yblusm2yJW/l5RmuxTWha0bNGcwef2plWL5sX6qGqkYAWEpCEkNYE2kkpInkxqDhAR95JMTH4Yyfy5i0nmtCUi5kn6DTAuPdWA8g5rM7N8mjdtQtc2GxTk3M9P/5SmTZKC49///ZKOm6xPE0GzJk1o0kQ0ayKaNBEbr9+cOQuX0GmT9WnaRMmPkm1NJT7/aikTP1rA7PlLaNV+LS8gIuLkarYHcFEl2wYCAwsRl5lZFrNuqtu+yeemzeHCRybW6TkLbY3upDYzW1NE+hjNmHc/Z+5Xy1heWsby0mB5aRlzFi5h/XWasaK0jBVlwfLSoGubDThk++I+leUCwsysHsz6cjEAv3/u3Uz7r9usCe/+9tBChlQtFxBmZvXggh9vSfOmYrv2rSgtCzZarznNmzZhnWZJp/mG6zajWdMmNGsq7hw9k4Gv/rfYIbuAMDOrD02aiPP22TLTvs2bqvqd6oFHczUzs7xcgzAza2CWLk86q5+bNoely8soLQsO7NGu3t+fcAFhZtbAjHrrM4BVHou9dml3zt6ra73G4QLCzKyBefrSvRk1/TN+uFlLAA7/06ssLy2r9zhcQJiZNTCtWjTnuF06AvDNshVFi8Od1GZmlpcLCDMzyytTE1M6surmwBJgVkTUf2OYmZnVq0oLCEkbkQymdzKwDjAXaAG0k/QGcHdEjKmXKM3MrN5VVYMYBgwG9omIBbkbJO0CnC5py4h4sIDxmZlZkVRaQETEgVVsmwBMKEhEZmbWIFTbSa3EaZKuSdc7S+pd+NDMzKyYsjzFdDewB0lfBMBXwF0Fi8jMzBqELAXEbhFxEbA
2 years ago
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
2 years ago
]
},
"metadata": {
"needs_background": "light"
},
2 years ago
"output_type": "display_data"
}
],
"source": [
"evaluate_embeddings_approach(labels=['An Amazon review with a negative sentiment.', 'An Amazon review with a positive sentiment.'])"
2 years ago
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As shown above, zero-shot classification with embeddings can lead to great results, especially when the labels are more descriptive than just simple words."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai",
"language": "python",
"name": "python3"
2 years ago
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
2 years ago
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
2 years ago
},
"nbformat": 4,
"nbformat_minor": 2
}