{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Vq31CdSRpgkI" }, "source": [ "# Customizing embeddings\n", "\n", "This notebook demonstrates one way to customize OpenAI embeddings to a particular task.\n", "\n", "The input is training data in the form of [text_1, text_2, label] where label is +1 if the pairs are similar and -1 if the pairs are dissimilar.\n", "\n", "The output is a matrix that you can use to multiply your embeddings. The product of this multiplication is a 'custom embedding' that will better emphasize aspects of the text relevant to your use case. In binary classification use cases, we've seen error rates drop by as much as 50%.\n", "\n", "In the following example, I use 1,000 sentence pairs picked from the SNLI corpus. Each pair of sentences are logically entailed (i.e., one implies the other). These pairs are our positives (label = 1). We generate synthetic negatives by combining sentences from different pairs, which are presumed to not be logically entailed (label = -1).\n", "\n", "For a clustering use case, you can generate positives by creating pairs from texts in the same clusters and generate negatives by creating pairs from sentences in different clusters.\n", "\n", "With other data sets, we have seen decent improvement with as little as ~100 training examples. Of course, performance will be better with more examples." ] }, { "cell_type": "markdown", "metadata": { "id": "arB38jFwpgkK" }, "source": [ "# 0. Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "ifvM7g4apgkK" }, "outputs": [], "source": [ "# imports\n", "from typing import List, Tuple # for type hints\n", "\n", "import numpy as np # for manipulating arrays\n", "import pandas as pd # for manipulating data in dataframes\n", "import pickle # for saving the embeddings cache\n", "import plotly.express as px # for plots\n", "import random # for generating run IDs\n", "from sklearn.model_selection import train_test_split # for splitting train & test data\n", "import torch # for matrix optimization\n", "\n", "from openai.embeddings_utils import get_embedding, cosine_similarity # for embeddings\n" ] }, { "cell_type": "markdown", "metadata": { "id": "DtBbryAapgkL" }, "source": [ "## 1. Inputs\n", "\n", "Most inputs are here. The key things to change are where to load your datset from, where to save a cache of embeddings to, and which embedding engine you want to use.\n", "\n", "Depending on how your data is formatted, you'll want to rewrite the process_input_data function." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "UzxcWRCkpgkM" }, "outputs": [], "source": [ "# input parameters\n", "embedding_cache_path = \"data/snli_embedding_cache.pkl\" # embeddings will be saved/loaded here\n", "default_embedding_engine = \"babbage-similarity\" # text-embedding-ada-002 is recommended\n", "num_pairs_to_embed = 1000 # 1000 is arbitrary\n", "local_dataset_path = \"data/snli_1.0_train_2k.csv\" # download from: https://nlp.stanford.edu/projects/snli/\n", "\n", "\n", "def process_input_data(df: pd.DataFrame) -> pd.DataFrame:\n", " # you can customize this to preprocess your own dataset\n", " # output should be a dataframe with 3 columns: text_1, text_2, label (1 for similar, -1 for dissimilar)\n", " df[\"label\"] = df[\"gold_label\"]\n", " df = df[df[\"label\"].isin([\"entailment\"])]\n", " df[\"label\"] = df[\"label\"].apply(lambda x: {\"entailment\": 1, \"contradiction\": -1}[x])\n", " df = df.rename(columns={\"sentence1\": \"text_1\", \"sentence2\": \"text_2\"})\n", " df = df[[\"text_1\", \"text_2\", \"label\"]]\n", " df = df.head(num_pairs_to_embed)\n", " return df\n" ] }, { "cell_type": "markdown", "metadata": { "id": "aBbH71hEpgkM" }, "source": [ "## 2. Load and process input data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "kAKLjYG6pgkN", "outputId": "dc178688-e97d-4ad0-b26c-dff67b858966" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/r4/x3kdvs816995fnnph2gdpwp40000gn/T/ipykernel_17509/1977422881.py:13: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " df[\"label\"] = df[\"label\"].apply(lambda x: {\"entailment\": 1, \"contradiction\": -1}[x])\n" ] }, { "data": { "text/html": [ "
\n", " | text_1 | \n", "text_2 | \n", "label | \n", "
---|---|---|---|
2 | \n", "A person on a horse jumps over a broken down a... | \n", "A person is outdoors, on a horse. | \n", "1 | \n", "
4 | \n", "Children smiling and waving at camera | \n", "There are children present | \n", "1 | \n", "
7 | \n", "A boy is jumping on skateboard in the middle o... | \n", "The boy does a skateboarding trick. | \n", "1 | \n", "
14 | \n", "Two blond women are hugging one another. | \n", "There are women showing affection. | \n", "1 | \n", "
17 | \n", "A few people in a restaurant setting, one of t... | \n", "The diners are at a restaurant. | \n", "1 | \n", "