openai-cookbook/examples/Customizing_embeddings.ipynb

13022 lines
438 KiB
Plaintext

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "Vq31CdSRpgkI"
},
"source": [
"# Customizing embeddings\n",
"\n",
"This notebook demonstrates one way to customize OpenAI embeddings to a particular task.\n",
"\n",
"The input is training data in the form of [text_1, text_2, label] where label is +1 if the pairs are similar and -1 if the pairs are dissimilar.\n",
"\n",
"The output is a matrix that you can use to multiply your embeddings. The product of this multiplication is a 'custom embedding' that will better emphasize aspects of the text relevant to your use case. In binary classification use cases, we've seen error rates drop by as much as 50%.\n",
"\n",
"In the following example, I use 1,000 sentence pairs picked from the SNLI corpus. Each pair of sentences are logically entailed (i.e., one implies the other). These pairs are our positives (label = 1). We generate synthetic negatives by combining sentences from different pairs, which are presumed to not be logically entailed (label = -1).\n",
"\n",
"For a clustering use case, you can generate positives by creating pairs from texts in the same clusters and generate negatives by creating pairs from sentences in different clusters.\n",
"\n",
"With other data sets, we have seen decent improvement with as little as ~100 training examples. Of course, performance will be better with more examples."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "arB38jFwpgkK"
},
"source": [
"# 0. Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "ifvM7g4apgkK"
},
"outputs": [],
"source": [
"# imports\n",
"from typing import List, Tuple # for type hints\n",
"\n",
"import numpy as np # for manipulating arrays\n",
"import pandas as pd # for manipulating data in dataframes\n",
"import pickle # for saving the embeddings cache\n",
"import plotly.express as px # for plots\n",
"import random # for generating run IDs\n",
"from sklearn.model_selection import train_test_split # for splitting train & test data\n",
"import torch # for matrix optimization\n",
"\n",
"from utils.embeddings_utils import get_embedding, cosine_similarity # for embeddings\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "DtBbryAapgkL"
},
"source": [
"## 1. Inputs\n",
"\n",
"Most inputs are here. The key things to change are where to load your datset from, where to save a cache of embeddings to, and which embedding engine you want to use.\n",
"\n",
"Depending on how your data is formatted, you'll want to rewrite the process_input_data function."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "UzxcWRCkpgkM"
},
"outputs": [],
"source": [
"# input parameters\n",
"embedding_cache_path = \"data/snli_embedding_cache.pkl\" # embeddings will be saved/loaded here\n",
"default_embedding_engine = \"babbage-similarity\" # text-embedding-ada-002 is recommended\n",
"num_pairs_to_embed = 1000 # 1000 is arbitrary\n",
"local_dataset_path = \"data/snli_1.0_train_2k.csv\" # download from: https://nlp.stanford.edu/projects/snli/\n",
"\n",
"\n",
"def process_input_data(df: pd.DataFrame) -> pd.DataFrame:\n",
" # you can customize this to preprocess your own dataset\n",
" # output should be a dataframe with 3 columns: text_1, text_2, label (1 for similar, -1 for dissimilar)\n",
" df[\"label\"] = df[\"gold_label\"]\n",
" df = df[df[\"label\"].isin([\"entailment\"])]\n",
" df[\"label\"] = df[\"label\"].apply(lambda x: {\"entailment\": 1, \"contradiction\": -1}[x])\n",
" df = df.rename(columns={\"sentence1\": \"text_1\", \"sentence2\": \"text_2\"})\n",
" df = df[[\"text_1\", \"text_2\", \"label\"]]\n",
" df = df.head(num_pairs_to_embed)\n",
" return df\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "aBbH71hEpgkM"
},
"source": [
"## 2. Load and process input data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "kAKLjYG6pgkN",
"outputId": "dc178688-e97d-4ad0-b26c-dff67b858966"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/r4/x3kdvs816995fnnph2gdpwp40000gn/T/ipykernel_17509/1977422881.py:13: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" df[\"label\"] = df[\"label\"].apply(lambda x: {\"entailment\": 1, \"contradiction\": -1}[x])\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text_1</th>\n",
" <th>text_2</th>\n",
" <th>label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>A person on a horse jumps over a broken down a...</td>\n",
" <td>A person is outdoors, on a horse.</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Children smiling and waving at camera</td>\n",
" <td>There are children present</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>A boy is jumping on skateboard in the middle o...</td>\n",
" <td>The boy does a skateboarding trick.</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>Two blond women are hugging one another.</td>\n",
" <td>There are women showing affection.</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>A few people in a restaurant setting, one of t...</td>\n",
" <td>The diners are at a restaurant.</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" text_1 \\\n",
"2 A person on a horse jumps over a broken down a... \n",
"4 Children smiling and waving at camera \n",
"7 A boy is jumping on skateboard in the middle o... \n",
"14 Two blond women are hugging one another. \n",
"17 A few people in a restaurant setting, one of t... \n",
"\n",
" text_2 label \n",
"2 A person is outdoors, on a horse. 1 \n",
"4 There are children present 1 \n",
"7 The boy does a skateboarding trick. 1 \n",
"14 There are women showing affection. 1 \n",
"17 The diners are at a restaurant. 1 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# load data\n",
"df = pd.read_csv(local_dataset_path)\n",
"\n",
"# process input data\n",
"df = process_input_data(df) # this demonstrates training data containing only positives\n",
"\n",
"# view data\n",
"df.head()\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "z2F1cCoYpgkO"
},
"source": [
"## 3. Split data into training test sets\n",
"\n",
"Note that it's important to split data into training and test sets *before* generating synethetic negatives or positives. You don't want any text strings in the training data to show up in the test data. If there's contamination, the test metrics will look better than they'll actually be in production."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "50QmnH2qpgkO",
"outputId": "6144029b-eb29-439e-9990-7aeb28168e56"
},
"outputs": [],
"source": [
"# split data into train and test sets\n",
"test_fraction = 0.5 # 0.5 is fairly arbitrary\n",
"random_seed = 123 # random seed is arbitrary, but is helpful in reproducibility\n",
"train_df, test_df = train_test_split(\n",
" df, test_size=test_fraction, stratify=df[\"label\"], random_state=random_seed\n",
")\n",
"train_df.loc[:, \"dataset\"] = \"train\"\n",
"test_df.loc[:, \"dataset\"] = \"test\"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "MzAFkA2opgkP"
},
"source": [
"## 4. Generate synthetic negatives\n",
"\n",
"This is another piece of the code that you will need to modify to match your use case.\n",
"\n",
"If you have data with positives and negatives, you can skip this section.\n",
"\n",
"If you have data with only positives, you can mostly keep it as is, where it generates negatives only.\n",
"\n",
"If you have multiclass data, you will want to generate both positives and negatives. The positives can be pairs of text that share labels, and the negatives can be pairs of text that do not share labels.\n",
"\n",
"The final output should be a dataframe with text pairs, where each pair is labeled -1 or 1."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "rUYd9V0zpgkP"
},
"outputs": [],
"source": [
"# generate negatives\n",
"def dataframe_of_negatives(dataframe_of_positives: pd.DataFrame) -> pd.DataFrame:\n",
" \"\"\"Return dataframe of negative pairs made by combining elements of positive pairs.\"\"\"\n",
" texts = set(dataframe_of_positives[\"text_1\"].values) | set(\n",
" dataframe_of_positives[\"text_2\"].values\n",
" )\n",
" all_pairs = {(t1, t2) for t1 in texts for t2 in texts if t1 < t2}\n",
" positive_pairs = set(\n",
" tuple(text_pair)\n",
" for text_pair in dataframe_of_positives[[\"text_1\", \"text_2\"]].values\n",
" )\n",
" negative_pairs = all_pairs - positive_pairs\n",
" df_of_negatives = pd.DataFrame(list(negative_pairs), columns=[\"text_1\", \"text_2\"])\n",
" df_of_negatives[\"label\"] = -1\n",
" return df_of_negatives\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "Rkh8-J89pgkP"
},
"outputs": [],
"source": [
"negatives_per_positive = (\n",
" 1 # it will work at higher values too, but more data will be slower\n",
")\n",
"# generate negatives for training dataset\n",
"train_df_negatives = dataframe_of_negatives(train_df)\n",
"train_df_negatives[\"dataset\"] = \"train\"\n",
"# generate negatives for test dataset\n",
"test_df_negatives = dataframe_of_negatives(test_df)\n",
"test_df_negatives[\"dataset\"] = \"test\"\n",
"# sample negatives and combine with positives\n",
"train_df = pd.concat(\n",
" [\n",
" train_df,\n",
" train_df_negatives.sample(\n",
" n=len(train_df) * negatives_per_positive, random_state=random_seed\n",
" ),\n",
" ]\n",
")\n",
"test_df = pd.concat(\n",
" [\n",
" test_df,\n",
" test_df_negatives.sample(\n",
" n=len(test_df) * negatives_per_positive, random_state=random_seed\n",
" ),\n",
" ]\n",
")\n",
"\n",
"df = pd.concat([train_df, test_df])\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "8MVSLMSrpgkQ"
},
"source": [
"## 5. Calculate embeddings and cosine similarities\n",
"\n",
"Here, I create a cache to save the embeddings. This is handy so that you don't have to pay again if you want to run the code again."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "R6tWgS_ApgkQ"
},
"outputs": [],
"source": [
"# establish a cache of embeddings to avoid recomputing\n",
"# cache is a dict of tuples (text, engine) -> embedding\n",
"try:\n",
" with open(embedding_cache_path, \"rb\") as f:\n",
" embedding_cache = pickle.load(f)\n",
"except FileNotFoundError:\n",
" precomputed_embedding_cache_path = \"https://cdn.openai.com/API/examples/data/snli_embedding_cache.pkl\"\n",
" embedding_cache = pd.read_pickle(precomputed_embedding_cache_path)\n",
"\n",
"\n",
"# this function will get embeddings from the cache and save them there afterward\n",
"def get_embedding_with_cache(\n",
" text: str,\n",
" engine: str = default_embedding_engine,\n",
" embedding_cache: dict = embedding_cache,\n",
" embedding_cache_path: str = embedding_cache_path,\n",
") -> list:\n",
" if (text, engine) not in embedding_cache.keys():\n",
" # if not in cache, call API to get embedding\n",
" embedding_cache[(text, engine)] = get_embedding(text, engine)\n",
" # save embeddings cache to disk after each update\n",
" with open(embedding_cache_path, \"wb\") as embedding_cache_file:\n",
" pickle.dump(embedding_cache, embedding_cache_file)\n",
" return embedding_cache[(text, engine)]\n",
"\n",
"\n",
"# create column of embeddings\n",
"for column in [\"text_1\", \"text_2\"]:\n",
" df[f\"{column}_embedding\"] = df[column].apply(get_embedding_with_cache)\n",
"\n",
"# create column of cosine similarity between embeddings\n",
"df[\"cosine_similarity\"] = df.apply(\n",
" lambda row: cosine_similarity(row[\"text_1_embedding\"], row[\"text_2_embedding\"]),\n",
" axis=1,\n",
")\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "4pwn608LpgkQ"
},
"source": [
"## 6. Plot distribution of cosine similarity\n",
"\n",
"Here we measure similarity of text using cosine similarity. In our experience, most distance functions (L1, L2, cosine similarity) all work about the same. Note that our embeddings are already normalized to length 1, so cosine similarity is equivalent to dot product.\n",
"\n",
"The graphs show how much the overlap there is between the distribution of cosine similarities for similar and dissimilar pairs. If there is a high amount of overlap, that means there are some dissimilar pairs with greater cosine similarity than some similar pairs.\n",
"\n",
"The accuracy I compute is the accuracy of a simple rule that predicts 'similar (1)' if the cosine similarity is above some threshold X and otherwise predicts 'dissimilar (0)'."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "SoeDF8vqpgkQ",
"outputId": "17db817e-1702-4089-c4e8-8ca32d294930"
},
"outputs": [
{
"data": {
"application/vnd.plotly.v1+json": {
"config": {
"plotlyServerURL": "https://plot.ly"
},
"data": [
{
"alignmentgroup": "True",
"bingroup": "x",
"hovertemplate": "label=1<br>dataset=train<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
"legendgroup": "1",
"marker": {
"color": "#636efa",
"opacity": 0.5,
"pattern": {
"shape": ""
}
},
"name": "1",
"offsetgroup": "1",
"orientation": "v",
"showlegend": true,
"type": "histogram",
"x": [
0.9267355919345323,
0.8959824209230313,
0.9119725922265434,
0.854066984766886,
0.892887342475088,
0.9197115510183956,
0.8645429616364463,
0.8314164166177409,
0.7243313891695883,
0.8819971504547897,
0.7956215051893748,
0.7959481856578454,
0.8682525487547492,
0.8973704561532523,
0.8648042605746787,
0.9236698981903941,
0.9834804742323746,
0.8152447410515193,
0.8251720025778885,
0.8138195587168373,
0.8041885875094662,
0.9329690881953521,
0.9560346908585096,
0.9727875559800526,
0.8739787488907723,
0.8208200937845748,
0.7246913159387628,
0.9324916305400826,
0.8285737172791849,
0.8797008558439083,
0.820333215489803,
0.9370111008629193,
0.8983827475968527,
0.8312111255338568,
0.8164052516323846,
0.8908148647559723,
0.7466264012908705,
0.749651964301876,
0.87375582683117,
0.7849398161817545,
0.8309506403579568,
0.9307212180139316,
0.8281747408531835,
0.9529528469109777,
0.7828662080109449,
0.887100957550481,
0.9000278356226864,
0.8805448739521785,
0.930239131235984,
0.8801954897699975,
0.8529206900756547,
0.9467797362617967,
0.950367690540818,
0.7030531843080301,
0.8643992401506337,
0.8536886651933216,
0.9619331104636477,
0.9798279220663803,
0.8545739729953584,
0.8957115039259408,
0.8241137169318955,
0.8234984837204449,
0.8936706246811023,
0.8987178417246647,
0.9081806514389928,
0.9208852073112466,
0.8961858080980447,
0.8831329491347061,
0.9282623092166472,
0.8990849228018752,
0.8284548395672304,
0.8202091311790463,
0.8647762700375631,
0.840136958242196,
0.9887387562448309,
0.833342655983456,
0.828533112298018,
0.9117757392019592,
0.8706628948983062,
0.9279786042901599,
0.7389559895894061,
0.8433932035736679,
0.9240307526722153,
0.950769936374633,
0.8586024431762636,
0.8685107120156628,
0.875535036209879,
0.9894909159640456,
0.8279650063634755,
0.9108703739124493,
0.9090161898700972,
0.8603952576151226,
0.7791958170479588,
0.8800175078297319,
0.8442387843179446,
0.7672266106337577,
0.9379753906877134,
0.8637536227406404,
0.9190295684769377,
0.813748744480001,
0.9134886375270357,
0.8043760086358624,
0.8792300492020616,
0.8716796296133559,
0.8669146731224541,
0.7736224659067634,
0.9437235456545872,
0.905686329875817,
0.9534823418409002,
0.9150626356433706,
0.9409873570135594,
0.8111212499450311,
0.9171209901271343,
0.9126582215550586,
0.8337042981493767,
0.7317859043960847,
0.8444929460069593,
0.8561920422842992,
0.7765276739806221,
0.8526116779814181,
0.9178549171995068,
0.9238337665547537,
0.7218806795387105,
0.8180162420894919,
0.9687438998025378,
0.8354559162714565,
0.9146160667909478,
0.8082103463995212,
0.9563444955040953,
0.9066029892990775,
0.8485102485675593,
0.8154210965438722,
0.8860338155558548,
0.9280705420457523,
0.9835182434488767,
0.9653797793986199,
0.7815047672209323,
0.7156150747063292,
0.9256075945804699,
0.8135073898727201,
0.965501517353668,
0.8222681603043882,
0.907212186549752,
0.8611990749004483,
0.9075083701591861,
0.9452697087352507,
0.8792221638449876,
0.9261547230875113,
0.8628843081873457,
0.7825774684929468,
0.826587869384389,
0.7399697967202743,
0.7855475169419611,
0.9111614040328055,
0.8871057104665023,
0.8824403169032524,
0.8618250237448342,
0.9787037901590712,
0.8066230600465234,
0.82769109211187,
0.9246432066709112,
0.8840853142719706,
0.786484352300887,
0.9106863314452291,
0.9342800776988389,
0.8573335079738363,
0.7780871062391552,
0.7913314665963198,
0.8574377397709955,
0.9078366114351649,
0.7529739278117176,
0.8630106564789936,
0.9051526759332171,
0.7715467442267975,
0.894146587858075,
0.8095881901541373,
0.7733578316514722,
0.7600408374028372,
0.7819972017869313,
0.9003461714649055,
0.7428048011790054,
0.8645936498599726,
0.8158769881622202,
0.8338827592994027,
0.8272653967731632,
0.9017517376316297,
0.8480852025321463,
0.7970818331146111,
0.84837067058732,
0.9272909953469497,
0.9511439765702141,
0.8796630924442385,
0.8297595339070903,
0.8132311695727563,
0.846096509352579,
0.8787645385610889,
0.8591367322994465,
0.8452813265723063,
0.708120854745005,
0.8769677220320785,
0.957621649201,
0.7463356296247886,
0.8618039385828121,
0.9560112462834625,
0.8478374736767776,
0.769289015775232,
0.8456866935685449,
0.9014601942765245,
0.8816990623836058,
0.8836365012367311,
0.8078009762348856,
0.898471670333294,
0.9064470715463968,
0.8762712610709371,
0.9178852315161201,
0.7896235958446132,
0.8939345739482307,
0.9534018415944343,
0.8358882135213556,
0.9488657107840998,
0.9046799883600772,
0.7583576517359092,
0.9080459939022663,
0.7709722685822517,
0.9635512477648485,
0.9792712672362888,
0.8526700765974843,
0.8278133097990042,
0.9735858611728696,
0.721230194834802,
0.8257425311085005,
0.9243205490037274,
0.9183796453898277,
0.9029146937248601,
0.9410246048181872,
0.9609604036664618,
0.7467407974920351,
0.8831901217813377,
0.8173287200784538,
0.8067347032404598,
0.7921957436175584,
0.9110994793014074,
0.8678737504217436,
0.9117743220816726,
0.781256498099708,
0.8553931170417658,
0.8798565764815073,
0.8485358179238083,
0.7748765495278522,
0.9432062986428599,
0.8328320716129907,
0.7983629771463491,
0.9345589964322458,
0.780034700168418,
0.9894680324844076,
0.8239308905190431,
0.8236003498823307,
0.8346101074957768,
0.8273793488443836,
0.7872103673165679,
0.9502897884724039,
0.83306630344504,
0.9346568240975981,
0.8082083566882774,
0.8920672687911747,
0.8566523137465843,
0.7636170858064102,
0.8271498146927989,
0.8450776675259235,
0.9045266249905214,
0.8578964053464326,
0.8673866119197355,
0.8804224181917263,
0.8199459545727489,
0.932410033971128,
0.9096821262750205,
0.8658255622247675,
0.9386720385226357,
0.8517830114204678,
0.889433736353776,
0.978847593529176,
0.8369738178686744,
0.8438616304896431,
0.9457050132083015,
0.8699723446274389,
0.7795221418941651,
0.9136284834408402,
0.839461038218975,
0.9453279809479284,
0.7899532071809374,
0.9078373589637019,
0.8434980548175782,
0.8112068688921613,
0.9466506411385276,
0.931413666561485,
0.7932453730948084,
0.8205411414836395,
0.9243834389669592,
0.719616209472987,
0.7552985082978257,
0.9593440978701407,
0.917557937781671,
0.8643861907339583,
0.8315201130646898,
0.7608819746989568,
0.9704324558544556,
0.8037085307118571,
0.7785270629127378,
0.8044961889638995,
0.8313307525316546,
0.8064106357955508,
0.9291149169587132,
0.8412940941318643,
0.6917091086045903,
0.895204432897799,
0.8182250729145697,
0.8645847241904496,
0.8532020284403578,
0.8143634597102418,
0.8825747762544934,
0.7764652529619696,
0.850099367461302,
0.8616919096196407,
0.9257293995599788,
0.935772204341753,
0.7742657205171264,
0.7898710076766889,
0.8590438495382458,
0.9317809680608741,
0.9087109944408146,
0.949297999227784,
0.8813316524294974,
0.7372081408133433,
0.8838176418404169
],
"xaxis": "x2",
"yaxis": "y2"
},
{
"alignmentgroup": "True",
"bingroup": "x",
"hovertemplate": "label=1<br>dataset=test<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
"legendgroup": "1",
"marker": {
"color": "#636efa",
"opacity": 0.5,
"pattern": {
"shape": ""
}
},
"name": "1",
"offsetgroup": "1",
"orientation": "v",
"showlegend": false,
"type": "histogram",
"x": [
0.9424796847531719,
0.9078956620721231,
0.8334324872928349,
0.9352180099421901,
0.9055462980371385,
0.8981939710469806,
0.8310153256803275,
0.8504676050313356,
0.8456281884302819,
0.8845204616348977,
0.9575409743129409,
0.8867362111499236,
0.8268148049156373,
0.9197424487445726,
0.7868932880014918,
0.7584994087629051,
0.9184151106611779,
0.8634069821364065,
0.8347803687511831,
0.8293627315810226,
0.9290633380383377,
0.8385821691321573,
0.9389267221382397,
0.890818442649049,
0.86634760562554,
0.8406483291061461,
0.8084243397427517,
0.8909500804242533,
0.9262896019507018,
0.8955541231110083,
0.8055268512037249,
0.758607678624549,
0.9609058491722305,
0.91495905751538,
0.8670137150812314,
0.8813831602682358,
0.8602255150075953,
0.9239960998195261,
0.91732217769794,
0.8037285395908439,
0.9196344979660896,
0.8179495018382927,
0.9015423003762686,
0.9054394610623084,
0.9309412941515826,
0.9421722892808463,
0.7632823176731237,
0.8622055676377434,
0.9855273108710355,
0.9144155566597573,
0.916057392559463,
0.8027504546689012,
0.7131090055200247,
0.86174194818737,
0.982873171115144,
0.8100227539053508,
0.8923878603823862,
0.8096643432711604,
0.8707613707684018,
0.8786740148818871,
0.8274639911687568,
0.8927098767487393,
0.9565597075186062,
0.9060728095743347,
0.738307517030644,
0.9645943657917937,
0.8755564011033787,
0.879644342330288,
0.8679709669180851,
0.9304235148160597,
0.8902804960376421,
0.8748369557157689,
0.9999999999999999,
0.7979398167195422,
0.8182553472910762,
0.7782108671759803,
0.8427610550014142,
0.8696408842646327,
0.8747903026537825,
0.914973368517656,
0.9651568968718233,
0.9775547988384379,
0.8964005885715726,
0.8689760349348122,
0.8501707274497268,
0.9069421089316828,
0.7682621587597694,
0.9658683147667629,
0.8946443487011166,
0.7855154267442119,
0.8963791544804274,
0.8062904921192978,
0.8165205978948239,
0.8392522254625653,
0.9456080863923884,
0.7904904126133966,
0.833126792097794,
0.7852156051245299,
0.7859162398886373,
0.9097674992318959,
0.8868692153350769,
0.9391826646753169,
0.9428151202937937,
0.7923603877885725,
0.9018727189911658,
0.9723161942344292,
0.7820369113228325,
0.9667234201162873,
0.978769627389609,
0.9155729781931277,
0.8273013970664075,
0.9603319621485501,
0.9298975009081959,
0.8775117467919693,
0.8614509568162568,
0.9144155657624043,
0.7783710792186642,
0.9701880190033496,
0.7858944693777298,
0.9278353490304795,
0.9472367444800102,
0.7834809788888359,
0.7997358978342995,
0.8459052935830644,
0.8612076995259889,
0.8470901722260982,
0.824037271004761,
0.865608650068826,
0.8023193246081538,
0.7836788857151791,
0.880404135119559,
0.8491559249509729,
0.788345270395367,
0.9461393747813323,
0.835123385080455,
0.8158174048388718,
0.8604581300972295,
0.9623616555004862,
0.8564688397784819,
0.8576867681204893,
0.8973905356978807,
0.8634447095761868,
0.8149528594606367,
0.873171253801101,
0.8653347676818378,
0.929525558735665,
0.8358267202818717,
0.971888682386554,
0.8500189240129448,
0.6201715858790215,
0.8982737437906866,
0.8919523978597029,
0.7327218620224804,
0.8329671225228632,
0.9265589851585685,
0.8976605724257473,
0.8865148832512937,
0.7893917264917192,
0.7303107673595587,
0.8428958487387793,
0.8712646524042454,
0.9726111208329614,
0.9368020234351375,
0.9270010843622818,
0.8900608737128451,
0.7975173141451626,
0.9403308743666237,
0.8484005148509558,
0.9285585494125882,
0.8461714640960527,
0.9301612552439464,
0.9840391344904358,
0.8305503032909712,
0.8985536902480058,
0.9476344055766088,
0.9342892661789283,
0.8849523247638441,
0.7736620851030366,
0.8083290901088768,
0.9510007696957726,
0.8677438102591386,
0.8324233959261911,
0.7379868650067231,
0.9049462205283083,
0.9044068964151009,
0.7810399099399672,
0.9040280419401343,
0.7720832557964016,
0.7168259249496903,
0.8657076231674912,
0.9689982290529879,
0.9330371348632155,
0.7014093149115691,
0.9056081832768474,
0.8483474394912361,
0.8729108893663646,
0.8494252835407102,
0.8702668029663239,
0.8703072652243842,
0.9279473628052111,
0.8615930026010632,
0.7590822846968434,
0.8435232136599701,
0.8264379743713927,
0.8793126202650794,
0.847452301256391,
0.7546334370590053,
0.8870818568791698,
0.8349553719854912,
0.9232007589383142,
0.7924421895458662,
0.8556103146657943,
0.8397958720336148,
0.9358165878534705,
0.904577353436614,
0.9022537114781464,
0.775603917181226,
0.946091618927627,
0.8264119461162666,
0.8261258105029201,
0.8605336598142989,
0.7518422489499549,
0.849587557879856,
0.9922578397415717,
0.7499254104864422,
0.8845204616348977,
0.8361936554693772,
0.9172228808488442,
0.8068135566092208,
0.795739929008372,
0.8632611464779958,
0.7612462576141025,
0.9589125421073597,
0.9555759038945358,
0.8822980105025566,
0.9663740139243834,
0.9071760951442449,
0.9335338894977118,
0.8042262160785201,
0.9399607295068667,
0.8318513711395181,
0.8697471272873054,
0.9103391819002411,
0.8272582065159091,
0.7868989539137853,
0.7416168920325092,
0.8828593510646834,
0.9141342991325323,
0.7259887492833588,
0.9478299721997572,
0.843766518385904,
0.919830425538435,
0.9069062941852448,
0.9036466185261808,
0.9817542893707696,
0.8833292621745382,
0.832556614533968,
0.8135910443631924,
0.9628932969024508,
0.9450804655496136,
0.9226384091529121,
0.8401818103049188,
0.723691406760321,
0.6828741135249211,
0.834410523069395,
0.9959256404542386,
0.952870396433041,
0.969514692554192,
0.9220387806044666,
0.9511950116111251,
0.87442203104197,
0.8399026046246612,
0.9029483760650204,
0.9097073428917352,
0.8651925582004045,
0.9178332691819033,
0.7556713752294848,
0.8601740894614596,
0.8250804250840363,
0.799473306929639,
0.8911389639861809,
0.915913776235107,
0.7867422041389165,
0.8035116695233039,
0.7702882636946234,
0.9060460430333088,
0.7214029229730072,
0.8607904806397634,
0.8228468643082103,
0.8900020169140401,
0.9343567736626528,
0.9305049279291139,
0.9664193138643489,
0.9008537853184969,
0.7625840742620333,
0.815302054727336,
0.9215061720798934,
0.7192673801865671,
0.8949994067748966,
0.9367566547265034,
0.7602684166275758,
0.8184439767612992,
0.8361983856596491,
0.7761725471827079,
0.7724780968772909,
0.9249211346782868,
0.8718843131924867,
0.8522890335712519,
0.9015475867709434,
0.8720699810318118,
0.8937599387455695,
0.8721713573852221,
0.8100783166142076,
1.0000000000000002,
0.8213222537973748,
0.8361185401136565,
0.8371907459006128,
0.9065697385076582,
0.752240671472798,
0.8283078905766531,
0.8499886819287953,
0.9097932369637356,
0.9529813104528191,
0.8449289750214674,
1,
0.8302949362084788,
0.7741532046500113,
0.8743828037305432,
0.8201855611163867,
0.8194689758101458,
0.7925076796225758,
0.8748126117575765,
0.8299510305557958,
0.9619426561868236,
0.8627070029199212
],
"xaxis": "x",
"yaxis": "y"
},
{
"alignmentgroup": "True",
"bingroup": "x",
"hovertemplate": "label=-1<br>dataset=train<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
"legendgroup": "-1",
"marker": {
"color": "#EF553B",
"opacity": 0.5,
"pattern": {
"shape": ""
}
},
"name": "-1",
"offsetgroup": "-1",
"orientation": "v",
"showlegend": true,
"type": "histogram",
"x": [
0.7299945446332757,
0.761829365033925,
0.676235270353631,
0.7023016593603112,
0.7350156032869306,
0.7735236691667362,
0.7187241641280292,
0.8063818744486255,
0.7115273138749911,
0.7402039545462252,
0.7372226375565185,
0.719606336264161,
0.8344052997670786,
0.6881482918877712,
0.650513335277676,
0.7182967436895931,
0.7659970793010594,
0.6361476177437694,
0.7957002158983555,
0.7395402073375513,
0.7614384854511241,
0.6835300790002988,
0.6209194558826169,
0.7907726220786139,
0.7215750502680501,
0.7271947538454276,
0.6962333614625641,
0.7517476038206041,
0.7135529871009506,
0.7522919529124952,
0.7120639628754573,
0.7623014780353815,
0.7939574492876662,
0.6998873138766412,
0.7700594720774577,
0.7766161530342343,
0.7285080945118152,
0.7562511166739386,
0.8086622737220109,
0.7565297004511734,
0.7315242427462245,
0.791225232078265,
0.7281092467134649,
0.7675685917886341,
0.7122436997016454,
0.7600255476588682,
0.769465992253443,
0.7552507233110458,
0.7373719614604455,
0.7681449827665134,
0.7760194768221379,
0.8035677492769473,
0.7771752906080505,
0.6514739683312004,
0.744787832649546,
0.7700232252518491,
0.6901772464238046,
0.721128796446804,
0.686814078974691,
0.796786900996621,
0.8242176320872988,
0.7806742901384496,
0.7697696656361902,
0.7868668497422822,
0.7304548331410279,
0.7296767182508448,
0.699401040331297,
0.8332660282838579,
0.6513224421864793,
0.6927364371405096,
0.7491793002300279,
0.7909034218879171,
0.7754176150761527,
0.8004827006684735,
0.6659923526948868,
0.8129618884980553,
0.7496476113488948,
0.6519665061955423,
0.7319506042291191,
0.7367099004846358,
0.8590817275589286,
0.6684490623528697,
0.7469140136676841,
0.7269016830127544,
0.6834261236240542,
0.6390380393123325,
0.7129209978685723,
0.6879274953497815,
0.6720427402498903,
0.7343923261152341,
0.5977941468082364,
0.7574521074337901,
0.7302129263006347,
0.7380209108764002,
0.7280177437983031,
0.6870880300968067,
0.6928024686173956,
0.6403900073451696,
0.7209247906698092,
0.666799070142464,
0.7233569397428488,
0.7555267048881131,
0.7275931437975757,
0.7562785127332186,
0.736649021802634,
0.762569980781232,
0.7741923768467133,
0.6669773662198453,
0.7499779113599104,
0.7783371410262362,
0.7798574909618832,
0.7068277719647446,
0.7718066533969561,
0.7078874155127759,
0.7238814624322412,
0.8729026036201937,
0.7106759057210158,
0.7585767440008555,
0.6923565683937588,
0.6693898561996529,
0.7219944003835542,
0.7188064794656569,
0.7491451951131076,
0.6750197249394477,
0.7009868838756784,
0.6519589230722103,
0.775109860666869,
0.6682014813660948,
0.6618358724923147,
0.6218362070243146,
0.7518134154555007,
0.7571307427362168,
0.7546823360234822,
0.7416435744760188,
0.7676464608286054,
0.5799855321635428,
0.796191792361925,
0.6845373357552841,
0.7667984177770908,
0.7393609881148844,
0.7544580057165962,
0.7397495323890664,
0.7658298386950907,
0.6611125314078552,
0.7977728876184589,
0.831090991824215,
0.6982347705041605,
0.6312226759947304,
0.6482907496231076,
0.7658362595299695,
0.7518400125995974,
0.8025480920087656,
0.7461501494015976,
0.7239149641844659,
0.7090055516721123,
0.6846622430587006,
0.7112212118684095,
0.6794460508450644,
0.8344462829585355,
0.7638782230495382,
0.6558255308015893,
0.6799836520005059,
0.7148088830660722,
0.7658007439571484,
0.6581857665503666,
0.648734015773256,
0.8156725527309459,
0.6929202590284855,
0.7490919589112273,
0.7090723359022927,
0.7105572886194162,
0.7461374422467866,
0.7084742440808511,
0.6889378818898818,
0.7551565238112727,
0.7198789279593328,
0.7270482774940944,
0.6971190427893721,
0.7391610904601401,
0.6344734499604212,
0.6719507302378157,
0.6861159059531602,
0.6516118314317232,
0.7199095105818035,
0.6817881072485698,
0.7207373940156253,
0.7745467156569463,
0.6258059783246434,
0.7481056513643776,
0.7183327989261993,
0.7624705969064652,
0.703717229048042,
0.7487365873620485,
0.7495555007485976,
0.6409624588579408,
0.7176245775200687,
0.7537100520717849,
0.7569868075002099,
0.7392135510982174,
0.7188471960732984,
0.697413550280765,
0.7138614496887067,
0.7057641350396069,
0.7675079665142936,
0.7310427541091186,
0.7808018818735418,
0.7567255895826445,
0.7035962262766099,
0.7040750813384501,
0.8183159919038065,
0.7953911933398697,
0.7464891038038547,
0.6751591827598264,
0.7849943676377955,
0.7155963442284841,
0.7428993249606122,
0.7131100645054201,
0.7227595311429733,
0.6519548345531954,
0.7201522118536183,
0.86540799716664,
0.8128819241371503,
0.7278912446692862,
0.7305867175950502,
0.7171875192153516,
0.6755179003509543,
0.7256221359402913,
0.7003814947129137,
0.7486334697158199,
0.7232489166529666,
0.7347697330652992,
0.6493837702986368,
0.6454310256268904,
0.7085305859966166,
0.7709963397003181,
0.7628122486461532,
0.7260869667056613,
0.7656074896314918,
0.7309944223135776,
0.7575162043117213,
0.7425954755181217,
0.7978452334414571,
0.7414129597626153,
0.7369987033441427,
0.7249664966482501,
0.663939118162477,
0.7490232329485363,
0.7532303509685747,
0.6505502824396713,
0.6820403873171862,
0.7458589415356044,
0.65106761846338,
0.8190794585362886,
0.6404595320063431,
0.7620011212588133,
0.6793344580417779,
0.7470455239529016,
0.6254743025101126,
0.714021296346165,
0.7624376857959331,
0.6124088443110871,
0.7190909082546953,
0.7667977482791689,
0.7390282391537781,
0.731411776312689,
0.6910654011387792,
0.7669731150300952,
0.7473599833172766,
0.7757742306046117,
0.7524096654164605,
0.668078105815945,
0.6797298899209671,
0.734572374544983,
0.7851097143298676,
0.7342031538272321,
0.7372538892798539,
0.7335209046034747,
0.6838366965805461,
0.6892908218001774,
0.7368799079039347,
0.667817778038092,
0.7436021083467266,
0.643687287978847,
0.6459012534104801,
0.717961524900317,
0.651811280493696,
0.774663511332172,
0.6481450536649574,
0.7135017154296592,
0.718243043672562,
0.6840974559559992,
0.6237039667424505,
0.7511301324631847,
0.6731554748777204,
0.7311433676647258,
0.7407740108803432,
0.7219713524118947,
0.6945284597258313,
0.7403964086997485,
0.7281416758951161,
0.767616849591671,
0.7623013461443361,
0.794278871148869,
0.7094595086377886,
0.6363511838820268,
0.6955392554405656,
0.7448982185809608,
0.7328610107493276,
0.7208624134565648,
0.6762963031433926,
0.7755406771546928,
0.7045876515082797,
0.6244682388287953,
0.6742169930443296,
0.8182351707309709,
0.7329206562104693,
0.7750198586773644,
0.7686712995024062,
0.7382706738683358,
0.6670365140953741,
0.7122098843228497,
0.720363344417827,
0.7260325476617868,
0.7455849615803621,
0.7135971488970759,
0.7597698332957011,
0.7261113201174606,
0.7802411718313292,
0.6937507195359754,
0.7842882648198792,
0.6900501446041246,
0.760860218311663,
0.7134088358525988,
0.7053629634417926
],
"xaxis": "x2",
"yaxis": "y2"
},
{
"alignmentgroup": "True",
"bingroup": "x",
"hovertemplate": "label=-1<br>dataset=test<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
"legendgroup": "-1",
"marker": {
"color": "#EF553B",
"opacity": 0.5,
"pattern": {
"shape": ""
}
},
"name": "-1",
"offsetgroup": "-1",
"orientation": "v",
"showlegend": false,
"type": "histogram",
"x": [
0.698216609902851,
0.7175312484264973,
0.7186004945024,
0.8000305631283213,
0.6982596730429885,
0.7632305498724974,
0.7138465594512551,
0.6788567152998355,
0.7373203069430523,
0.6873036529619025,
0.6503465541274354,
0.7365009792588864,
0.741093678604772,
0.735107851927508,
0.7789882788330387,
0.6997359150722298,
0.7996799028712657,
0.7467428244049107,
0.6687862526227949,
0.743064589979991,
0.8567601871608912,
0.7518934598338161,
0.7026922402796603,
0.7211365080382613,
0.706354045904417,
0.779427003097291,
0.7362962393365906,
0.7291751132805836,
0.7122378769301289,
0.7140477948952889,
0.7173405960056081,
0.7634875929311733,
0.7977581390590794,
0.7301463738402032,
0.7615332057968679,
0.682227985030492,
0.7334635989895589,
0.7386028577038499,
0.659084291835728,
0.7899820114980755,
0.7247172517704278,
0.7438155210487466,
0.7005346860296618,
0.6648553956533266,
0.7701566644966253,
0.7514961574415904,
0.7587991656983686,
0.7001882273133521,
0.6910707516646375,
0.7394355361240693,
0.7276824899179835,
0.6759744362016779,
0.8302185470787163,
0.6928641094502374,
0.7120538723839331,
0.7224960785372221,
0.7198045816069757,
0.6813558762031737,
0.7165801628045559,
0.7120832723046919,
0.8420619371167495,
0.946387336435492,
0.7554566849916029,
0.6539169025880401,
0.7809846972343978,
0.7724403150471626,
0.7005276086814838,
0.6393757830582598,
0.8206678516495857,
0.7220887623700465,
0.6457268309147112,
0.7355556783165726,
0.7154485704709247,
0.736485552747626,
0.6279868336962336,
0.6826307018159569,
0.6893086023402188,
0.6662701662224271,
0.7867533724923507,
0.7767747518359883,
0.8265509786609969,
0.7191298707058883,
0.7022356632617615,
0.7327905404619434,
0.744068997548466,
0.7610196098150734,
0.7115456073990181,
0.7432332956176703,
0.7893785433034154,
0.7290401089931655,
0.6577253300110991,
0.7003577024217508,
0.6656632326716906,
0.777774068734494,
0.7332058736487923,
0.6832255453323396,
0.7959026283113317,
0.7447573588360193,
0.6641622569109379,
0.6536973982676278,
0.6665423422110455,
0.6839542281076658,
0.7423443203765382,
0.7548386221527327,
0.6027110439271006,
0.6910860362919979,
0.6562817652314046,
0.716178983425107,
0.724403115068019,
0.7259909803954812,
0.7571530348849562,
0.733442357026098,
0.7422263713152575,
0.7194490356714829,
0.8324274302968815,
0.7326854438693811,
0.6965534375434462,
0.6112368704742529,
0.7304384905441154,
0.7720040972080942,
0.6935032899149339,
0.6744196671055273,
0.7147585872166342,
0.7752945283086001,
0.7247252468203336,
0.7487128692844189,
0.7073695766484023,
0.7770002807438351,
0.6765039364068596,
0.6337555851438214,
0.7894145395726685,
0.808127657824718,
0.6806212157836168,
0.6631556694778886,
0.7026864982467919,
0.6616453779258454,
0.7884566439105933,
0.7066759939380831,
0.6337579382878964,
0.6487574469000645,
0.772145520185077,
0.7232462166936153,
0.7901145921506159,
0.6972515066616096,
0.7122390390469722,
0.7192738985202007,
0.6592857746430886,
0.7609468966795828,
0.6603956483288737,
0.7457664593808938,
0.6826889303879785,
0.7934081087742163,
0.6763383894378445,
0.6839731303810913,
0.6687692061586001,
0.6991702352595418,
0.7797875942740597,
0.7663428993137384,
0.7674609664886143,
0.7252797254797089,
0.6976652869077468,
0.7158816756428659,
0.718241456712601,
0.7632800029357677,
0.7220393239650442,
0.7734594213360495,
0.6831354341128338,
0.8050740585092945,
0.8304016729411579,
0.6942297180429666,
0.7777210539532041,
0.7971492766742209,
0.7551377672686813,
0.756111346085797,
0.7377139073001521,
0.7292914327831885,
0.6619852966467218,
0.6633387313486273,
0.7753943622795527,
0.6931366760284077,
0.6879938635549627,
0.7934277451864097,
0.7095961561996172,
0.6721647904665006,
0.6639348167332335,
0.7190874751718003,
0.6487682662731885,
0.7712237228183192,
0.7195541308325332,
0.7624245070018526,
0.7066568592895756,
0.6955819267956074,
0.762644689565747,
0.696550099953965,
0.72160309057877,
0.6589853583691466,
0.7781076723886485,
0.7844353288099747,
0.6499942545071061,
0.7586643115109818,
0.7851245713510661,
0.6825431110733673,
0.7920473550917971,
0.7505505200683399,
0.7112992195413225,
0.6872424297540068,
0.6629403755138552,
0.7754417757832819,
0.7445419843378304,
0.7064660255551196,
0.7102764764345906,
0.7166584523700176,
0.745706231193076,
0.7628022956052315,
0.6960882904963708,
0.6837468719098787,
0.7520487263229376,
0.7604129787986132,
0.7522054011542562,
0.6898973175025894,
0.712053903439596,
0.7339122116041967,
0.6789458999380491,
0.763449576758855,
0.7406926320689465,
0.6976029213628422,
0.7734670110437211,
0.7670042286239444,
0.7194794729796166,
0.8039337600845294,
0.822887238017454,
0.7546355748553022,
0.712175872919759,
0.6283408438534046,
0.7277722488820856,
0.7848289730316366,
0.6470175548703326,
0.7175970646556328,
0.6982508276209234,
0.6931585425755658,
0.7105706262503181,
0.6541269887120206,
0.8404400660965669,
0.7278163563567496,
0.8022018950050529,
0.7449136144274442,
0.7254549036563473,
0.7708530061010334,
0.7226568414320923,
0.7174427406313559,
0.7242867487973796,
0.7465548093208114,
0.6481248100208876,
0.715420475795481,
0.7818995378198694,
0.710140386310007,
0.779070581516581,
0.7022765286826387,
0.7920014977874291,
0.7028249663680173,
0.7464011128583546,
0.6874796697586977,
0.7834259382246206,
0.7487992683674399,
0.6256280566008573,
0.7248416704075088,
0.6787298488859722,
0.7604099689852636,
0.7563309422531382,
0.6536692254847523,
0.7277402265583088,
0.7961681595257355,
0.7183023940359037,
0.8578324194628058,
0.6890352710148969,
0.711616174722821,
0.6560825239577808,
0.7235723022023411,
0.6649236563739442,
0.6800589793852263,
0.7785177771732936,
0.8277144895457349,
0.7047472917613488,
0.6981993581133783,
0.69214194925211,
0.7175225477364914,
0.6821700037384272,
0.6934882153010823,
0.6671459724713757,
0.7577056235720239,
0.7452347481085622,
0.7540647847248615,
0.7623045528688165,
0.8419961516314336,
0.7607631404114222,
0.7104592749279437,
0.7907219223857475,
0.6626370861321504,
0.7293323972767943,
0.6747790790615374,
0.7205393564241827,
0.7182141925188444,
0.6698620455142402,
0.7774615433926413,
0.68441903533162,
0.7195176784475413,
0.7765542578440102,
0.7653003235671018,
0.6588957811563716,
0.7049538466814345,
0.6767019827253833,
0.6852115350974048,
0.7159808946533304,
0.6275008181698795,
0.6641598464064111,
0.7653064009797307,
0.7846062245731015,
0.7131195190890488,
0.7388407888274415,
0.7078575690975646,
0.7922969773693673,
0.6399205949452071,
0.7522331808600956,
0.756127025845852,
0.7527950868321375,
0.7791392496558245,
0.7388745760013306,
0.6739605493576779,
0.6432673241279167,
0.7124181751534668,
0.669456709871883,
0.7067049471522552,
0.6685698115209102,
0.7430777123010961,
0.7510627360284545
],
"xaxis": "x",
"yaxis": "y"
}
],
"layout": {
"annotations": [
{
"font": {},
"showarrow": false,
"text": "dataset=test",
"textangle": 90,
"x": 0.98,
"xanchor": "left",
"xref": "paper",
"y": 0.2425,
"yanchor": "middle",
"yref": "paper"
},
{
"font": {},
"showarrow": false,
"text": "dataset=train",
"textangle": 90,
"x": 0.98,
"xanchor": "left",
"xref": "paper",
"y": 0.7575000000000001,
"yanchor": "middle",
"yref": "paper"
}
],
"barmode": "overlay",
"legend": {
"title": {
"text": "label"
},
"tracegroupgap": 0
},
"margin": {
"t": 60
},
"template": {
"data": {
"bar": [
{
"error_x": {
"color": "#2a3f5f"
},
"error_y": {
"color": "#2a3f5f"
},
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "bar"
}
],
"barpolar": [
{
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "barpolar"
}
],
"carpet": [
{
"aaxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"baxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"type": "carpet"
}
],
"choropleth": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "choropleth"
}
],
"contour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "contour"
}
],
"contourcarpet": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "contourcarpet"
}
],
"heatmap": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmap"
}
],
"heatmapgl": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmapgl"
}
],
"histogram": [
{
"marker": {
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "histogram"
}
],
"histogram2d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2d"
}
],
"histogram2dcontour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2dcontour"
}
],
"mesh3d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "mesh3d"
}
],
"parcoords": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "parcoords"
}
],
"pie": [
{
"automargin": true,
"type": "pie"
}
],
"scatter": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter"
}
],
"scatter3d": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter3d"
}
],
"scattercarpet": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattercarpet"
}
],
"scattergeo": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergeo"
}
],
"scattergl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergl"
}
],
"scattermapbox": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattermapbox"
}
],
"scatterpolar": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolar"
}
],
"scatterpolargl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolargl"
}
],
"scatterternary": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterternary"
}
],
"surface": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "surface"
}
],
"table": [
{
"cells": {
"fill": {
"color": "#EBF0F8"
},
"line": {
"color": "white"
}
},
"header": {
"fill": {
"color": "#C8D4E3"
},
"line": {
"color": "white"
}
},
"type": "table"
}
]
},
"layout": {
"annotationdefaults": {
"arrowcolor": "#2a3f5f",
"arrowhead": 0,
"arrowwidth": 1
},
"autotypenumbers": "strict",
"coloraxis": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"colorscale": {
"diverging": [
[
0,
"#8e0152"
],
[
0.1,
"#c51b7d"
],
[
0.2,
"#de77ae"
],
[
0.3,
"#f1b6da"
],
[
0.4,
"#fde0ef"
],
[
0.5,
"#f7f7f7"
],
[
0.6,
"#e6f5d0"
],
[
0.7,
"#b8e186"
],
[
0.8,
"#7fbc41"
],
[
0.9,
"#4d9221"
],
[
1,
"#276419"
]
],
"sequential": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"sequentialminus": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
]
},
"colorway": [
"#636efa",
"#EF553B",
"#00cc96",
"#ab63fa",
"#FFA15A",
"#19d3f3",
"#FF6692",
"#B6E880",
"#FF97FF",
"#FECB52"
],
"font": {
"color": "#2a3f5f"
},
"geo": {
"bgcolor": "white",
"lakecolor": "white",
"landcolor": "#E5ECF6",
"showlakes": true,
"showland": true,
"subunitcolor": "white"
},
"hoverlabel": {
"align": "left"
},
"hovermode": "closest",
"mapbox": {
"style": "light"
},
"paper_bgcolor": "white",
"plot_bgcolor": "#E5ECF6",
"polar": {
"angularaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"radialaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"scene": {
"xaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"yaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"zaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
}
},
"shapedefaults": {
"line": {
"color": "#2a3f5f"
}
},
"ternary": {
"aaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"baxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"caxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"title": {
"x": 0.05
},
"xaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
},
"yaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
}
}
},
"width": 500,
"xaxis": {
"anchor": "y",
"domain": [
0,
0.98
],
"title": {
"text": "cosine_similarity"
}
},
"xaxis2": {
"anchor": "y2",
"domain": [
0,
0.98
],
"matches": "x",
"showticklabels": false
},
"yaxis": {
"anchor": "x",
"domain": [
0,
0.485
],
"title": {
"text": "count"
}
},
"yaxis2": {
"anchor": "x2",
"domain": [
0.515,
1
],
"matches": "y",
"title": {
"text": "count"
}
}
}
}
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train accuracy: 89.1% ± 2.4%\n",
"test accuracy: 88.8% ± 2.4%\n"
]
}
],
"source": [
"# calculate accuracy (and its standard error) of predicting label=1 if similarity>x\n",
"# x is optimized by sweeping from -1 to 1 in steps of 0.01\n",
"def accuracy_and_se(cosine_similarity: float, labeled_similarity: int) -> Tuple[float]:\n",
" accuracies = []\n",
" for threshold_thousandths in range(-1000, 1000, 1):\n",
" threshold = threshold_thousandths / 1000\n",
" total = 0\n",
" correct = 0\n",
" for cs, ls in zip(cosine_similarity, labeled_similarity):\n",
" total += 1\n",
" if cs > threshold:\n",
" prediction = 1\n",
" else:\n",
" prediction = -1\n",
" if prediction == ls:\n",
" correct += 1\n",
" accuracy = correct / total\n",
" accuracies.append(accuracy)\n",
" a = max(accuracies)\n",
" n = len(cosine_similarity)\n",
" standard_error = (a * (1 - a) / n) ** 0.5 # standard error of binomial\n",
" return a, standard_error\n",
"\n",
"\n",
"# check that training and test sets are balanced\n",
"px.histogram(\n",
" df,\n",
" x=\"cosine_similarity\",\n",
" color=\"label\",\n",
" barmode=\"overlay\",\n",
" width=500,\n",
" facet_row=\"dataset\",\n",
").show()\n",
"\n",
"for dataset in [\"train\", \"test\"]:\n",
" data = df[df[\"dataset\"] == dataset]\n",
" a, se = accuracy_and_se(data[\"cosine_similarity\"], data[\"label\"])\n",
" print(f\"{dataset} accuracy: {a:0.1%} ± {1.96 * se:0.1%}\")\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "zHLxlnsApgkR"
},
"source": [
"## 7. Optimize the matrix using the training data provided"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "z52V0x8IpgkR"
},
"outputs": [],
"source": [
"def embedding_multiplied_by_matrix(\n",
" embedding: List[float], matrix: torch.tensor\n",
") -> np.array:\n",
" embedding_tensor = torch.tensor(embedding).float()\n",
" modified_embedding = embedding_tensor @ matrix\n",
" modified_embedding = modified_embedding.detach().numpy()\n",
" return modified_embedding\n",
"\n",
"\n",
"# compute custom embeddings and new cosine similarities\n",
"def apply_matrix_to_embeddings_dataframe(matrix: torch.tensor, df: pd.DataFrame):\n",
" for column in [\"text_1_embedding\", \"text_2_embedding\"]:\n",
" df[f\"{column}_custom\"] = df[column].apply(\n",
" lambda x: embedding_multiplied_by_matrix(x, matrix)\n",
" )\n",
" df[\"cosine_similarity_custom\"] = df.apply(\n",
" lambda row: cosine_similarity(\n",
" row[\"text_1_embedding_custom\"], row[\"text_2_embedding_custom\"]\n",
" ),\n",
" axis=1,\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "p2ZSXu6spgkR"
},
"outputs": [],
"source": [
"def optimize_matrix(\n",
" modified_embedding_length: int = 2048, # in my brief experimentation, bigger was better (2048 is length of babbage encoding)\n",
" batch_size: int = 100,\n",
" max_epochs: int = 100,\n",
" learning_rate: float = 100.0, # seemed to work best when similar to batch size - feel free to try a range of values\n",
" dropout_fraction: float = 0.0, # in my testing, dropout helped by a couple percentage points (definitely not necessary)\n",
" df: pd.DataFrame = df,\n",
" print_progress: bool = True,\n",
" save_results: bool = True,\n",
") -> torch.tensor:\n",
" \"\"\"Return matrix optimized to minimize loss on training data.\"\"\"\n",
" run_id = random.randint(0, 2 ** 31 - 1) # (range is arbitrary)\n",
" # convert from dataframe to torch tensors\n",
" # e is for embedding, s for similarity label\n",
" def tensors_from_dataframe(\n",
" df: pd.DataFrame,\n",
" embedding_column_1: str,\n",
" embedding_column_2: str,\n",
" similarity_label_column: str,\n",
" ) -> Tuple[torch.tensor]:\n",
" e1 = np.stack(np.array(df[embedding_column_1].values))\n",
" e2 = np.stack(np.array(df[embedding_column_2].values))\n",
" s = np.stack(np.array(df[similarity_label_column].astype(\"float\").values))\n",
"\n",
" e1 = torch.from_numpy(e1).float()\n",
" e2 = torch.from_numpy(e2).float()\n",
" s = torch.from_numpy(s).float()\n",
"\n",
" return e1, e2, s\n",
"\n",
" e1_train, e2_train, s_train = tensors_from_dataframe(\n",
" df[df[\"dataset\"] == \"train\"], \"text_1_embedding\", \"text_2_embedding\", \"label\"\n",
" )\n",
" e1_test, e2_test, s_test = tensors_from_dataframe(\n",
" df[df[\"dataset\"] == \"test\"], \"text_1_embedding\", \"text_2_embedding\", \"label\"\n",
" )\n",
"\n",
" # create dataset and loader\n",
" dataset = torch.utils.data.TensorDataset(e1_train, e2_train, s_train)\n",
" train_loader = torch.utils.data.DataLoader(\n",
" dataset, batch_size=batch_size, shuffle=True\n",
" )\n",
"\n",
" # define model (similarity of projected embeddings)\n",
" def model(embedding_1, embedding_2, matrix, dropout_fraction=dropout_fraction):\n",
" e1 = torch.nn.functional.dropout(embedding_1, p=dropout_fraction)\n",
" e2 = torch.nn.functional.dropout(embedding_2, p=dropout_fraction)\n",
" modified_embedding_1 = e1 @ matrix # @ is matrix multiplication\n",
" modified_embedding_2 = e2 @ matrix\n",
" similarity = torch.nn.functional.cosine_similarity(\n",
" modified_embedding_1, modified_embedding_2\n",
" )\n",
" return similarity\n",
"\n",
" # define loss function to minimize\n",
" def mse_loss(predictions, targets):\n",
" difference = predictions - targets\n",
" return torch.sum(difference * difference) / difference.numel()\n",
"\n",
" # initialize projection matrix\n",
" embedding_length = len(df[\"text_1_embedding\"].values[0])\n",
" matrix = torch.randn(\n",
" embedding_length, modified_embedding_length, requires_grad=True\n",
" )\n",
"\n",
" epochs, types, losses, accuracies, matrices = [], [], [], [], []\n",
" for epoch in range(1, 1 + max_epochs):\n",
" # iterate through training dataloader\n",
" for a, b, actual_similarity in train_loader:\n",
" # generate prediction\n",
" predicted_similarity = model(a, b, matrix)\n",
" # get loss and perform backpropagation\n",
" loss = mse_loss(predicted_similarity, actual_similarity)\n",
" loss.backward()\n",
" # update the weights\n",
" with torch.no_grad():\n",
" matrix -= matrix.grad * learning_rate\n",
" # set gradients to zero\n",
" matrix.grad.zero_()\n",
" # calculate test loss\n",
" test_predictions = model(e1_test, e2_test, matrix)\n",
" test_loss = mse_loss(test_predictions, s_test)\n",
"\n",
" # compute custom embeddings and new cosine similarities\n",
" apply_matrix_to_embeddings_dataframe(matrix, df)\n",
"\n",
" # calculate test accuracy\n",
" for dataset in [\"train\", \"test\"]:\n",
" data = df[df[\"dataset\"] == dataset]\n",
" a, se = accuracy_and_se(data[\"cosine_similarity_custom\"], data[\"label\"])\n",
"\n",
" # record results of each epoch\n",
" epochs.append(epoch)\n",
" types.append(dataset)\n",
" losses.append(loss.item() if dataset == \"train\" else test_loss.item())\n",
" accuracies.append(a)\n",
" matrices.append(matrix.detach().numpy())\n",
"\n",
" # optionally print accuracies\n",
" if print_progress is True:\n",
" print(\n",
" f\"Epoch {epoch}/{max_epochs}: {dataset} accuracy: {a:0.1%} ± {1.96 * se:0.1%}\"\n",
" )\n",
"\n",
" data = pd.DataFrame(\n",
" {\"epoch\": epochs, \"type\": types, \"loss\": losses, \"accuracy\": accuracies}\n",
" )\n",
" data[\"run_id\"] = run_id\n",
" data[\"modified_embedding_length\"] = modified_embedding_length\n",
" data[\"batch_size\"] = batch_size\n",
" data[\"max_epochs\"] = max_epochs\n",
" data[\"learning_rate\"] = learning_rate\n",
" data[\"dropout_fraction\"] = dropout_fraction\n",
" data[\n",
" \"matrix\"\n",
" ] = matrices # saving every single matrix can get big; feel free to delete/change\n",
" if save_results is True:\n",
" data.to_csv(f\"{run_id}_optimization_results.csv\", index=False)\n",
"\n",
" return data\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "nlcUW-zEpgkS",
"outputId": "4bd4bdff-628a-406f-fffe-aedbfad66446"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/30: train accuracy: 89.1% ± 2.4%\n",
"Epoch 1/30: test accuracy: 88.4% ± 2.4%\n",
"Epoch 2/30: train accuracy: 89.5% ± 2.3%\n",
"Epoch 2/30: test accuracy: 88.8% ± 2.4%\n",
"Epoch 3/30: train accuracy: 90.6% ± 2.2%\n",
"Epoch 3/30: test accuracy: 89.3% ± 2.3%\n",
"Epoch 4/30: train accuracy: 91.2% ± 2.2%\n",
"Epoch 4/30: test accuracy: 89.7% ± 2.3%\n",
"Epoch 5/30: train accuracy: 91.5% ± 2.1%\n",
"Epoch 5/30: test accuracy: 90.0% ± 2.3%\n",
"Epoch 6/30: train accuracy: 91.9% ± 2.1%\n",
"Epoch 6/30: test accuracy: 90.4% ± 2.2%\n",
"Epoch 7/30: train accuracy: 92.2% ± 2.0%\n",
"Epoch 7/30: test accuracy: 90.7% ± 2.2%\n",
"Epoch 8/30: train accuracy: 92.7% ± 2.0%\n",
"Epoch 8/30: test accuracy: 90.9% ± 2.2%\n",
"Epoch 9/30: train accuracy: 92.7% ± 2.0%\n",
"Epoch 9/30: test accuracy: 91.0% ± 2.2%\n",
"Epoch 10/30: train accuracy: 93.0% ± 1.9%\n",
"Epoch 10/30: test accuracy: 91.6% ± 2.1%\n",
"Epoch 11/30: train accuracy: 93.1% ± 1.9%\n",
"Epoch 11/30: test accuracy: 91.8% ± 2.1%\n",
"Epoch 12/30: train accuracy: 93.4% ± 1.9%\n",
"Epoch 12/30: test accuracy: 92.1% ± 2.0%\n",
"Epoch 13/30: train accuracy: 93.6% ± 1.9%\n",
"Epoch 13/30: test accuracy: 92.4% ± 2.0%\n",
"Epoch 14/30: train accuracy: 93.7% ± 1.8%\n",
"Epoch 14/30: test accuracy: 92.7% ± 2.0%\n",
"Epoch 15/30: train accuracy: 93.7% ± 1.8%\n",
"Epoch 15/30: test accuracy: 92.7% ± 2.0%\n",
"Epoch 16/30: train accuracy: 94.0% ± 1.8%\n",
"Epoch 16/30: test accuracy: 93.0% ± 1.9%\n",
"Epoch 17/30: train accuracy: 94.0% ± 1.8%\n",
"Epoch 17/30: test accuracy: 93.0% ± 1.9%\n",
"Epoch 18/30: train accuracy: 94.2% ± 1.8%\n",
"Epoch 18/30: test accuracy: 93.1% ± 1.9%\n",
"Epoch 19/30: train accuracy: 94.2% ± 1.8%\n",
"Epoch 19/30: test accuracy: 93.1% ± 1.9%\n",
"Epoch 20/30: train accuracy: 94.3% ± 1.8%\n",
"Epoch 20/30: test accuracy: 93.0% ± 1.9%\n",
"Epoch 21/30: train accuracy: 94.5% ± 1.7%\n",
"Epoch 21/30: test accuracy: 93.1% ± 1.9%\n",
"Epoch 22/30: train accuracy: 94.5% ± 1.7%\n",
"Epoch 22/30: test accuracy: 93.3% ± 1.9%\n",
"Epoch 23/30: train accuracy: 94.6% ± 1.7%\n",
"Epoch 23/30: test accuracy: 93.3% ± 1.9%\n",
"Epoch 24/30: train accuracy: 94.6% ± 1.7%\n",
"Epoch 24/30: test accuracy: 93.3% ± 1.9%\n",
"Epoch 25/30: train accuracy: 94.8% ± 1.7%\n",
"Epoch 25/30: test accuracy: 93.3% ± 1.9%\n",
"Epoch 26/30: train accuracy: 94.8% ± 1.7%\n",
"Epoch 26/30: test accuracy: 93.4% ± 1.9%\n",
"Epoch 27/30: train accuracy: 94.8% ± 1.7%\n",
"Epoch 27/30: test accuracy: 93.4% ± 1.9%\n",
"Epoch 28/30: train accuracy: 94.9% ± 1.7%\n",
"Epoch 28/30: test accuracy: 93.4% ± 1.9%\n",
"Epoch 29/30: train accuracy: 94.9% ± 1.7%\n",
"Epoch 29/30: test accuracy: 93.4% ± 1.9%\n",
"Epoch 30/30: train accuracy: 94.9% ± 1.7%\n",
"Epoch 30/30: test accuracy: 93.3% ± 1.9%\n",
"Epoch 1/30: train accuracy: 89.7% ± 2.3%\n",
"Epoch 1/30: test accuracy: 89.1% ± 2.4%\n",
"Epoch 2/30: train accuracy: 89.8% ± 2.3%\n",
"Epoch 2/30: test accuracy: 89.9% ± 2.3%\n",
"Epoch 3/30: train accuracy: 90.3% ± 2.2%\n",
"Epoch 3/30: test accuracy: 90.0% ± 2.3%\n",
"Epoch 4/30: train accuracy: 91.0% ± 2.2%\n",
"Epoch 4/30: test accuracy: 90.3% ± 2.2%\n",
"Epoch 5/30: train accuracy: 91.3% ± 2.1%\n",
"Epoch 5/30: test accuracy: 90.3% ± 2.2%\n",
"Epoch 6/30: train accuracy: 91.8% ± 2.1%\n",
"Epoch 6/30: test accuracy: 90.4% ± 2.2%\n",
"Epoch 7/30: train accuracy: 92.4% ± 2.0%\n",
"Epoch 7/30: test accuracy: 91.0% ± 2.2%\n",
"Epoch 8/30: train accuracy: 92.8% ± 2.0%\n",
"Epoch 8/30: test accuracy: 91.3% ± 2.1%\n",
"Epoch 9/30: train accuracy: 93.1% ± 1.9%\n",
"Epoch 9/30: test accuracy: 91.6% ± 2.1%\n",
"Epoch 10/30: train accuracy: 93.4% ± 1.9%\n",
"Epoch 10/30: test accuracy: 91.9% ± 2.1%\n",
"Epoch 11/30: train accuracy: 93.4% ± 1.9%\n",
"Epoch 11/30: test accuracy: 91.8% ± 2.1%\n",
"Epoch 12/30: train accuracy: 93.6% ± 1.9%\n",
"Epoch 12/30: test accuracy: 92.1% ± 2.0%\n",
"Epoch 13/30: train accuracy: 93.7% ± 1.8%\n",
"Epoch 13/30: test accuracy: 92.4% ± 2.0%\n",
"Epoch 14/30: train accuracy: 93.7% ± 1.8%\n",
"Epoch 14/30: test accuracy: 92.5% ± 2.0%\n",
"Epoch 15/30: train accuracy: 93.9% ± 1.8%\n",
"Epoch 15/30: test accuracy: 92.8% ± 2.0%\n",
"Epoch 16/30: train accuracy: 94.0% ± 1.8%\n",
"Epoch 16/30: test accuracy: 92.8% ± 2.0%\n",
"Epoch 17/30: train accuracy: 94.0% ± 1.8%\n",
"Epoch 17/30: test accuracy: 92.8% ± 2.0%\n",
"Epoch 18/30: train accuracy: 94.2% ± 1.8%\n",
"Epoch 18/30: test accuracy: 92.8% ± 2.0%\n",
"Epoch 19/30: train accuracy: 94.2% ± 1.8%\n",
"Epoch 19/30: test accuracy: 92.8% ± 2.0%\n",
"Epoch 20/30: train accuracy: 94.2% ± 1.8%\n",
"Epoch 20/30: test accuracy: 93.1% ± 1.9%\n",
"Epoch 21/30: train accuracy: 94.3% ± 1.8%\n",
"Epoch 21/30: test accuracy: 93.3% ± 1.9%\n",
"Epoch 22/30: train accuracy: 94.3% ± 1.8%\n",
"Epoch 22/30: test accuracy: 93.3% ± 1.9%\n",
"Epoch 23/30: train accuracy: 94.5% ± 1.7%\n",
"Epoch 23/30: test accuracy: 93.3% ± 1.9%\n",
"Epoch 24/30: train accuracy: 94.5% ± 1.7%\n",
"Epoch 24/30: test accuracy: 93.3% ± 1.9%\n",
"Epoch 25/30: train accuracy: 94.6% ± 1.7%\n",
"Epoch 25/30: test accuracy: 93.4% ± 1.9%\n",
"Epoch 26/30: train accuracy: 94.6% ± 1.7%\n",
"Epoch 26/30: test accuracy: 93.3% ± 1.9%\n",
"Epoch 27/30: train accuracy: 94.6% ± 1.7%\n",
"Epoch 27/30: test accuracy: 93.4% ± 1.9%\n",
"Epoch 28/30: train accuracy: 94.8% ± 1.7%\n",
"Epoch 28/30: test accuracy: 93.4% ± 1.9%\n",
"Epoch 29/30: train accuracy: 94.8% ± 1.7%\n",
"Epoch 29/30: test accuracy: 93.3% ± 1.9%\n",
"Epoch 30/30: train accuracy: 94.8% ± 1.7%\n",
"Epoch 30/30: test accuracy: 93.4% ± 1.9%\n",
"Epoch 1/30: train accuracy: 90.7% ± 2.2%\n",
"Epoch 1/30: test accuracy: 89.9% ± 2.3%\n",
"Epoch 2/30: train accuracy: 90.9% ± 2.2%\n",
"Epoch 2/30: test accuracy: 90.3% ± 2.2%\n",
"Epoch 3/30: train accuracy: 91.6% ± 2.1%\n",
"Epoch 3/30: test accuracy: 90.3% ± 2.2%\n",
"Epoch 4/30: train accuracy: 92.2% ± 2.0%\n",
"Epoch 4/30: test accuracy: 90.7% ± 2.2%\n",
"Epoch 5/30: train accuracy: 92.4% ± 2.0%\n",
"Epoch 5/30: test accuracy: 91.3% ± 2.1%\n",
"Epoch 6/30: train accuracy: 92.5% ± 2.0%\n",
"Epoch 6/30: test accuracy: 91.8% ± 2.1%\n",
"Epoch 7/30: train accuracy: 93.0% ± 1.9%\n",
"Epoch 7/30: test accuracy: 92.2% ± 2.0%\n",
"Epoch 8/30: train accuracy: 93.1% ± 1.9%\n",
"Epoch 8/30: test accuracy: 92.7% ± 2.0%\n",
"Epoch 9/30: train accuracy: 93.3% ± 1.9%\n",
"Epoch 9/30: test accuracy: 92.5% ± 2.0%\n",
"Epoch 10/30: train accuracy: 93.4% ± 1.9%\n",
"Epoch 10/30: test accuracy: 92.7% ± 2.0%\n",
"Epoch 11/30: train accuracy: 93.6% ± 1.9%\n",
"Epoch 11/30: test accuracy: 92.8% ± 2.0%\n",
"Epoch 12/30: train accuracy: 93.7% ± 1.8%\n",
"Epoch 12/30: test accuracy: 92.8% ± 2.0%\n",
"Epoch 13/30: train accuracy: 94.0% ± 1.8%\n",
"Epoch 13/30: test accuracy: 93.0% ± 1.9%\n",
"Epoch 14/30: train accuracy: 93.9% ± 1.8%\n",
"Epoch 14/30: test accuracy: 93.0% ± 1.9%\n",
"Epoch 15/30: train accuracy: 94.2% ± 1.8%\n",
"Epoch 15/30: test accuracy: 93.0% ± 1.9%\n",
"Epoch 16/30: train accuracy: 94.2% ± 1.8%\n",
"Epoch 16/30: test accuracy: 93.0% ± 1.9%\n",
"Epoch 17/30: train accuracy: 94.3% ± 1.8%\n",
"Epoch 17/30: test accuracy: 93.0% ± 1.9%\n",
"Epoch 18/30: train accuracy: 94.5% ± 1.7%\n",
"Epoch 18/30: test accuracy: 93.1% ± 1.9%\n",
"Epoch 19/30: train accuracy: 94.5% ± 1.7%\n",
"Epoch 19/30: test accuracy: 93.1% ± 1.9%\n",
"Epoch 20/30: train accuracy: 94.6% ± 1.7%\n",
"Epoch 20/30: test accuracy: 93.3% ± 1.9%\n",
"Epoch 21/30: train accuracy: 94.8% ± 1.7%\n",
"Epoch 21/30: test accuracy: 93.3% ± 1.9%\n",
"Epoch 22/30: train accuracy: 94.8% ± 1.7%\n",
"Epoch 22/30: test accuracy: 93.4% ± 1.9%\n",
"Epoch 23/30: train accuracy: 94.8% ± 1.7%\n",
"Epoch 23/30: test accuracy: 93.4% ± 1.9%\n",
"Epoch 24/30: train accuracy: 94.8% ± 1.7%\n",
"Epoch 24/30: test accuracy: 93.4% ± 1.9%\n",
"Epoch 25/30: train accuracy: 94.8% ± 1.7%\n",
"Epoch 25/30: test accuracy: 93.4% ± 1.9%\n",
"Epoch 26/30: train accuracy: 94.9% ± 1.7%\n",
"Epoch 26/30: test accuracy: 93.6% ± 1.9%\n",
"Epoch 27/30: train accuracy: 94.9% ± 1.7%\n",
"Epoch 27/30: test accuracy: 93.6% ± 1.9%\n",
"Epoch 28/30: train accuracy: 94.9% ± 1.7%\n",
"Epoch 28/30: test accuracy: 93.6% ± 1.9%\n",
"Epoch 29/30: train accuracy: 95.1% ± 1.6%\n",
"Epoch 29/30: test accuracy: 93.6% ± 1.9%\n",
"Epoch 30/30: train accuracy: 95.1% ± 1.6%\n",
"Epoch 30/30: test accuracy: 93.6% ± 1.9%\n"
]
}
],
"source": [
"# example hyperparameter search\n",
"# I recommend starting with max_epochs=10 while initially exploring\n",
"results = []\n",
"max_epochs = 30\n",
"dropout_fraction = 0.2\n",
"for batch_size, learning_rate in [(10, 10), (100, 100), (1000, 1000)]:\n",
" result = optimize_matrix(\n",
" batch_size=batch_size,\n",
" learning_rate=learning_rate,\n",
" max_epochs=max_epochs,\n",
" dropout_fraction=dropout_fraction,\n",
" save_results=False,\n",
" )\n",
" results.append(result)\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "PoTZWC1SpgkS",
"outputId": "207360e5-fd07-4180-a143-0ec5dd27ffe1"
},
"outputs": [
{
"data": {
"application/vnd.plotly.v1+json": {
"config": {
"plotlyServerURL": "https://plot.ly"
},
"data": [
{
"customdata": [
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
]
],
"hovertemplate": "type=train<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=1449308123<br>epoch=%{x}<br>loss=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
"legendgroup": "train",
"line": {
"color": "#636efa",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "train",
"orientation": "v",
"showlegend": true,
"type": "scatter",
"x": [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30
],
"xaxis": "x7",
"y": [
1.1652076244354248,
0.8190054297447205,
0.9203463792800903,
0.9908844232559204,
0.8408004641532898,
0.7119297981262207,
0.6632728576660156,
0.8289280533790588,
0.8687358498573303,
0.8235021829605103,
0.895695149898529,
0.6677632331848145,
0.7526643872261047,
0.7708764672279358,
0.68276047706604,
0.6613249778747559,
0.5960850119590759,
0.8617165684700012,
0.724422037601471,
0.9765143394470215,
0.5958823561668396,
0.7277706265449524,
0.7929649353027344,
0.8311190009117126,
0.484933465719223,
0.6846191883087158,
0.6711297035217285,
0.738968551158905,
0.5267000198364258,
0.9111422300338745
],
"yaxis": "y7"
},
{
"customdata": [
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
]
],
"hovertemplate": "type=train<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=676326879<br>epoch=%{x}<br>loss=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
"legendgroup": "train",
"line": {
"color": "#636efa",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "train",
"orientation": "v",
"showlegend": false,
"type": "scatter",
"x": [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30
],
"xaxis": "x5",
"y": [
1.06173574924469,
0.9195982813835144,
0.8605211973190308,
0.7780925631523132,
0.7680334448814392,
0.7896735072135925,
0.7652512788772583,
0.7015480399131775,
0.8019503951072693,
0.7844551801681519,
0.823682427406311,
0.711807131767273,
0.7855805158615112,
0.7014225125312805,
0.7862630486488342,
0.6663534045219421,
0.7388879060745239,
0.6876973509788513,
0.7274147272109985,
0.7191041111946106,
0.8075127601623535,
0.7195712924003601,
0.746185839176178,
0.7220138311386108,
0.7456589341163635,
0.6642791032791138,
0.7399784326553345,
0.7393214106559753,
0.6680636405944824,
0.6562733054161072
],
"yaxis": "y5"
},
{
"customdata": [
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
]
],
"hovertemplate": "type=train<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=881033356<br>epoch=%{x}<br>loss=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
"legendgroup": "train",
"line": {
"color": "#636efa",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "train",
"orientation": "v",
"showlegend": false,
"type": "scatter",
"x": [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30
],
"xaxis": "x3",
"y": [
1.299358606338501,
1.0686416625976562,
0.8467883467674255,
0.7818496823310852,
0.7742239236831665,
0.7678887844085693,
0.7696077227592468,
0.766943097114563,
0.7572135925292969,
0.7593567967414856,
0.7546665668487549,
0.7499144077301025,
0.7492073178291321,
0.7394933700561523,
0.743760883808136,
0.7366983294487,
0.7340478301048279,
0.7297782897949219,
0.7292298674583435,
0.7229472994804382,
0.7246285080909729,
0.721783459186554,
0.7177888751029968,
0.7198930978775024,
0.7123011946678162,
0.7132685780525208,
0.7121831178665161,
0.7118210196495056,
0.7035670280456543,
0.7066351771354675
],
"yaxis": "y3"
},
{
"customdata": [
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
]
],
"hovertemplate": "type=test<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=1449308123<br>epoch=%{x}<br>loss=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
"legendgroup": "test",
"line": {
"color": "#EF553B",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "test",
"orientation": "v",
"showlegend": true,
"type": "scatter",
"x": [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30
],
"xaxis": "x7",
"y": [
1.1514005661010742,
0.9815413355827332,
0.8687632083892822,
0.8124286532402039,
0.7918410301208496,
0.7831570506095886,
0.7777765393257141,
0.7735418677330017,
0.7701976895332336,
0.7720419764518738,
0.7726959586143494,
0.7650123834609985,
0.7650409936904907,
0.765683114528656,
0.7626248598098755,
0.7623012065887451,
0.7609940767288208,
0.7587862610816956,
0.7559080123901367,
0.7571383118629456,
0.7588285803794861,
0.7556950449943542,
0.7562108039855957,
0.7484216690063477,
0.7531804442405701,
0.7502257823944092,
0.7496891617774963,
0.7472137808799744,
0.748519241809845,
0.7483490705490112
],
"yaxis": "y7"
},
{
"customdata": [
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
]
],
"hovertemplate": "type=test<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=676326879<br>epoch=%{x}<br>loss=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
"legendgroup": "test",
"line": {
"color": "#EF553B",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "test",
"orientation": "v",
"showlegend": false,
"type": "scatter",
"x": [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30
],
"xaxis": "x5",
"y": [
1.1180120706558228,
0.9514004588127136,
0.8425014019012451,
0.8078141212463379,
0.7853298187255859,
0.7790449857711792,
0.7733959555625916,
0.7792260646820068,
0.7700104117393494,
0.7700384855270386,
0.764447033405304,
0.7696305513381958,
0.7642591595649719,
0.7629777193069458,
0.7620665431022644,
0.7622931003570557,
0.7602745890617371,
0.7564830780029297,
0.761269748210907,
0.7550154328346252,
0.7560049295425415,
0.7538376450538635,
0.7503026127815247,
0.7528620958328247,
0.7485130429267883,
0.7481465339660645,
0.7483287453651428,
0.742965817451477,
0.7445206642150879,
0.7476803064346313
],
"yaxis": "y5"
},
{
"customdata": [
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
]
],
"hovertemplate": "type=test<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=881033356<br>epoch=%{x}<br>loss=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
"legendgroup": "test",
"line": {
"color": "#EF553B",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "test",
"orientation": "v",
"showlegend": false,
"type": "scatter",
"x": [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30
],
"xaxis": "x3",
"y": [
1.0644111633300781,
0.8446540236473083,
0.783637285232544,
0.7773287892341614,
0.7771273851394653,
0.7754599452018738,
0.7732424736022949,
0.7710903882980347,
0.772106945514679,
0.7645512223243713,
0.7639281153678894,
0.7609775066375732,
0.7558433413505554,
0.7579177618026733,
0.7527595162391663,
0.7520467042922974,
0.7534357309341431,
0.7487133145332336,
0.7478086352348328,
0.747072160243988,
0.7411499619483948,
0.7459169030189514,
0.7451942563056946,
0.7394304275512695,
0.7337468862533569,
0.734693169593811,
0.7376500368118286,
0.7371401190757751,
0.731780469417572,
0.727291464805603
],
"yaxis": "y3"
}
],
"layout": {
"annotations": [
{
"font": {},
"showarrow": false,
"text": "batch_size=10",
"x": 0.15666666666666665,
"xanchor": "center",
"xref": "paper",
"y": 0.9999999999999998,
"yanchor": "bottom",
"yref": "paper"
},
{
"font": {},
"showarrow": false,
"text": "batch_size=100",
"x": 0.49,
"xanchor": "center",
"xref": "paper",
"y": 0.9999999999999998,
"yanchor": "bottom",
"yref": "paper"
},
{
"font": {},
"showarrow": false,
"text": "batch_size=1000",
"x": 0.8233333333333333,
"xanchor": "center",
"xref": "paper",
"y": 0.9999999999999998,
"yanchor": "bottom",
"yref": "paper"
},
{
"font": {},
"showarrow": false,
"text": "learning_rate=1000",
"textangle": 90,
"x": 0.98,
"xanchor": "left",
"xref": "paper",
"y": 0.15666666666666665,
"yanchor": "middle",
"yref": "paper"
},
{
"font": {},
"showarrow": false,
"text": "learning_rate=100",
"textangle": 90,
"x": 0.98,
"xanchor": "left",
"xref": "paper",
"y": 0.4999999999999999,
"yanchor": "middle",
"yref": "paper"
},
{
"font": {},
"showarrow": false,
"text": "learning_rate=10",
"textangle": 90,
"x": 0.98,
"xanchor": "left",
"xref": "paper",
"y": 0.8433333333333332,
"yanchor": "middle",
"yref": "paper"
}
],
"legend": {
"title": {
"text": "type"
},
"tracegroupgap": 0
},
"margin": {
"t": 60
},
"template": {
"data": {
"bar": [
{
"error_x": {
"color": "#2a3f5f"
},
"error_y": {
"color": "#2a3f5f"
},
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "bar"
}
],
"barpolar": [
{
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "barpolar"
}
],
"carpet": [
{
"aaxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"baxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"type": "carpet"
}
],
"choropleth": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "choropleth"
}
],
"contour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "contour"
}
],
"contourcarpet": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "contourcarpet"
}
],
"heatmap": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmap"
}
],
"heatmapgl": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmapgl"
}
],
"histogram": [
{
"marker": {
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "histogram"
}
],
"histogram2d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2d"
}
],
"histogram2dcontour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2dcontour"
}
],
"mesh3d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "mesh3d"
}
],
"parcoords": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "parcoords"
}
],
"pie": [
{
"automargin": true,
"type": "pie"
}
],
"scatter": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter"
}
],
"scatter3d": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter3d"
}
],
"scattercarpet": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattercarpet"
}
],
"scattergeo": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergeo"
}
],
"scattergl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergl"
}
],
"scattermapbox": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattermapbox"
}
],
"scatterpolar": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolar"
}
],
"scatterpolargl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolargl"
}
],
"scatterternary": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterternary"
}
],
"surface": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "surface"
}
],
"table": [
{
"cells": {
"fill": {
"color": "#EBF0F8"
},
"line": {
"color": "white"
}
},
"header": {
"fill": {
"color": "#C8D4E3"
},
"line": {
"color": "white"
}
},
"type": "table"
}
]
},
"layout": {
"annotationdefaults": {
"arrowcolor": "#2a3f5f",
"arrowhead": 0,
"arrowwidth": 1
},
"autotypenumbers": "strict",
"coloraxis": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"colorscale": {
"diverging": [
[
0,
"#8e0152"
],
[
0.1,
"#c51b7d"
],
[
0.2,
"#de77ae"
],
[
0.3,
"#f1b6da"
],
[
0.4,
"#fde0ef"
],
[
0.5,
"#f7f7f7"
],
[
0.6,
"#e6f5d0"
],
[
0.7,
"#b8e186"
],
[
0.8,
"#7fbc41"
],
[
0.9,
"#4d9221"
],
[
1,
"#276419"
]
],
"sequential": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"sequentialminus": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
]
},
"colorway": [
"#636efa",
"#EF553B",
"#00cc96",
"#ab63fa",
"#FFA15A",
"#19d3f3",
"#FF6692",
"#B6E880",
"#FF97FF",
"#FECB52"
],
"font": {
"color": "#2a3f5f"
},
"geo": {
"bgcolor": "white",
"lakecolor": "white",
"landcolor": "#E5ECF6",
"showlakes": true,
"showland": true,
"subunitcolor": "white"
},
"hoverlabel": {
"align": "left"
},
"hovermode": "closest",
"mapbox": {
"style": "light"
},
"paper_bgcolor": "white",
"plot_bgcolor": "#E5ECF6",
"polar": {
"angularaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"radialaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"scene": {
"xaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"yaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"zaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
}
},
"shapedefaults": {
"line": {
"color": "#2a3f5f"
}
},
"ternary": {
"aaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"baxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"caxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"title": {
"x": 0.05
},
"xaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
},
"yaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
}
}
},
"width": 500,
"xaxis": {
"anchor": "y",
"domain": [
0,
0.3133333333333333
],
"title": {
"text": "epoch"
}
},
"xaxis2": {
"anchor": "y2",
"domain": [
0.3333333333333333,
0.6466666666666666
],
"matches": "x",
"title": {
"text": "epoch"
}
},
"xaxis3": {
"anchor": "y3",
"domain": [
0.6666666666666666,
0.98
],
"matches": "x",
"title": {
"text": "epoch"
}
},
"xaxis4": {
"anchor": "y4",
"domain": [
0,
0.3133333333333333
],
"matches": "x",
"showticklabels": false
},
"xaxis5": {
"anchor": "y5",
"domain": [
0.3333333333333333,
0.6466666666666666
],
"matches": "x",
"showticklabels": false
},
"xaxis6": {
"anchor": "y6",
"domain": [
0.6666666666666666,
0.98
],
"matches": "x",
"showticklabels": false
},
"xaxis7": {
"anchor": "y7",
"domain": [
0,
0.3133333333333333
],
"matches": "x",
"showticklabels": false
},
"xaxis8": {
"anchor": "y8",
"domain": [
0.3333333333333333,
0.6466666666666666
],
"matches": "x",
"showticklabels": false
},
"xaxis9": {
"anchor": "y9",
"domain": [
0.6666666666666666,
0.98
],
"matches": "x",
"showticklabels": false
},
"yaxis": {
"anchor": "x",
"domain": [
0,
0.3133333333333333
],
"title": {
"text": "loss"
}
},
"yaxis2": {
"anchor": "x2",
"domain": [
0,
0.3133333333333333
],
"matches": "y",
"showticklabels": false
},
"yaxis3": {
"anchor": "x3",
"domain": [
0,
0.3133333333333333
],
"matches": "y",
"showticklabels": false
},
"yaxis4": {
"anchor": "x4",
"domain": [
0.34333333333333327,
0.6566666666666665
],
"matches": "y",
"title": {
"text": "loss"
}
},
"yaxis5": {
"anchor": "x5",
"domain": [
0.34333333333333327,
0.6566666666666665
],
"matches": "y",
"showticklabels": false
},
"yaxis6": {
"anchor": "x6",
"domain": [
0.34333333333333327,
0.6566666666666665
],
"matches": "y",
"showticklabels": false
},
"yaxis7": {
"anchor": "x7",
"domain": [
0.6866666666666665,
0.9999999999999998
],
"matches": "y",
"title": {
"text": "loss"
}
},
"yaxis8": {
"anchor": "x8",
"domain": [
0.6866666666666665,
0.9999999999999998
],
"matches": "y",
"showticklabels": false
},
"yaxis9": {
"anchor": "x9",
"domain": [
0.6866666666666665,
0.9999999999999998
],
"matches": "y",
"showticklabels": false
}
}
}
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.plotly.v1+json": {
"config": {
"plotlyServerURL": "https://plot.ly"
},
"data": [
{
"customdata": [
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
]
],
"hovertemplate": "type=train<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=1449308123<br>epoch=%{x}<br>accuracy=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
"legendgroup": "train",
"line": {
"color": "#636efa",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "train",
"orientation": "v",
"showlegend": true,
"type": "scatter",
"x": [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30
],
"xaxis": "x7",
"y": [
0.8907185628742516,
0.8952095808383234,
0.905688622754491,
0.9116766467065869,
0.9146706586826348,
0.9191616766467066,
0.9221556886227545,
0.9266467065868264,
0.9266467065868264,
0.9296407185628742,
0.9311377245508982,
0.9341317365269461,
0.9356287425149701,
0.937125748502994,
0.937125748502994,
0.9401197604790419,
0.9401197604790419,
0.9416167664670658,
0.9416167664670658,
0.9431137724550899,
0.9446107784431138,
0.9446107784431138,
0.9461077844311377,
0.9461077844311377,
0.9476047904191617,
0.9476047904191617,
0.9476047904191617,
0.9491017964071856,
0.9491017964071856,
0.9491017964071856
],
"yaxis": "y7"
},
{
"customdata": [
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
]
],
"hovertemplate": "type=train<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=676326879<br>epoch=%{x}<br>accuracy=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
"legendgroup": "train",
"line": {
"color": "#636efa",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "train",
"orientation": "v",
"showlegend": false,
"type": "scatter",
"x": [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30
],
"xaxis": "x5",
"y": [
0.8967065868263473,
0.8982035928143712,
0.9026946107784432,
0.9101796407185628,
0.9131736526946108,
0.9176646706586826,
0.9236526946107785,
0.9281437125748503,
0.9311377245508982,
0.9341317365269461,
0.9341317365269461,
0.9356287425149701,
0.937125748502994,
0.937125748502994,
0.938622754491018,
0.9401197604790419,
0.9401197604790419,
0.9416167664670658,
0.9416167664670658,
0.9416167664670658,
0.9431137724550899,
0.9431137724550899,
0.9446107784431138,
0.9446107784431138,
0.9461077844311377,
0.9461077844311377,
0.9461077844311377,
0.9476047904191617,
0.9476047904191617,
0.9476047904191617
],
"yaxis": "y5"
},
{
"customdata": [
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
]
],
"hovertemplate": "type=train<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=881033356<br>epoch=%{x}<br>accuracy=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
"legendgroup": "train",
"line": {
"color": "#636efa",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "train",
"orientation": "v",
"showlegend": false,
"type": "scatter",
"x": [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30
],
"xaxis": "x3",
"y": [
0.907185628742515,
0.9086826347305389,
0.9161676646706587,
0.9221556886227545,
0.9236526946107785,
0.9251497005988024,
0.9296407185628742,
0.9311377245508982,
0.9326347305389222,
0.9341317365269461,
0.9356287425149701,
0.937125748502994,
0.9401197604790419,
0.938622754491018,
0.9416167664670658,
0.9416167664670658,
0.9431137724550899,
0.9446107784431138,
0.9446107784431138,
0.9461077844311377,
0.9476047904191617,
0.9476047904191617,
0.9476047904191617,
0.9476047904191617,
0.9476047904191617,
0.9491017964071856,
0.9491017964071856,
0.9491017964071856,
0.9505988023952096,
0.9505988023952096
],
"yaxis": "y3"
},
{
"customdata": [
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
],
[
10,
10,
0.2
]
],
"hovertemplate": "type=test<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=1449308123<br>epoch=%{x}<br>accuracy=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
"legendgroup": "test",
"line": {
"color": "#EF553B",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "test",
"orientation": "v",
"showlegend": true,
"type": "scatter",
"x": [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30
],
"xaxis": "x7",
"y": [
0.8835820895522388,
0.8880597014925373,
0.8925373134328358,
0.8970149253731343,
0.9,
0.9044776119402985,
0.9074626865671642,
0.908955223880597,
0.9104477611940298,
0.9164179104477612,
0.917910447761194,
0.9208955223880597,
0.9238805970149254,
0.926865671641791,
0.926865671641791,
0.9298507462686567,
0.9298507462686567,
0.9313432835820895,
0.9313432835820895,
0.9298507462686567,
0.9313432835820895,
0.9328358208955224,
0.9328358208955224,
0.9328358208955224,
0.9328358208955224,
0.9343283582089552,
0.9343283582089552,
0.9343283582089552,
0.9343283582089552,
0.9328358208955224
],
"yaxis": "y7"
},
{
"customdata": [
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
],
[
100,
100,
0.2
]
],
"hovertemplate": "type=test<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=676326879<br>epoch=%{x}<br>accuracy=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
"legendgroup": "test",
"line": {
"color": "#EF553B",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "test",
"orientation": "v",
"showlegend": false,
"type": "scatter",
"x": [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30
],
"xaxis": "x5",
"y": [
0.891044776119403,
0.8985074626865671,
0.9,
0.9029850746268657,
0.9029850746268657,
0.9044776119402985,
0.9104477611940298,
0.9134328358208955,
0.9164179104477612,
0.9194029850746268,
0.917910447761194,
0.9208955223880597,
0.9238805970149254,
0.9253731343283582,
0.9283582089552239,
0.9283582089552239,
0.9283582089552239,
0.9283582089552239,
0.9283582089552239,
0.9313432835820895,
0.9328358208955224,
0.9328358208955224,
0.9328358208955224,
0.9328358208955224,
0.9343283582089552,
0.9328358208955224,
0.9343283582089552,
0.9343283582089552,
0.9328358208955224,
0.9343283582089552
],
"yaxis": "y5"
},
{
"customdata": [
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
],
[
1000,
1000,
0.2
]
],
"hovertemplate": "type=test<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=881033356<br>epoch=%{x}<br>accuracy=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
"legendgroup": "test",
"line": {
"color": "#EF553B",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "test",
"orientation": "v",
"showlegend": false,
"type": "scatter",
"x": [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30
],
"xaxis": "x3",
"y": [
0.8985074626865671,
0.9029850746268657,
0.9029850746268657,
0.9074626865671642,
0.9134328358208955,
0.917910447761194,
0.9223880597014925,
0.926865671641791,
0.9253731343283582,
0.926865671641791,
0.9283582089552239,
0.9283582089552239,
0.9298507462686567,
0.9298507462686567,
0.9298507462686567,
0.9298507462686567,
0.9298507462686567,
0.9313432835820895,
0.9313432835820895,
0.9328358208955224,
0.9328358208955224,
0.9343283582089552,
0.9343283582089552,
0.9343283582089552,
0.9343283582089552,
0.935820895522388,
0.935820895522388,
0.935820895522388,
0.935820895522388,
0.935820895522388
],
"yaxis": "y3"
}
],
"layout": {
"annotations": [
{
"font": {},
"showarrow": false,
"text": "batch_size=10",
"x": 0.15666666666666665,
"xanchor": "center",
"xref": "paper",
"y": 0.9999999999999998,
"yanchor": "bottom",
"yref": "paper"
},
{
"font": {},
"showarrow": false,
"text": "batch_size=100",
"x": 0.49,
"xanchor": "center",
"xref": "paper",
"y": 0.9999999999999998,
"yanchor": "bottom",
"yref": "paper"
},
{
"font": {},
"showarrow": false,
"text": "batch_size=1000",
"x": 0.8233333333333333,
"xanchor": "center",
"xref": "paper",
"y": 0.9999999999999998,
"yanchor": "bottom",
"yref": "paper"
},
{
"font": {},
"showarrow": false,
"text": "learning_rate=1000",
"textangle": 90,
"x": 0.98,
"xanchor": "left",
"xref": "paper",
"y": 0.15666666666666665,
"yanchor": "middle",
"yref": "paper"
},
{
"font": {},
"showarrow": false,
"text": "learning_rate=100",
"textangle": 90,
"x": 0.98,
"xanchor": "left",
"xref": "paper",
"y": 0.4999999999999999,
"yanchor": "middle",
"yref": "paper"
},
{
"font": {},
"showarrow": false,
"text": "learning_rate=10",
"textangle": 90,
"x": 0.98,
"xanchor": "left",
"xref": "paper",
"y": 0.8433333333333332,
"yanchor": "middle",
"yref": "paper"
}
],
"legend": {
"title": {
"text": "type"
},
"tracegroupgap": 0
},
"margin": {
"t": 60
},
"template": {
"data": {
"bar": [
{
"error_x": {
"color": "#2a3f5f"
},
"error_y": {
"color": "#2a3f5f"
},
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "bar"
}
],
"barpolar": [
{
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "barpolar"
}
],
"carpet": [
{
"aaxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"baxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"type": "carpet"
}
],
"choropleth": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "choropleth"
}
],
"contour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "contour"
}
],
"contourcarpet": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "contourcarpet"
}
],
"heatmap": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmap"
}
],
"heatmapgl": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmapgl"
}
],
"histogram": [
{
"marker": {
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "histogram"
}
],
"histogram2d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2d"
}
],
"histogram2dcontour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2dcontour"
}
],
"mesh3d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "mesh3d"
}
],
"parcoords": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "parcoords"
}
],
"pie": [
{
"automargin": true,
"type": "pie"
}
],
"scatter": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter"
}
],
"scatter3d": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter3d"
}
],
"scattercarpet": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattercarpet"
}
],
"scattergeo": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergeo"
}
],
"scattergl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergl"
}
],
"scattermapbox": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattermapbox"
}
],
"scatterpolar": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolar"
}
],
"scatterpolargl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolargl"
}
],
"scatterternary": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterternary"
}
],
"surface": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "surface"
}
],
"table": [
{
"cells": {
"fill": {
"color": "#EBF0F8"
},
"line": {
"color": "white"
}
},
"header": {
"fill": {
"color": "#C8D4E3"
},
"line": {
"color": "white"
}
},
"type": "table"
}
]
},
"layout": {
"annotationdefaults": {
"arrowcolor": "#2a3f5f",
"arrowhead": 0,
"arrowwidth": 1
},
"autotypenumbers": "strict",
"coloraxis": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"colorscale": {
"diverging": [
[
0,
"#8e0152"
],
[
0.1,
"#c51b7d"
],
[
0.2,
"#de77ae"
],
[
0.3,
"#f1b6da"
],
[
0.4,
"#fde0ef"
],
[
0.5,
"#f7f7f7"
],
[
0.6,
"#e6f5d0"
],
[
0.7,
"#b8e186"
],
[
0.8,
"#7fbc41"
],
[
0.9,
"#4d9221"
],
[
1,
"#276419"
]
],
"sequential": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"sequentialminus": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
]
},
"colorway": [
"#636efa",
"#EF553B",
"#00cc96",
"#ab63fa",
"#FFA15A",
"#19d3f3",
"#FF6692",
"#B6E880",
"#FF97FF",
"#FECB52"
],
"font": {
"color": "#2a3f5f"
},
"geo": {
"bgcolor": "white",
"lakecolor": "white",
"landcolor": "#E5ECF6",
"showlakes": true,
"showland": true,
"subunitcolor": "white"
},
"hoverlabel": {
"align": "left"
},
"hovermode": "closest",
"mapbox": {
"style": "light"
},
"paper_bgcolor": "white",
"plot_bgcolor": "#E5ECF6",
"polar": {
"angularaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"radialaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"scene": {
"xaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"yaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"zaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
}
},
"shapedefaults": {
"line": {
"color": "#2a3f5f"
}
},
"ternary": {
"aaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"baxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"caxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"title": {
"x": 0.05
},
"xaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
},
"yaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
}
}
},
"width": 500,
"xaxis": {
"anchor": "y",
"domain": [
0,
0.3133333333333333
],
"title": {
"text": "epoch"
}
},
"xaxis2": {
"anchor": "y2",
"domain": [
0.3333333333333333,
0.6466666666666666
],
"matches": "x",
"title": {
"text": "epoch"
}
},
"xaxis3": {
"anchor": "y3",
"domain": [
0.6666666666666666,
0.98
],
"matches": "x",
"title": {
"text": "epoch"
}
},
"xaxis4": {
"anchor": "y4",
"domain": [
0,
0.3133333333333333
],
"matches": "x",
"showticklabels": false
},
"xaxis5": {
"anchor": "y5",
"domain": [
0.3333333333333333,
0.6466666666666666
],
"matches": "x",
"showticklabels": false
},
"xaxis6": {
"anchor": "y6",
"domain": [
0.6666666666666666,
0.98
],
"matches": "x",
"showticklabels": false
},
"xaxis7": {
"anchor": "y7",
"domain": [
0,
0.3133333333333333
],
"matches": "x",
"showticklabels": false
},
"xaxis8": {
"anchor": "y8",
"domain": [
0.3333333333333333,
0.6466666666666666
],
"matches": "x",
"showticklabels": false
},
"xaxis9": {
"anchor": "y9",
"domain": [
0.6666666666666666,
0.98
],
"matches": "x",
"showticklabels": false
},
"yaxis": {
"anchor": "x",
"domain": [
0,
0.3133333333333333
],
"title": {
"text": "accuracy"
}
},
"yaxis2": {
"anchor": "x2",
"domain": [
0,
0.3133333333333333
],
"matches": "y",
"showticklabels": false
},
"yaxis3": {
"anchor": "x3",
"domain": [
0,
0.3133333333333333
],
"matches": "y",
"showticklabels": false
},
"yaxis4": {
"anchor": "x4",
"domain": [
0.34333333333333327,
0.6566666666666665
],
"matches": "y",
"title": {
"text": "accuracy"
}
},
"yaxis5": {
"anchor": "x5",
"domain": [
0.34333333333333327,
0.6566666666666665
],
"matches": "y",
"showticklabels": false
},
"yaxis6": {
"anchor": "x6",
"domain": [
0.34333333333333327,
0.6566666666666665
],
"matches": "y",
"showticklabels": false
},
"yaxis7": {
"anchor": "x7",
"domain": [
0.6866666666666665,
0.9999999999999998
],
"matches": "y",
"title": {
"text": "accuracy"
}
},
"yaxis8": {
"anchor": "x8",
"domain": [
0.6866666666666665,
0.9999999999999998
],
"matches": "y",
"showticklabels": false
},
"yaxis9": {
"anchor": "x9",
"domain": [
0.6866666666666665,
0.9999999999999998
],
"matches": "y",
"showticklabels": false
}
}
}
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"runs_df = pd.concat(results)\n",
"\n",
"# plot training loss and test loss over time\n",
"px.line(\n",
" runs_df,\n",
" line_group=\"run_id\",\n",
" x=\"epoch\",\n",
" y=\"loss\",\n",
" color=\"type\",\n",
" hover_data=[\"batch_size\", \"learning_rate\", \"dropout_fraction\"],\n",
" facet_row=\"learning_rate\",\n",
" facet_col=\"batch_size\",\n",
" width=500,\n",
").show()\n",
"\n",
"# plot accuracy over time\n",
"px.line(\n",
" runs_df,\n",
" line_group=\"run_id\",\n",
" x=\"epoch\",\n",
" y=\"accuracy\",\n",
" color=\"type\",\n",
" hover_data=[\"batch_size\", \"learning_rate\", \"dropout_fraction\"],\n",
" facet_row=\"learning_rate\",\n",
" facet_col=\"batch_size\",\n",
" width=500,\n",
").show()\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "MiBQDcMPpgkS"
},
"source": [
"## 8. Plot the before & after, showing the results of the best matrix found during training\n",
"\n",
"The better the matrix is, the more cleanly it will separate the similar and dissimilar pairs."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "hzjoyLDOpgkS"
},
"outputs": [],
"source": [
"# apply result of best run to original data\n",
"best_run = runs_df.sort_values(by=\"accuracy\", ascending=False).iloc[0]\n",
"best_matrix = best_run[\"matrix\"]\n",
"apply_matrix_to_embeddings_dataframe(best_matrix, df)\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "nLnvABnXpgkS",
"outputId": "0c070faa-6e3e-4765-b082-565c72a609be"
},
"outputs": [
{
"data": {
"application/vnd.plotly.v1+json": {
"config": {
"plotlyServerURL": "https://plot.ly"
},
"data": [
{
"alignmentgroup": "True",
"bingroup": "x",
"hovertemplate": "label=1<br>dataset=train<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
"legendgroup": "1",
"marker": {
"color": "#636efa",
"opacity": 0.5,
"pattern": {
"shape": ""
}
},
"name": "1",
"offsetgroup": "1",
"orientation": "v",
"showlegend": true,
"type": "histogram",
"x": [
0.9267355919345323,
0.8959824209230313,
0.9119725922265434,
0.854066984766886,
0.892887342475088,
0.9197115510183956,
0.8645429616364463,
0.8314164166177409,
0.7243313891695883,
0.8819971504547897,
0.7956215051893748,
0.7959481856578454,
0.8682525487547492,
0.8973704561532523,
0.8648042605746787,
0.9236698981903941,
0.9834804742323746,
0.8152447410515193,
0.8251720025778885,
0.8138195587168373,
0.8041885875094662,
0.9329690881953521,
0.9560346908585096,
0.9727875559800526,
0.8739787488907723,
0.8208200937845748,
0.7246913159387628,
0.9324916305400826,
0.8285737172791849,
0.8797008558439083,
0.820333215489803,
0.9370111008629193,
0.8983827475968527,
0.8312111255338568,
0.8164052516323846,
0.8908148647559723,
0.7466264012908705,
0.749651964301876,
0.87375582683117,
0.7849398161817545,
0.8309506403579568,
0.9307212180139316,
0.8281747408531835,
0.9529528469109777,
0.7828662080109449,
0.887100957550481,
0.9000278356226864,
0.8805448739521785,
0.930239131235984,
0.8801954897699975,
0.8529206900756547,
0.9467797362617967,
0.950367690540818,
0.7030531843080301,
0.8643992401506337,
0.8536886651933216,
0.9619331104636477,
0.9798279220663803,
0.8545739729953584,
0.8957115039259408,
0.8241137169318955,
0.8234984837204449,
0.8936706246811023,
0.8987178417246647,
0.9081806514389928,
0.9208852073112466,
0.8961858080980447,
0.8831329491347061,
0.9282623092166472,
0.8990849228018752,
0.8284548395672304,
0.8202091311790463,
0.8647762700375631,
0.840136958242196,
0.9887387562448309,
0.833342655983456,
0.828533112298018,
0.9117757392019592,
0.8706628948983062,
0.9279786042901599,
0.7389559895894061,
0.8433932035736679,
0.9240307526722153,
0.950769936374633,
0.8586024431762636,
0.8685107120156628,
0.875535036209879,
0.9894909159640456,
0.8279650063634755,
0.9108703739124493,
0.9090161898700972,
0.8603952576151226,
0.7791958170479588,
0.8800175078297319,
0.8442387843179446,
0.7672266106337577,
0.9379753906877134,
0.8637536227406404,
0.9190295684769377,
0.813748744480001,
0.9134886375270357,
0.8043760086358624,
0.8792300492020616,
0.8716796296133559,
0.8669146731224541,
0.7736224659067634,
0.9437235456545872,
0.905686329875817,
0.9534823418409002,
0.9150626356433706,
0.9409873570135594,
0.8111212499450311,
0.9171209901271343,
0.9126582215550586,
0.8337042981493767,
0.7317859043960847,
0.8444929460069593,
0.8561920422842992,
0.7765276739806221,
0.8526116779814181,
0.9178549171995068,
0.9238337665547537,
0.7218806795387105,
0.8180162420894919,
0.9687438998025378,
0.8354559162714565,
0.9146160667909478,
0.8082103463995212,
0.9563444955040953,
0.9066029892990775,
0.8485102485675593,
0.8154210965438722,
0.8860338155558548,
0.9280705420457523,
0.9835182434488767,
0.9653797793986199,
0.7815047672209323,
0.7156150747063292,
0.9256075945804699,
0.8135073898727201,
0.965501517353668,
0.8222681603043882,
0.907212186549752,
0.8611990749004483,
0.9075083701591861,
0.9452697087352507,
0.8792221638449876,
0.9261547230875113,
0.8628843081873457,
0.7825774684929468,
0.826587869384389,
0.7399697967202743,
0.7855475169419611,
0.9111614040328055,
0.8871057104665023,
0.8824403169032524,
0.8618250237448342,
0.9787037901590712,
0.8066230600465234,
0.82769109211187,
0.9246432066709112,
0.8840853142719706,
0.786484352300887,
0.9106863314452291,
0.9342800776988389,
0.8573335079738363,
0.7780871062391552,
0.7913314665963198,
0.8574377397709955,
0.9078366114351649,
0.7529739278117176,
0.8630106564789936,
0.9051526759332171,
0.7715467442267975,
0.894146587858075,
0.8095881901541373,
0.7733578316514722,
0.7600408374028372,
0.7819972017869313,
0.9003461714649055,
0.7428048011790054,
0.8645936498599726,
0.8158769881622202,
0.8338827592994027,
0.8272653967731632,
0.9017517376316297,
0.8480852025321463,
0.7970818331146111,
0.84837067058732,
0.9272909953469497,
0.9511439765702141,
0.8796630924442385,
0.8297595339070903,
0.8132311695727563,
0.846096509352579,
0.8787645385610889,
0.8591367322994465,
0.8452813265723063,
0.708120854745005,
0.8769677220320785,
0.957621649201,
0.7463356296247886,
0.8618039385828121,
0.9560112462834625,
0.8478374736767776,
0.769289015775232,
0.8456866935685449,
0.9014601942765245,
0.8816990623836058,
0.8836365012367311,
0.8078009762348856,
0.898471670333294,
0.9064470715463968,
0.8762712610709371,
0.9178852315161201,
0.7896235958446132,
0.8939345739482307,
0.9534018415944343,
0.8358882135213556,
0.9488657107840998,
0.9046799883600772,
0.7583576517359092,
0.9080459939022663,
0.7709722685822517,
0.9635512477648485,
0.9792712672362888,
0.8526700765974843,
0.8278133097990042,
0.9735858611728696,
0.721230194834802,
0.8257425311085005,
0.9243205490037274,
0.9183796453898277,
0.9029146937248601,
0.9410246048181872,
0.9609604036664618,
0.7467407974920351,
0.8831901217813377,
0.8173287200784538,
0.8067347032404598,
0.7921957436175584,
0.9110994793014074,
0.8678737504217436,
0.9117743220816726,
0.781256498099708,
0.8553931170417658,
0.8798565764815073,
0.8485358179238083,
0.7748765495278522,
0.9432062986428599,
0.8328320716129907,
0.7983629771463491,
0.9345589964322458,
0.780034700168418,
0.9894680324844076,
0.8239308905190431,
0.8236003498823307,
0.8346101074957768,
0.8273793488443836,
0.7872103673165679,
0.9502897884724039,
0.83306630344504,
0.9346568240975981,
0.8082083566882774,
0.8920672687911747,
0.8566523137465843,
0.7636170858064102,
0.8271498146927989,
0.8450776675259235,
0.9045266249905214,
0.8578964053464326,
0.8673866119197355,
0.8804224181917263,
0.8199459545727489,
0.932410033971128,
0.9096821262750205,
0.8658255622247675,
0.9386720385226357,
0.8517830114204678,
0.889433736353776,
0.978847593529176,
0.8369738178686744,
0.8438616304896431,
0.9457050132083015,
0.8699723446274389,
0.7795221418941651,
0.9136284834408402,
0.839461038218975,
0.9453279809479284,
0.7899532071809374,
0.9078373589637019,
0.8434980548175782,
0.8112068688921613,
0.9466506411385276,
0.931413666561485,
0.7932453730948084,
0.8205411414836395,
0.9243834389669592,
0.719616209472987,
0.7552985082978257,
0.9593440978701407,
0.917557937781671,
0.8643861907339583,
0.8315201130646898,
0.7608819746989568,
0.9704324558544556,
0.8037085307118571,
0.7785270629127378,
0.8044961889638995,
0.8313307525316546,
0.8064106357955508,
0.9291149169587132,
0.8412940941318643,
0.6917091086045903,
0.895204432897799,
0.8182250729145697,
0.8645847241904496,
0.8532020284403578,
0.8143634597102418,
0.8825747762544934,
0.7764652529619696,
0.850099367461302,
0.8616919096196407,
0.9257293995599788,
0.935772204341753,
0.7742657205171264,
0.7898710076766889,
0.8590438495382458,
0.9317809680608741,
0.9087109944408146,
0.949297999227784,
0.8813316524294974,
0.7372081408133433,
0.8838176418404169
],
"xaxis": "x2",
"yaxis": "y2"
},
{
"alignmentgroup": "True",
"bingroup": "x",
"hovertemplate": "label=1<br>dataset=test<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
"legendgroup": "1",
"marker": {
"color": "#636efa",
"opacity": 0.5,
"pattern": {
"shape": ""
}
},
"name": "1",
"offsetgroup": "1",
"orientation": "v",
"showlegend": false,
"type": "histogram",
"x": [
0.9424796847531719,
0.9078956620721231,
0.8334324872928349,
0.9352180099421901,
0.9055462980371385,
0.8981939710469806,
0.8310153256803275,
0.8504676050313356,
0.8456281884302819,
0.8845204616348977,
0.9575409743129409,
0.8867362111499236,
0.8268148049156373,
0.9197424487445726,
0.7868932880014918,
0.7584994087629051,
0.9184151106611779,
0.8634069821364065,
0.8347803687511831,
0.8293627315810226,
0.9290633380383377,
0.8385821691321573,
0.9389267221382397,
0.890818442649049,
0.86634760562554,
0.8406483291061461,
0.8084243397427517,
0.8909500804242533,
0.9262896019507018,
0.8955541231110083,
0.8055268512037249,
0.758607678624549,
0.9609058491722305,
0.91495905751538,
0.8670137150812314,
0.8813831602682358,
0.8602255150075953,
0.9239960998195261,
0.91732217769794,
0.8037285395908439,
0.9196344979660896,
0.8179495018382927,
0.9015423003762686,
0.9054394610623084,
0.9309412941515826,
0.9421722892808463,
0.7632823176731237,
0.8622055676377434,
0.9855273108710355,
0.9144155566597573,
0.916057392559463,
0.8027504546689012,
0.7131090055200247,
0.86174194818737,
0.982873171115144,
0.8100227539053508,
0.8923878603823862,
0.8096643432711604,
0.8707613707684018,
0.8786740148818871,
0.8274639911687568,
0.8927098767487393,
0.9565597075186062,
0.9060728095743347,
0.738307517030644,
0.9645943657917937,
0.8755564011033787,
0.879644342330288,
0.8679709669180851,
0.9304235148160597,
0.8902804960376421,
0.8748369557157689,
0.9999999999999999,
0.7979398167195422,
0.8182553472910762,
0.7782108671759803,
0.8427610550014142,
0.8696408842646327,
0.8747903026537825,
0.914973368517656,
0.9651568968718233,
0.9775547988384379,
0.8964005885715726,
0.8689760349348122,
0.8501707274497268,
0.9069421089316828,
0.7682621587597694,
0.9658683147667629,
0.8946443487011166,
0.7855154267442119,
0.8963791544804274,
0.8062904921192978,
0.8165205978948239,
0.8392522254625653,
0.9456080863923884,
0.7904904126133966,
0.833126792097794,
0.7852156051245299,
0.7859162398886373,
0.9097674992318959,
0.8868692153350769,
0.9391826646753169,
0.9428151202937937,
0.7923603877885725,
0.9018727189911658,
0.9723161942344292,
0.7820369113228325,
0.9667234201162873,
0.978769627389609,
0.9155729781931277,
0.8273013970664075,
0.9603319621485501,
0.9298975009081959,
0.8775117467919693,
0.8614509568162568,
0.9144155657624043,
0.7783710792186642,
0.9701880190033496,
0.7858944693777298,
0.9278353490304795,
0.9472367444800102,
0.7834809788888359,
0.7997358978342995,
0.8459052935830644,
0.8612076995259889,
0.8470901722260982,
0.824037271004761,
0.865608650068826,
0.8023193246081538,
0.7836788857151791,
0.880404135119559,
0.8491559249509729,
0.788345270395367,
0.9461393747813323,
0.835123385080455,
0.8158174048388718,
0.8604581300972295,
0.9623616555004862,
0.8564688397784819,
0.8576867681204893,
0.8973905356978807,
0.8634447095761868,
0.8149528594606367,
0.873171253801101,
0.8653347676818378,
0.929525558735665,
0.8358267202818717,
0.971888682386554,
0.8500189240129448,
0.6201715858790215,
0.8982737437906866,
0.8919523978597029,
0.7327218620224804,
0.8329671225228632,
0.9265589851585685,
0.8976605724257473,
0.8865148832512937,
0.7893917264917192,
0.7303107673595587,
0.8428958487387793,
0.8712646524042454,
0.9726111208329614,
0.9368020234351375,
0.9270010843622818,
0.8900608737128451,
0.7975173141451626,
0.9403308743666237,
0.8484005148509558,
0.9285585494125882,
0.8461714640960527,
0.9301612552439464,
0.9840391344904358,
0.8305503032909712,
0.8985536902480058,
0.9476344055766088,
0.9342892661789283,
0.8849523247638441,
0.7736620851030366,
0.8083290901088768,
0.9510007696957726,
0.8677438102591386,
0.8324233959261911,
0.7379868650067231,
0.9049462205283083,
0.9044068964151009,
0.7810399099399672,
0.9040280419401343,
0.7720832557964016,
0.7168259249496903,
0.8657076231674912,
0.9689982290529879,
0.9330371348632155,
0.7014093149115691,
0.9056081832768474,
0.8483474394912361,
0.8729108893663646,
0.8494252835407102,
0.8702668029663239,
0.8703072652243842,
0.9279473628052111,
0.8615930026010632,
0.7590822846968434,
0.8435232136599701,
0.8264379743713927,
0.8793126202650794,
0.847452301256391,
0.7546334370590053,
0.8870818568791698,
0.8349553719854912,
0.9232007589383142,
0.7924421895458662,
0.8556103146657943,
0.8397958720336148,
0.9358165878534705,
0.904577353436614,
0.9022537114781464,
0.775603917181226,
0.946091618927627,
0.8264119461162666,
0.8261258105029201,
0.8605336598142989,
0.7518422489499549,
0.849587557879856,
0.9922578397415717,
0.7499254104864422,
0.8845204616348977,
0.8361936554693772,
0.9172228808488442,
0.8068135566092208,
0.795739929008372,
0.8632611464779958,
0.7612462576141025,
0.9589125421073597,
0.9555759038945358,
0.8822980105025566,
0.9663740139243834,
0.9071760951442449,
0.9335338894977118,
0.8042262160785201,
0.9399607295068667,
0.8318513711395181,
0.8697471272873054,
0.9103391819002411,
0.8272582065159091,
0.7868989539137853,
0.7416168920325092,
0.8828593510646834,
0.9141342991325323,
0.7259887492833588,
0.9478299721997572,
0.843766518385904,
0.919830425538435,
0.9069062941852448,
0.9036466185261808,
0.9817542893707696,
0.8833292621745382,
0.832556614533968,
0.8135910443631924,
0.9628932969024508,
0.9450804655496136,
0.9226384091529121,
0.8401818103049188,
0.723691406760321,
0.6828741135249211,
0.834410523069395,
0.9959256404542386,
0.952870396433041,
0.969514692554192,
0.9220387806044666,
0.9511950116111251,
0.87442203104197,
0.8399026046246612,
0.9029483760650204,
0.9097073428917352,
0.8651925582004045,
0.9178332691819033,
0.7556713752294848,
0.8601740894614596,
0.8250804250840363,
0.799473306929639,
0.8911389639861809,
0.915913776235107,
0.7867422041389165,
0.8035116695233039,
0.7702882636946234,
0.9060460430333088,
0.7214029229730072,
0.8607904806397634,
0.8228468643082103,
0.8900020169140401,
0.9343567736626528,
0.9305049279291139,
0.9664193138643489,
0.9008537853184969,
0.7625840742620333,
0.815302054727336,
0.9215061720798934,
0.7192673801865671,
0.8949994067748966,
0.9367566547265034,
0.7602684166275758,
0.8184439767612992,
0.8361983856596491,
0.7761725471827079,
0.7724780968772909,
0.9249211346782868,
0.8718843131924867,
0.8522890335712519,
0.9015475867709434,
0.8720699810318118,
0.8937599387455695,
0.8721713573852221,
0.8100783166142076,
1.0000000000000002,
0.8213222537973748,
0.8361185401136565,
0.8371907459006128,
0.9065697385076582,
0.752240671472798,
0.8283078905766531,
0.8499886819287953,
0.9097932369637356,
0.9529813104528191,
0.8449289750214674,
1,
0.8302949362084788,
0.7741532046500113,
0.8743828037305432,
0.8201855611163867,
0.8194689758101458,
0.7925076796225758,
0.8748126117575765,
0.8299510305557958,
0.9619426561868236,
0.8627070029199212
],
"xaxis": "x",
"yaxis": "y"
},
{
"alignmentgroup": "True",
"bingroup": "x",
"hovertemplate": "label=-1<br>dataset=train<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
"legendgroup": "-1",
"marker": {
"color": "#EF553B",
"opacity": 0.5,
"pattern": {
"shape": ""
}
},
"name": "-1",
"offsetgroup": "-1",
"orientation": "v",
"showlegend": true,
"type": "histogram",
"x": [
0.7299945446332757,
0.761829365033925,
0.676235270353631,
0.7023016593603112,
0.7350156032869306,
0.7735236691667362,
0.7187241641280292,
0.8063818744486255,
0.7115273138749911,
0.7402039545462252,
0.7372226375565185,
0.719606336264161,
0.8344052997670786,
0.6881482918877712,
0.650513335277676,
0.7182967436895931,
0.7659970793010594,
0.6361476177437694,
0.7957002158983555,
0.7395402073375513,
0.7614384854511241,
0.6835300790002988,
0.6209194558826169,
0.7907726220786139,
0.7215750502680501,
0.7271947538454276,
0.6962333614625641,
0.7517476038206041,
0.7135529871009506,
0.7522919529124952,
0.7120639628754573,
0.7623014780353815,
0.7939574492876662,
0.6998873138766412,
0.7700594720774577,
0.7766161530342343,
0.7285080945118152,
0.7562511166739386,
0.8086622737220109,
0.7565297004511734,
0.7315242427462245,
0.791225232078265,
0.7281092467134649,
0.7675685917886341,
0.7122436997016454,
0.7600255476588682,
0.769465992253443,
0.7552507233110458,
0.7373719614604455,
0.7681449827665134,
0.7760194768221379,
0.8035677492769473,
0.7771752906080505,
0.6514739683312004,
0.744787832649546,
0.7700232252518491,
0.6901772464238046,
0.721128796446804,
0.686814078974691,
0.796786900996621,
0.8242176320872988,
0.7806742901384496,
0.7697696656361902,
0.7868668497422822,
0.7304548331410279,
0.7296767182508448,
0.699401040331297,
0.8332660282838579,
0.6513224421864793,
0.6927364371405096,
0.7491793002300279,
0.7909034218879171,
0.7754176150761527,
0.8004827006684735,
0.6659923526948868,
0.8129618884980553,
0.7496476113488948,
0.6519665061955423,
0.7319506042291191,
0.7367099004846358,
0.8590817275589286,
0.6684490623528697,
0.7469140136676841,
0.7269016830127544,
0.6834261236240542,
0.6390380393123325,
0.7129209978685723,
0.6879274953497815,
0.6720427402498903,
0.7343923261152341,
0.5977941468082364,
0.7574521074337901,
0.7302129263006347,
0.7380209108764002,
0.7280177437983031,
0.6870880300968067,
0.6928024686173956,
0.6403900073451696,
0.7209247906698092,
0.666799070142464,
0.7233569397428488,
0.7555267048881131,
0.7275931437975757,
0.7562785127332186,
0.736649021802634,
0.762569980781232,
0.7741923768467133,
0.6669773662198453,
0.7499779113599104,
0.7783371410262362,
0.7798574909618832,
0.7068277719647446,
0.7718066533969561,
0.7078874155127759,
0.7238814624322412,
0.8729026036201937,
0.7106759057210158,
0.7585767440008555,
0.6923565683937588,
0.6693898561996529,
0.7219944003835542,
0.7188064794656569,
0.7491451951131076,
0.6750197249394477,
0.7009868838756784,
0.6519589230722103,
0.775109860666869,
0.6682014813660948,
0.6618358724923147,
0.6218362070243146,
0.7518134154555007,
0.7571307427362168,
0.7546823360234822,
0.7416435744760188,
0.7676464608286054,
0.5799855321635428,
0.796191792361925,
0.6845373357552841,
0.7667984177770908,
0.7393609881148844,
0.7544580057165962,
0.7397495323890664,
0.7658298386950907,
0.6611125314078552,
0.7977728876184589,
0.831090991824215,
0.6982347705041605,
0.6312226759947304,
0.6482907496231076,
0.7658362595299695,
0.7518400125995974,
0.8025480920087656,
0.7461501494015976,
0.7239149641844659,
0.7090055516721123,
0.6846622430587006,
0.7112212118684095,
0.6794460508450644,
0.8344462829585355,
0.7638782230495382,
0.6558255308015893,
0.6799836520005059,
0.7148088830660722,
0.7658007439571484,
0.6581857665503666,
0.648734015773256,
0.8156725527309459,
0.6929202590284855,
0.7490919589112273,
0.7090723359022927,
0.7105572886194162,
0.7461374422467866,
0.7084742440808511,
0.6889378818898818,
0.7551565238112727,
0.7198789279593328,
0.7270482774940944,
0.6971190427893721,
0.7391610904601401,
0.6344734499604212,
0.6719507302378157,
0.6861159059531602,
0.6516118314317232,
0.7199095105818035,
0.6817881072485698,
0.7207373940156253,
0.7745467156569463,
0.6258059783246434,
0.7481056513643776,
0.7183327989261993,
0.7624705969064652,
0.703717229048042,
0.7487365873620485,
0.7495555007485976,
0.6409624588579408,
0.7176245775200687,
0.7537100520717849,
0.7569868075002099,
0.7392135510982174,
0.7188471960732984,
0.697413550280765,
0.7138614496887067,
0.7057641350396069,
0.7675079665142936,
0.7310427541091186,
0.7808018818735418,
0.7567255895826445,
0.7035962262766099,
0.7040750813384501,
0.8183159919038065,
0.7953911933398697,
0.7464891038038547,
0.6751591827598264,
0.7849943676377955,
0.7155963442284841,
0.7428993249606122,
0.7131100645054201,
0.7227595311429733,
0.6519548345531954,
0.7201522118536183,
0.86540799716664,
0.8128819241371503,
0.7278912446692862,
0.7305867175950502,
0.7171875192153516,
0.6755179003509543,
0.7256221359402913,
0.7003814947129137,
0.7486334697158199,
0.7232489166529666,
0.7347697330652992,
0.6493837702986368,
0.6454310256268904,
0.7085305859966166,
0.7709963397003181,
0.7628122486461532,
0.7260869667056613,
0.7656074896314918,
0.7309944223135776,
0.7575162043117213,
0.7425954755181217,
0.7978452334414571,
0.7414129597626153,
0.7369987033441427,
0.7249664966482501,
0.663939118162477,
0.7490232329485363,
0.7532303509685747,
0.6505502824396713,
0.6820403873171862,
0.7458589415356044,
0.65106761846338,
0.8190794585362886,
0.6404595320063431,
0.7620011212588133,
0.6793344580417779,
0.7470455239529016,
0.6254743025101126,
0.714021296346165,
0.7624376857959331,
0.6124088443110871,
0.7190909082546953,
0.7667977482791689,
0.7390282391537781,
0.731411776312689,
0.6910654011387792,
0.7669731150300952,
0.7473599833172766,
0.7757742306046117,
0.7524096654164605,
0.668078105815945,
0.6797298899209671,
0.734572374544983,
0.7851097143298676,
0.7342031538272321,
0.7372538892798539,
0.7335209046034747,
0.6838366965805461,
0.6892908218001774,
0.7368799079039347,
0.667817778038092,
0.7436021083467266,
0.643687287978847,
0.6459012534104801,
0.717961524900317,
0.651811280493696,
0.774663511332172,
0.6481450536649574,
0.7135017154296592,
0.718243043672562,
0.6840974559559992,
0.6237039667424505,
0.7511301324631847,
0.6731554748777204,
0.7311433676647258,
0.7407740108803432,
0.7219713524118947,
0.6945284597258313,
0.7403964086997485,
0.7281416758951161,
0.767616849591671,
0.7623013461443361,
0.794278871148869,
0.7094595086377886,
0.6363511838820268,
0.6955392554405656,
0.7448982185809608,
0.7328610107493276,
0.7208624134565648,
0.6762963031433926,
0.7755406771546928,
0.7045876515082797,
0.6244682388287953,
0.6742169930443296,
0.8182351707309709,
0.7329206562104693,
0.7750198586773644,
0.7686712995024062,
0.7382706738683358,
0.6670365140953741,
0.7122098843228497,
0.720363344417827,
0.7260325476617868,
0.7455849615803621,
0.7135971488970759,
0.7597698332957011,
0.7261113201174606,
0.7802411718313292,
0.6937507195359754,
0.7842882648198792,
0.6900501446041246,
0.760860218311663,
0.7134088358525988,
0.7053629634417926
],
"xaxis": "x2",
"yaxis": "y2"
},
{
"alignmentgroup": "True",
"bingroup": "x",
"hovertemplate": "label=-1<br>dataset=test<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
"legendgroup": "-1",
"marker": {
"color": "#EF553B",
"opacity": 0.5,
"pattern": {
"shape": ""
}
},
"name": "-1",
"offsetgroup": "-1",
"orientation": "v",
"showlegend": false,
"type": "histogram",
"x": [
0.698216609902851,
0.7175312484264973,
0.7186004945024,
0.8000305631283213,
0.6982596730429885,
0.7632305498724974,
0.7138465594512551,
0.6788567152998355,
0.7373203069430523,
0.6873036529619025,
0.6503465541274354,
0.7365009792588864,
0.741093678604772,
0.735107851927508,
0.7789882788330387,
0.6997359150722298,
0.7996799028712657,
0.7467428244049107,
0.6687862526227949,
0.743064589979991,
0.8567601871608912,
0.7518934598338161,
0.7026922402796603,
0.7211365080382613,
0.706354045904417,
0.779427003097291,
0.7362962393365906,
0.7291751132805836,
0.7122378769301289,
0.7140477948952889,
0.7173405960056081,
0.7634875929311733,
0.7977581390590794,
0.7301463738402032,
0.7615332057968679,
0.682227985030492,
0.7334635989895589,
0.7386028577038499,
0.659084291835728,
0.7899820114980755,
0.7247172517704278,
0.7438155210487466,
0.7005346860296618,
0.6648553956533266,
0.7701566644966253,
0.7514961574415904,
0.7587991656983686,
0.7001882273133521,
0.6910707516646375,
0.7394355361240693,
0.7276824899179835,
0.6759744362016779,
0.8302185470787163,
0.6928641094502374,
0.7120538723839331,
0.7224960785372221,
0.7198045816069757,
0.6813558762031737,
0.7165801628045559,
0.7120832723046919,
0.8420619371167495,
0.946387336435492,
0.7554566849916029,
0.6539169025880401,
0.7809846972343978,
0.7724403150471626,
0.7005276086814838,
0.6393757830582598,
0.8206678516495857,
0.7220887623700465,
0.6457268309147112,
0.7355556783165726,
0.7154485704709247,
0.736485552747626,
0.6279868336962336,
0.6826307018159569,
0.6893086023402188,
0.6662701662224271,
0.7867533724923507,
0.7767747518359883,
0.8265509786609969,
0.7191298707058883,
0.7022356632617615,
0.7327905404619434,
0.744068997548466,
0.7610196098150734,
0.7115456073990181,
0.7432332956176703,
0.7893785433034154,
0.7290401089931655,
0.6577253300110991,
0.7003577024217508,
0.6656632326716906,
0.777774068734494,
0.7332058736487923,
0.6832255453323396,
0.7959026283113317,
0.7447573588360193,
0.6641622569109379,
0.6536973982676278,
0.6665423422110455,
0.6839542281076658,
0.7423443203765382,
0.7548386221527327,
0.6027110439271006,
0.6910860362919979,
0.6562817652314046,
0.716178983425107,
0.724403115068019,
0.7259909803954812,
0.7571530348849562,
0.733442357026098,
0.7422263713152575,
0.7194490356714829,
0.8324274302968815,
0.7326854438693811,
0.6965534375434462,
0.6112368704742529,
0.7304384905441154,
0.7720040972080942,
0.6935032899149339,
0.6744196671055273,
0.7147585872166342,
0.7752945283086001,
0.7247252468203336,
0.7487128692844189,
0.7073695766484023,
0.7770002807438351,
0.6765039364068596,
0.6337555851438214,
0.7894145395726685,
0.808127657824718,
0.6806212157836168,
0.6631556694778886,
0.7026864982467919,
0.6616453779258454,
0.7884566439105933,
0.7066759939380831,
0.6337579382878964,
0.6487574469000645,
0.772145520185077,
0.7232462166936153,
0.7901145921506159,
0.6972515066616096,
0.7122390390469722,
0.7192738985202007,
0.6592857746430886,
0.7609468966795828,
0.6603956483288737,
0.7457664593808938,
0.6826889303879785,
0.7934081087742163,
0.6763383894378445,
0.6839731303810913,
0.6687692061586001,
0.6991702352595418,
0.7797875942740597,
0.7663428993137384,
0.7674609664886143,
0.7252797254797089,
0.6976652869077468,
0.7158816756428659,
0.718241456712601,
0.7632800029357677,
0.7220393239650442,
0.7734594213360495,
0.6831354341128338,
0.8050740585092945,
0.8304016729411579,
0.6942297180429666,
0.7777210539532041,
0.7971492766742209,
0.7551377672686813,
0.756111346085797,
0.7377139073001521,
0.7292914327831885,
0.6619852966467218,
0.6633387313486273,
0.7753943622795527,
0.6931366760284077,
0.6879938635549627,
0.7934277451864097,
0.7095961561996172,
0.6721647904665006,
0.6639348167332335,
0.7190874751718003,
0.6487682662731885,
0.7712237228183192,
0.7195541308325332,
0.7624245070018526,
0.7066568592895756,
0.6955819267956074,
0.762644689565747,
0.696550099953965,
0.72160309057877,
0.6589853583691466,
0.7781076723886485,
0.7844353288099747,
0.6499942545071061,
0.7586643115109818,
0.7851245713510661,
0.6825431110733673,
0.7920473550917971,
0.7505505200683399,
0.7112992195413225,
0.6872424297540068,
0.6629403755138552,
0.7754417757832819,
0.7445419843378304,
0.7064660255551196,
0.7102764764345906,
0.7166584523700176,
0.745706231193076,
0.7628022956052315,
0.6960882904963708,
0.6837468719098787,
0.7520487263229376,
0.7604129787986132,
0.7522054011542562,
0.6898973175025894,
0.712053903439596,
0.7339122116041967,
0.6789458999380491,
0.763449576758855,
0.7406926320689465,
0.6976029213628422,
0.7734670110437211,
0.7670042286239444,
0.7194794729796166,
0.8039337600845294,
0.822887238017454,
0.7546355748553022,
0.712175872919759,
0.6283408438534046,
0.7277722488820856,
0.7848289730316366,
0.6470175548703326,
0.7175970646556328,
0.6982508276209234,
0.6931585425755658,
0.7105706262503181,
0.6541269887120206,
0.8404400660965669,
0.7278163563567496,
0.8022018950050529,
0.7449136144274442,
0.7254549036563473,
0.7708530061010334,
0.7226568414320923,
0.7174427406313559,
0.7242867487973796,
0.7465548093208114,
0.6481248100208876,
0.715420475795481,
0.7818995378198694,
0.710140386310007,
0.779070581516581,
0.7022765286826387,
0.7920014977874291,
0.7028249663680173,
0.7464011128583546,
0.6874796697586977,
0.7834259382246206,
0.7487992683674399,
0.6256280566008573,
0.7248416704075088,
0.6787298488859722,
0.7604099689852636,
0.7563309422531382,
0.6536692254847523,
0.7277402265583088,
0.7961681595257355,
0.7183023940359037,
0.8578324194628058,
0.6890352710148969,
0.711616174722821,
0.6560825239577808,
0.7235723022023411,
0.6649236563739442,
0.6800589793852263,
0.7785177771732936,
0.8277144895457349,
0.7047472917613488,
0.6981993581133783,
0.69214194925211,
0.7175225477364914,
0.6821700037384272,
0.6934882153010823,
0.6671459724713757,
0.7577056235720239,
0.7452347481085622,
0.7540647847248615,
0.7623045528688165,
0.8419961516314336,
0.7607631404114222,
0.7104592749279437,
0.7907219223857475,
0.6626370861321504,
0.7293323972767943,
0.6747790790615374,
0.7205393564241827,
0.7182141925188444,
0.6698620455142402,
0.7774615433926413,
0.68441903533162,
0.7195176784475413,
0.7765542578440102,
0.7653003235671018,
0.6588957811563716,
0.7049538466814345,
0.6767019827253833,
0.6852115350974048,
0.7159808946533304,
0.6275008181698795,
0.6641598464064111,
0.7653064009797307,
0.7846062245731015,
0.7131195190890488,
0.7388407888274415,
0.7078575690975646,
0.7922969773693673,
0.6399205949452071,
0.7522331808600956,
0.756127025845852,
0.7527950868321375,
0.7791392496558245,
0.7388745760013306,
0.6739605493576779,
0.6432673241279167,
0.7124181751534668,
0.669456709871883,
0.7067049471522552,
0.6685698115209102,
0.7430777123010961,
0.7510627360284545
],
"xaxis": "x",
"yaxis": "y"
}
],
"layout": {
"annotations": [
{
"font": {},
"showarrow": false,
"text": "dataset=test",
"textangle": 90,
"x": 0.98,
"xanchor": "left",
"xref": "paper",
"y": 0.2425,
"yanchor": "middle",
"yref": "paper"
},
{
"font": {},
"showarrow": false,
"text": "dataset=train",
"textangle": 90,
"x": 0.98,
"xanchor": "left",
"xref": "paper",
"y": 0.7575000000000001,
"yanchor": "middle",
"yref": "paper"
}
],
"barmode": "overlay",
"legend": {
"title": {
"text": "label"
},
"tracegroupgap": 0
},
"margin": {
"t": 60
},
"template": {
"data": {
"bar": [
{
"error_x": {
"color": "#2a3f5f"
},
"error_y": {
"color": "#2a3f5f"
},
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "bar"
}
],
"barpolar": [
{
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "barpolar"
}
],
"carpet": [
{
"aaxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"baxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"type": "carpet"
}
],
"choropleth": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "choropleth"
}
],
"contour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "contour"
}
],
"contourcarpet": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "contourcarpet"
}
],
"heatmap": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmap"
}
],
"heatmapgl": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmapgl"
}
],
"histogram": [
{
"marker": {
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "histogram"
}
],
"histogram2d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2d"
}
],
"histogram2dcontour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2dcontour"
}
],
"mesh3d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "mesh3d"
}
],
"parcoords": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "parcoords"
}
],
"pie": [
{
"automargin": true,
"type": "pie"
}
],
"scatter": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter"
}
],
"scatter3d": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter3d"
}
],
"scattercarpet": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattercarpet"
}
],
"scattergeo": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergeo"
}
],
"scattergl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergl"
}
],
"scattermapbox": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattermapbox"
}
],
"scatterpolar": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolar"
}
],
"scatterpolargl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolargl"
}
],
"scatterternary": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterternary"
}
],
"surface": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "surface"
}
],
"table": [
{
"cells": {
"fill": {
"color": "#EBF0F8"
},
"line": {
"color": "white"
}
},
"header": {
"fill": {
"color": "#C8D4E3"
},
"line": {
"color": "white"
}
},
"type": "table"
}
]
},
"layout": {
"annotationdefaults": {
"arrowcolor": "#2a3f5f",
"arrowhead": 0,
"arrowwidth": 1
},
"autotypenumbers": "strict",
"coloraxis": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"colorscale": {
"diverging": [
[
0,
"#8e0152"
],
[
0.1,
"#c51b7d"
],
[
0.2,
"#de77ae"
],
[
0.3,
"#f1b6da"
],
[
0.4,
"#fde0ef"
],
[
0.5,
"#f7f7f7"
],
[
0.6,
"#e6f5d0"
],
[
0.7,
"#b8e186"
],
[
0.8,
"#7fbc41"
],
[
0.9,
"#4d9221"
],
[
1,
"#276419"
]
],
"sequential": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"sequentialminus": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
]
},
"colorway": [
"#636efa",
"#EF553B",
"#00cc96",
"#ab63fa",
"#FFA15A",
"#19d3f3",
"#FF6692",
"#B6E880",
"#FF97FF",
"#FECB52"
],
"font": {
"color": "#2a3f5f"
},
"geo": {
"bgcolor": "white",
"lakecolor": "white",
"landcolor": "#E5ECF6",
"showlakes": true,
"showland": true,
"subunitcolor": "white"
},
"hoverlabel": {
"align": "left"
},
"hovermode": "closest",
"mapbox": {
"style": "light"
},
"paper_bgcolor": "white",
"plot_bgcolor": "#E5ECF6",
"polar": {
"angularaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"radialaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"scene": {
"xaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"yaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"zaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
}
},
"shapedefaults": {
"line": {
"color": "#2a3f5f"
}
},
"ternary": {
"aaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"baxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"caxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"title": {
"x": 0.05
},
"xaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
},
"yaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
}
}
},
"width": 500,
"xaxis": {
"anchor": "y",
"domain": [
0,
0.98
],
"title": {
"text": "cosine_similarity"
}
},
"xaxis2": {
"anchor": "y2",
"domain": [
0,
0.98
],
"matches": "x",
"showticklabels": false
},
"yaxis": {
"anchor": "x",
"domain": [
0,
0.485
],
"title": {
"text": "count"
}
},
"yaxis2": {
"anchor": "x2",
"domain": [
0.515,
1
],
"matches": "y",
"title": {
"text": "count"
}
}
}
}
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 88.8% ± 2.4%\n"
]
},
{
"data": {
"application/vnd.plotly.v1+json": {
"config": {
"plotlyServerURL": "https://plot.ly"
},
"data": [
{
"alignmentgroup": "True",
"bingroup": "x",
"hovertemplate": "label=1<br>dataset=train<br>cosine_similarity_custom=%{x}<br>count=%{y}<extra></extra>",
"legendgroup": "1",
"marker": {
"color": "#636efa",
"opacity": 0.5,
"pattern": {
"shape": ""
}
},
"name": "1",
"offsetgroup": "1",
"orientation": "v",
"showlegend": true,
"type": "histogram",
"x": [
0.7582729,
0.77073956,
0.65517455,
0.5354905,
0.7025479,
0.709722,
0.5727541,
0.412897,
0.30631626,
0.7510688,
0.4988944,
0.5430158,
0.6839247,
0.6838065,
0.73403287,
0.7992136,
0.94956607,
0.6076832,
0.6175979,
0.4301836,
0.41434872,
0.73805547,
0.8479354,
0.92995274,
0.6952341,
0.48424515,
0.37089303,
0.77779347,
0.46015453,
0.64896333,
0.5164182,
0.7807624,
0.744742,
0.60417265,
0.59455043,
0.69891477,
0.37484196,
0.42779225,
0.6438146,
0.5732102,
0.6928373,
0.8111064,
0.63694704,
0.8399712,
0.53589267,
0.72180825,
0.70790404,
0.6756934,
0.7727933,
0.7064891,
0.64490205,
0.8398692,
0.87921053,
0.27489004,
0.6942742,
0.57597476,
0.8717559,
0.9125966,
0.5301582,
0.60662705,
0.57393736,
0.6068608,
0.7286386,
0.70121765,
0.6917445,
0.8071895,
0.67622876,
0.79283214,
0.7993787,
0.72133046,
0.50220287,
0.57344776,
0.55474436,
0.6362008,
0.95821553,
0.49900997,
0.57039934,
0.733116,
0.78354275,
0.78056777,
0.05022569,
0.5561661,
0.77611494,
0.8746339,
0.66644144,
0.7035711,
0.6177985,
0.9615531,
0.54877394,
0.7237992,
0.80131793,
0.6017557,
0.5219256,
0.75939846,
0.5138731,
0.47525474,
0.8088046,
0.55364305,
0.8155018,
0.64197356,
0.78343874,
0.48175102,
0.7740524,
0.5643503,
0.5785323,
0.4045787,
0.844991,
0.78484446,
0.83895594,
0.779593,
0.8160813,
0.501844,
0.72787094,
0.74717385,
0.56744146,
0.23557578,
0.508822,
0.5945875,
0.30390128,
0.6806088,
0.73289305,
0.7714907,
0.35027343,
0.62142855,
0.89361686,
0.6107199,
0.80126244,
0.36980844,
0.8348204,
0.69843924,
0.40595374,
0.60189885,
0.67977715,
0.8707584,
0.94712615,
0.84084773,
0.38264382,
0.2997883,
0.82825613,
0.62465394,
0.8830344,
0.476409,
0.8145352,
0.5575372,
0.5749161,
0.8105687,
0.60813963,
0.8235246,
0.59305763,
0.35663584,
0.62649107,
0.3989369,
0.5393962,
0.85537386,
0.59931004,
0.7605111,
0.632939,
0.9163666,
0.4586151,
0.63326573,
0.82228065,
0.62719053,
0.44674626,
0.6449413,
0.7844299,
0.7109407,
0.42483827,
0.40851548,
0.6666175,
0.72658247,
0.29441452,
0.7422972,
0.6994506,
0.18581374,
0.678699,
0.6461491,
0.47299898,
0.3536248,
0.4810702,
0.74285,
0.31706694,
0.6093468,
0.49133518,
0.5890032,
0.4567156,
0.76003814,
0.60391974,
0.38637605,
0.5953039,
0.7995997,
0.8849446,
0.5667035,
0.65296185,
0.607227,
0.5879688,
0.71059513,
0.65723944,
0.607201,
0.2025656,
0.7741568,
0.8627294,
0.20287104,
0.7483574,
0.8585248,
0.47140092,
0.31818846,
0.57475,
0.79051167,
0.71378,
0.61821806,
0.63847613,
0.70790005,
0.70595086,
0.6324633,
0.8411397,
0.6010942,
0.6198959,
0.8601436,
0.506412,
0.8787982,
0.7363141,
0.5014173,
0.81275815,
0.43802464,
0.8606852,
0.90823215,
0.6048983,
0.5862797,
0.89965594,
0.21609125,
0.58538985,
0.74388975,
0.6436246,
0.68614113,
0.8133605,
0.8900723,
0.33205867,
0.5844268,
0.57690537,
0.55082875,
0.18703146,
0.7354434,
0.7188371,
0.73535186,
0.49179193,
0.60141474,
0.739191,
0.60902834,
0.5209424,
0.86558294,
0.54551184,
0.35854483,
0.74914163,
0.5216272,
0.9627959,
0.61666507,
0.36979172,
0.5368949,
0.43976095,
0.49865422,
0.8608738,
0.5655695,
0.81297505,
0.6252572,
0.6706519,
0.63214654,
0.50118244,
0.62449926,
0.70553565,
0.71667975,
0.6725781,
0.5215899,
0.64888376,
0.58514893,
0.8322057,
0.7555814,
0.62717474,
0.85005945,
0.5694707,
0.7400664,
0.94494027,
0.51169676,
0.34803283,
0.8389519,
0.5855034,
0.44634232,
0.78732824,
0.5224927,
0.8763063,
0.55516154,
0.77261907,
0.6934301,
0.39077526,
0.8408532,
0.8121041,
0.62674487,
0.51155764,
0.82464856,
0.17810936,
0.45419854,
0.90070426,
0.7507616,
0.6304127,
0.5382616,
0.2548589,
0.90965825,
0.5757083,
0.3881809,
0.47091305,
0.6259747,
0.549566,
0.78464717,
0.67623615,
0.22947542,
0.6341044,
0.47200313,
0.48406413,
0.54176223,
0.55456895,
0.70727295,
0.24309844,
0.5912585,
0.6418238,
0.7596782,
0.8406271,
0.38865283,
0.5345519,
0.67705786,
0.8471737,
0.78848296,
0.8296676,
0.6814048,
0.1439574,
0.76609397
],
"xaxis": "x2",
"yaxis": "y2"
},
{
"alignmentgroup": "True",
"bingroup": "x",
"hovertemplate": "label=1<br>dataset=test<br>cosine_similarity_custom=%{x}<br>count=%{y}<extra></extra>",
"legendgroup": "1",
"marker": {
"color": "#636efa",
"opacity": 0.5,
"pattern": {
"shape": ""
}
},
"name": "1",
"offsetgroup": "1",
"orientation": "v",
"showlegend": false,
"type": "histogram",
"x": [
0.8055613,
0.6490477,
0.53790224,
0.86933863,
0.7265006,
0.7500095,
0.5136602,
0.6108247,
0.6824276,
0.6863151,
0.81834275,
0.70145804,
0.6070372,
0.76361525,
0.47412366,
0.29016504,
0.7311523,
0.6745492,
0.5820018,
0.45364782,
0.7563268,
0.6551417,
0.8229127,
0.71966314,
0.65637314,
0.6056125,
0.40986726,
0.6351686,
0.8003523,
0.6638293,
0.5493354,
0.38898873,
0.89240587,
0.7085296,
0.6742074,
0.73229283,
0.6573821,
0.8563852,
0.73161244,
0.5125263,
0.73639345,
0.5608091,
0.6822197,
0.6509677,
0.77809495,
0.7077787,
0.42401087,
0.55358726,
0.96326643,
0.73163295,
0.7535971,
0.36794028,
0.25175893,
0.5587677,
0.9381429,
0.65131617,
0.707767,
0.404175,
0.738107,
0.5416304,
0.57372445,
0.76368237,
0.8565733,
0.7438249,
0.12734012,
0.88206637,
0.6993351,
0.6296972,
0.61613667,
0.7972349,
0.69560605,
0.6654787,
0.99999994,
0.48609376,
0.47301903,
0.55419683,
0.6296634,
0.4366933,
0.5387273,
0.7703018,
0.9061766,
0.936411,
0.67313033,
0.6194193,
0.37134537,
0.69975835,
0.47905326,
0.91620946,
0.7396347,
0.4008342,
0.69292057,
0.45092002,
0.48777297,
0.52802837,
0.821909,
0.57007617,
0.55647916,
0.5132801,
0.43960932,
0.67652094,
0.7167868,
0.7868749,
0.8414399,
0.30066487,
0.749063,
0.9248404,
0.29511026,
0.8829613,
0.9339626,
0.80454856,
0.66526824,
0.88261116,
0.8127759,
0.63907516,
0.5948788,
0.779537,
0.2734612,
0.91090304,
0.6181431,
0.8291171,
0.8786744,
0.45176527,
0.4650637,
0.52472,
0.6086645,
0.61167514,
0.618851,
0.59550864,
0.5683689,
0.54923105,
0.72792804,
0.5108976,
0.17235357,
0.8798169,
0.45582134,
0.6490313,
0.5987218,
0.8585789,
0.57275724,
0.6698848,
0.708311,
0.5580988,
0.5629255,
0.7218422,
0.61439306,
0.7661999,
0.5670317,
0.87961054,
0.57681245,
-0.025263075,
0.7305135,
0.72332287,
0.23190439,
0.49458456,
0.7727562,
0.63174117,
0.63469034,
0.24959645,
0.22246554,
0.6352285,
0.7173436,
0.8997552,
0.71509236,
0.68947405,
0.66640544,
0.2979133,
0.7765941,
0.61110735,
0.7824376,
0.45629472,
0.7814283,
0.94727206,
0.57099396,
0.7175106,
0.8542127,
0.7793517,
0.63087165,
0.3020511,
0.4713252,
0.8319655,
0.65548015,
0.43472487,
0.44796726,
0.73886794,
0.65364546,
0.5894948,
0.67677295,
0.34056446,
0.2893894,
0.67523193,
0.9004053,
0.7664299,
0.070429526,
0.7358473,
0.596747,
0.64145684,
0.70690864,
0.6161911,
0.63442343,
0.8034927,
0.65551484,
0.15030503,
0.6618132,
0.28821555,
0.6077155,
0.67070943,
0.23067664,
0.6557098,
0.55279106,
0.7644635,
0.28628072,
0.6755361,
0.5833505,
0.8240328,
0.67054516,
0.61436915,
0.51077014,
0.8207636,
0.61361855,
0.46712258,
0.40974268,
0.5336698,
0.558656,
0.97240293,
0.32733288,
0.6863151,
0.4985424,
0.79023767,
0.54902637,
0.34437552,
0.5104175,
0.1298324,
0.8800378,
0.8841439,
0.61748284,
0.9165488,
0.6865411,
0.8628597,
0.5251879,
0.8125164,
0.673494,
0.6672625,
0.73910636,
0.49335945,
0.27169147,
0.27906692,
0.676355,
0.72769314,
0.009032255,
0.86168426,
0.554593,
0.74324536,
0.78169584,
0.7702316,
0.95651525,
0.66262144,
0.6954001,
0.5656505,
0.8851712,
0.8271407,
0.7820752,
0.5980351,
0.37550944,
-0.056717057,
0.6385865,
0.9871788,
0.8663082,
0.9173605,
0.7812852,
0.8634348,
0.6151981,
0.4764769,
0.7129371,
0.78437144,
0.55460554,
0.765633,
0.3909395,
0.5777054,
0.46684957,
0.42080644,
0.73171866,
0.7438681,
0.46907872,
0.5971506,
0.35307956,
0.70854855,
0.14134333,
0.63074076,
0.6841741,
0.69037616,
0.8382874,
0.8383749,
0.8832028,
0.7200073,
0.3631567,
0.3802071,
0.7435165,
0.31047463,
0.7060255,
0.78007823,
0.42565107,
0.58386827,
0.4447159,
0.5319243,
0.2536062,
0.7522945,
0.6332803,
0.7163948,
0.71887827,
0.7136859,
0.70331776,
0.7141161,
0.51288867,
1,
0.65905374,
0.558384,
0.543979,
0.6952025,
0.36829284,
0.66087174,
0.7381764,
0.8295663,
0.83220756,
0.49746192,
1,
0.5527734,
0.3635958,
0.552087,
0.4612422,
0.55174565,
0.37825137,
0.6775494,
0.28022328,
0.8762227,
0.6699667
],
"xaxis": "x",
"yaxis": "y"
},
{
"alignmentgroup": "True",
"bingroup": "x",
"hovertemplate": "label=-1<br>dataset=train<br>cosine_similarity_custom=%{x}<br>count=%{y}<extra></extra>",
"legendgroup": "-1",
"marker": {
"color": "#EF553B",
"opacity": 0.5,
"pattern": {
"shape": ""
}
},
"name": "-1",
"offsetgroup": "-1",
"orientation": "v",
"showlegend": true,
"type": "histogram",
"x": [
0.14302981,
-0.11057069,
-0.17724767,
-0.24500036,
-0.062143553,
0.14316264,
0.15678164,
0.23773395,
-0.07846958,
-0.05204554,
-0.06655996,
-0.1291204,
0.443245,
-0.23869778,
-0.1645332,
-0.11218425,
0.23315865,
-0.38848343,
0.114264876,
-0.05154337,
0.18390255,
-0.27475196,
-0.05900202,
0.22430505,
-0.08546959,
-0.19892967,
-0.009420975,
-0.025093991,
-0.031790845,
-0.0051786974,
-0.076038264,
-0.024751754,
0.18652028,
0.011254758,
0.108140975,
0.09609464,
-0.07145849,
-0.06851679,
0.26567802,
-0.08995881,
-0.0607494,
0.10250999,
-0.010349991,
0.025486441,
-0.036893804,
0.020803962,
0.072955474,
0.12435959,
-0.052231353,
0.37205556,
0.17689727,
-0.0055993614,
0.19630852,
-0.26762488,
-0.0044234265,
0.021829708,
-0.09977905,
-0.103702135,
-0.08513733,
0.43749526,
0.36857948,
0.10853024,
0.13605508,
0.15389544,
-0.13927853,
-0.14963916,
-0.3644624,
0.52444035,
-0.40716577,
-0.084710844,
-0.07702993,
0.28454077,
0.17952761,
0.19037494,
-0.119846344,
0.3372389,
-0.25199848,
-0.14442691,
-0.21720155,
-0.09363778,
0.43611217,
-0.25409278,
0.059805665,
0.08246046,
-0.13380173,
-0.13812494,
-0.050808128,
-0.06082993,
-0.29077193,
-0.28214246,
-0.26683524,
0.15480486,
-0.19142698,
-0.07857721,
0.059667934,
-0.27533692,
0.0012983828,
-0.41322538,
-0.042272426,
-0.14912425,
-0.13140962,
0.33333275,
-0.07854002,
0.17386948,
0.12033723,
0.16654651,
-0.10265535,
-0.12654941,
-0.029241713,
0.14599761,
0.020325862,
-0.07572659,
-0.044500634,
-0.25713974,
0.12043542,
0.57479286,
0.00869727,
-0.054210614,
-0.1706504,
-0.32414228,
-0.2668199,
-0.26956692,
-0.05117539,
-0.06610332,
-0.10072504,
-0.17387511,
-0.08241673,
-0.15317225,
-0.1389603,
-0.23053482,
-0.03335168,
0.0077261436,
-0.108453214,
-0.16304685,
0.101268604,
-0.48224252,
0.35240975,
-0.12512998,
-0.078767695,
-0.02886478,
-0.11151735,
0.21486124,
0.06401499,
-0.07141201,
0.24201383,
0.38163814,
-0.18371437,
-0.33098587,
-0.17749745,
0.088400126,
0.0389295,
0.30001613,
0.14689615,
0.2980527,
-0.23508236,
-0.20070519,
-0.20603697,
0.09391333,
0.413654,
0.057752814,
-0.0879763,
-0.085120164,
-0.1986266,
0.112001404,
-0.24664646,
-0.2028898,
0.21875925,
-0.21640524,
0.040345095,
-0.22236083,
-0.028793441,
0.026597003,
-0.17050122,
-0.17038809,
-0.02601072,
-0.11057518,
0.04983874,
-0.09951373,
-0.05255925,
-0.21587878,
-0.042706307,
-0.13658924,
-0.27548197,
-0.080267474,
0.08046054,
-0.1778766,
0.05119922,
-0.29639244,
-0.13204463,
0.019229155,
-0.029514017,
-0.168669,
-0.1870547,
0.030565858,
-0.3171055,
-0.03315684,
-0.13100462,
-0.025322285,
-0.0279369,
-0.13280292,
-0.071367,
-0.17220059,
-0.11124851,
0.16763113,
0.06615649,
0.00736435,
-0.005432845,
0.0028309082,
-0.14925274,
0.29655105,
0.28778684,
0.1029157,
-0.22709346,
0.12467997,
-0.26696387,
-0.030379076,
-0.099905916,
0.030037213,
-0.19075271,
-0.1794763,
0.49299556,
0.3293071,
-0.10895353,
0.12664062,
-0.092162006,
-0.13777371,
-0.034772262,
-0.05871062,
0.026836451,
0.02839845,
0.13685997,
-0.045252584,
-0.14783676,
-0.1717493,
0.05255844,
0.3095329,
0.10439775,
0.28798553,
-0.12638304,
0.20911206,
0.012737125,
0.0957955,
-0.13159595,
0.063837364,
0.3295994,
-0.13373746,
0.29288846,
-0.029724248,
-0.0043318802,
-0.14764512,
-0.11397636,
-0.37147662,
0.34722033,
-0.18721312,
0.12718824,
-0.09654416,
-0.0766707,
-0.2977832,
-0.14676157,
0.26951638,
-0.082651325,
-0.27545825,
0.13462485,
-0.103285775,
0.06710768,
0.0015975507,
-0.14853989,
0.034589108,
0.04793806,
-0.0091663385,
-0.013575685,
-0.042185366,
-0.084682934,
0.058529433,
-0.088299036,
-0.04100522,
-0.1625054,
-0.27100295,
0.11074404,
0.092317745,
0.018090539,
-0.088989764,
-0.22713757,
-0.2248414,
-0.23065864,
-0.12176347,
0.4832462,
-0.22033268,
-0.065273635,
0.10010773,
-0.029044893,
-0.39913887,
0.122752,
-0.3432,
-0.15346213,
0.038391877,
-0.106522396,
-0.24408461,
-0.2250687,
0.13498728,
0.089975856,
-0.0010886639,
0.14626358,
-0.25450054,
-0.16206416,
-0.35668114,
-0.027360622,
-0.042533685,
-0.028601661,
-0.2779112,
-0.028973868,
-0.26466352,
-0.24255127,
-0.22748213,
0.28742003,
0.4260002,
0.12781607,
0.014489925,
0.09618315,
-0.13999633,
-0.056763954,
0.010561457,
0.021302488,
-0.11718368,
-0.13106598,
-0.06693572,
-0.22939485,
0.10453988,
-0.03318918,
0.29774678,
-0.30688155,
-0.006827436,
-0.023383835,
-0.24362536
],
"xaxis": "x2",
"yaxis": "y2"
},
{
"alignmentgroup": "True",
"bingroup": "x",
"hovertemplate": "label=-1<br>dataset=test<br>cosine_similarity_custom=%{x}<br>count=%{y}<extra></extra>",
"legendgroup": "-1",
"marker": {
"color": "#EF553B",
"opacity": 0.5,
"pattern": {
"shape": ""
}
},
"name": "-1",
"offsetgroup": "-1",
"orientation": "v",
"showlegend": false,
"type": "histogram",
"x": [
0.10617652,
-0.031975385,
-0.10883323,
0.32207173,
-0.31140262,
0.11705714,
-0.15120678,
-0.13226521,
0.32406852,
0.028926875,
-0.12701692,
0.14731352,
0.15226293,
-0.14873207,
0.044569373,
-0.195492,
0.31632814,
0.049260046,
-0.1941448,
0.12955306,
0.39960444,
0.10171481,
0.01060061,
-0.08512529,
0.20605738,
0.10371626,
0.031458076,
-0.01223364,
-0.22145674,
-0.09562346,
-0.038915318,
-0.036689993,
0.12696612,
-0.062234133,
0.23239751,
-0.16191426,
-0.112516925,
0.26545197,
-0.06769511,
0.017749624,
0.175943,
0.17765644,
-0.11271724,
-0.24490058,
0.23518375,
0.004909072,
-0.05079494,
-0.18667297,
0.0043339524,
-0.16339056,
-0.0025770266,
-0.032812953,
0.34513906,
-0.1910095,
0.07079941,
-0.07229687,
-0.09512156,
-0.16745691,
-0.1844601,
-0.101349175,
0.41526136,
0.81717646,
-0.0026841194,
-0.09593156,
0.0574869,
-0.024282647,
-0.013999336,
-0.40602204,
0.569725,
-0.053840384,
-0.20432736,
-0.008172799,
-0.00547958,
0.180258,
-0.12665462,
-0.21722375,
-0.19613074,
-0.023697829,
0.24777885,
0.1711332,
0.537012,
-0.087072074,
-0.07477147,
0.05843446,
0.025428122,
0.09440827,
-0.007145553,
-0.10771193,
0.25761753,
0.044429805,
-0.03361663,
-0.029891007,
-0.06391836,
0.1614778,
0.0022980184,
0.009103167,
0.03071209,
-0.1727816,
-0.20178483,
-0.13833278,
-0.22644272,
-0.23083316,
0.03205073,
0.33865592,
-0.11224474,
-0.18638353,
-0.15613213,
0.13247007,
0.091726765,
-0.08427659,
0.04926182,
-0.06531185,
0.08664228,
0.061082914,
0.3546836,
0.13567112,
-0.12794621,
-0.25035465,
-0.101458125,
-0.07636069,
-0.00021772343,
-0.06313226,
-0.14774837,
0.015537703,
-0.05826063,
0.19063057,
-0.18808654,
0.09029963,
-0.2199365,
-0.3122693,
0.23579088,
0.31103528,
-0.041199364,
-0.00906378,
0.21824978,
-0.08886902,
0.18546428,
0.20851673,
-0.15233244,
-0.3224695,
-0.0325812,
-0.04210101,
0.16766939,
0.26867718,
0.111526266,
-0.036602326,
-0.14308666,
0.15467033,
-0.24415994,
0.24594295,
-0.24345881,
0.32660398,
0.03280199,
-0.09476015,
-0.29096526,
0.032068066,
0.17987496,
0.03262933,
0.20223673,
-0.120644696,
0.29335102,
-0.14304209,
-0.021362929,
0.06404272,
0.05794023,
0.0831995,
0.0718886,
0.2677285,
0.35200343,
-0.035561264,
0.18149698,
0.23688413,
-0.1472639,
0.057499476,
0.08719631,
-0.09575894,
-0.08939485,
-0.14045444,
0.026247796,
-0.11188459,
-0.07919278,
0.31483328,
0.039055776,
-0.12993088,
-0.1622256,
0.31707028,
-0.3364733,
0.24035016,
-0.027344713,
0.014690834,
-0.16855082,
-0.06932548,
-0.022266265,
-0.22546014,
0.008567488,
-0.102743,
0.13508673,
0.15486145,
-0.2813347,
0.1068097,
0.13032803,
-0.019828802,
0.11795447,
0.4221859,
0.000759042,
0.09165115,
-0.03338406,
0.08470532,
-0.1508812,
0.035786882,
0.030124856,
0.19506164,
0.26077172,
-0.124081574,
-0.31896383,
0.10112047,
-0.057609923,
-0.00080296077,
0.23779202,
-0.23625362,
-0.08160995,
0.3032903,
-0.08120042,
0.16656089,
0.10409328,
0.18034437,
0.046120126,
0.1544962,
-0.09288298,
0.2510819,
0.29124844,
0.17870253,
-0.005379772,
-0.30627653,
-0.06947586,
0.15925092,
-0.13554731,
0.043542545,
0.17783535,
-0.20150776,
-0.11547584,
-0.16441183,
0.4270623,
0.22010702,
0.25666305,
-0.04944679,
0.057062034,
0.0508959,
-0.16835836,
-0.19262494,
-0.1199862,
0.15262641,
-0.046507467,
0.034018174,
0.013183228,
0.12422307,
-0.047742013,
0.008541771,
0.28857812,
-0.062908895,
0.08314598,
-0.35084343,
0.061281245,
0.0785819,
-0.08573331,
-0.13729927,
-0.025910659,
0.11283098,
0.05220833,
-0.012076858,
-0.022399202,
0.27282238,
0.20037143,
0.49455974,
-0.25124055,
-0.15144002,
-0.17023633,
0.17223814,
-0.2117701,
-0.07220247,
0.031898964,
0.3294976,
-0.11548876,
0.11725828,
-0.20589651,
-0.20220196,
-0.120919146,
-0.17677295,
-0.07525886,
0.109793626,
-0.034417115,
0.16536875,
-0.023174297,
0.5393097,
-0.0039643934,
-0.07138218,
0.28882954,
-0.13014461,
-0.23138109,
-0.1327854,
0.20499766,
-0.055300497,
-0.0646167,
0.009579126,
0.041330144,
0.0024610406,
0.12112454,
0.07543635,
0.1213533,
0.08904015,
-0.08634728,
-0.06367639,
-0.06046631,
-0.19301783,
-0.14426412,
0.07088253,
0.38884938,
-0.06855903,
0.1762231,
-0.10999228,
0.14951923,
-0.29029095,
0.12920266,
0.035691362,
-0.14431979,
0.114437565,
-0.04307397,
-0.1095031,
-0.24052213,
-0.056761023,
-0.2482388,
-0.022682087,
0.02143445,
0.1074058,
0.3866141
],
"xaxis": "x",
"yaxis": "y"
}
],
"layout": {
"annotations": [
{
"font": {},
"showarrow": false,
"text": "dataset=test",
"textangle": 90,
"x": 0.98,
"xanchor": "left",
"xref": "paper",
"y": 0.2425,
"yanchor": "middle",
"yref": "paper"
},
{
"font": {},
"showarrow": false,
"text": "dataset=train",
"textangle": 90,
"x": 0.98,
"xanchor": "left",
"xref": "paper",
"y": 0.7575000000000001,
"yanchor": "middle",
"yref": "paper"
}
],
"barmode": "overlay",
"legend": {
"title": {
"text": "label"
},
"tracegroupgap": 0
},
"margin": {
"t": 60
},
"template": {
"data": {
"bar": [
{
"error_x": {
"color": "#2a3f5f"
},
"error_y": {
"color": "#2a3f5f"
},
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "bar"
}
],
"barpolar": [
{
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "barpolar"
}
],
"carpet": [
{
"aaxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"baxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"type": "carpet"
}
],
"choropleth": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "choropleth"
}
],
"contour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "contour"
}
],
"contourcarpet": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "contourcarpet"
}
],
"heatmap": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmap"
}
],
"heatmapgl": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmapgl"
}
],
"histogram": [
{
"marker": {
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "histogram"
}
],
"histogram2d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2d"
}
],
"histogram2dcontour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2dcontour"
}
],
"mesh3d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "mesh3d"
}
],
"parcoords": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "parcoords"
}
],
"pie": [
{
"automargin": true,
"type": "pie"
}
],
"scatter": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter"
}
],
"scatter3d": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter3d"
}
],
"scattercarpet": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattercarpet"
}
],
"scattergeo": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergeo"
}
],
"scattergl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergl"
}
],
"scattermapbox": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattermapbox"
}
],
"scatterpolar": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolar"
}
],
"scatterpolargl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolargl"
}
],
"scatterternary": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterternary"
}
],
"surface": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "surface"
}
],
"table": [
{
"cells": {
"fill": {
"color": "#EBF0F8"
},
"line": {
"color": "white"
}
},
"header": {
"fill": {
"color": "#C8D4E3"
},
"line": {
"color": "white"
}
},
"type": "table"
}
]
},
"layout": {
"annotationdefaults": {
"arrowcolor": "#2a3f5f",
"arrowhead": 0,
"arrowwidth": 1
},
"autotypenumbers": "strict",
"coloraxis": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"colorscale": {
"diverging": [
[
0,
"#8e0152"
],
[
0.1,
"#c51b7d"
],
[
0.2,
"#de77ae"
],
[
0.3,
"#f1b6da"
],
[
0.4,
"#fde0ef"
],
[
0.5,
"#f7f7f7"
],
[
0.6,
"#e6f5d0"
],
[
0.7,
"#b8e186"
],
[
0.8,
"#7fbc41"
],
[
0.9,
"#4d9221"
],
[
1,
"#276419"
]
],
"sequential": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"sequentialminus": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
]
},
"colorway": [
"#636efa",
"#EF553B",
"#00cc96",
"#ab63fa",
"#FFA15A",
"#19d3f3",
"#FF6692",
"#B6E880",
"#FF97FF",
"#FECB52"
],
"font": {
"color": "#2a3f5f"
},
"geo": {
"bgcolor": "white",
"lakecolor": "white",
"landcolor": "#E5ECF6",
"showlakes": true,
"showland": true,
"subunitcolor": "white"
},
"hoverlabel": {
"align": "left"
},
"hovermode": "closest",
"mapbox": {
"style": "light"
},
"paper_bgcolor": "white",
"plot_bgcolor": "#E5ECF6",
"polar": {
"angularaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"radialaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"scene": {
"xaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"yaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"zaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
}
},
"shapedefaults": {
"line": {
"color": "#2a3f5f"
}
},
"ternary": {
"aaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"baxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"caxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"title": {
"x": 0.05
},
"xaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
},
"yaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
}
}
},
"width": 500,
"xaxis": {
"anchor": "y",
"domain": [
0,
0.98
],
"title": {
"text": "cosine_similarity_custom"
}
},
"xaxis2": {
"anchor": "y2",
"domain": [
0,
0.98
],
"matches": "x",
"showticklabels": false
},
"yaxis": {
"anchor": "x",
"domain": [
0,
0.485
],
"title": {
"text": "count"
}
},
"yaxis2": {
"anchor": "x2",
"domain": [
0.515,
1
],
"matches": "y",
"title": {
"text": "count"
}
}
}
}
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy after customization: 93.6% ± 1.9%\n"
]
}
],
"source": [
"# plot similarity distribution BEFORE customization\n",
"px.histogram(\n",
" df,\n",
" x=\"cosine_similarity\",\n",
" color=\"label\",\n",
" barmode=\"overlay\",\n",
" width=500,\n",
" facet_row=\"dataset\",\n",
").show()\n",
"\n",
"test_df = df[df[\"dataset\"] == \"test\"]\n",
"a, se = accuracy_and_se(test_df[\"cosine_similarity\"], test_df[\"label\"])\n",
"print(f\"Test accuracy: {a:0.1%} ± {1.96 * se:0.1%}\")\n",
"\n",
"# plot similarity distribution AFTER customization\n",
"px.histogram(\n",
" df,\n",
" x=\"cosine_similarity_custom\",\n",
" color=\"label\",\n",
" barmode=\"overlay\",\n",
" width=500,\n",
" facet_row=\"dataset\",\n",
").show()\n",
"\n",
"a, se = accuracy_and_se(test_df[\"cosine_similarity_custom\"], test_df[\"label\"])\n",
"print(f\"Test accuracy after customization: {a:0.1%} ± {1.96 * se:0.1%}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "XO7iqiVjpgkT",
"outputId": "a100a9e0-d5aa-46ab-b8a7-4ec6f7bd1cec"
},
"outputs": [
{
"data": {
"text/plain": [
"array([[-1.2566795e+00, -1.5297449e+00, -1.3271648e-01, ...,\n",
" -1.2859761e+00, -5.3254390e-01, 4.8364732e-01],\n",
" [-1.4826347e+00, 9.2656955e-02, -4.2437232e-01, ...,\n",
" 1.1872858e+00, -1.0831847e+00, -1.0683593e+00],\n",
" [-2.2029283e+00, -1.9703420e+00, 3.1125939e-01, ...,\n",
" 2.2947595e+00, 5.5780332e-03, -6.0171342e-01],\n",
" ...,\n",
" [-1.1019799e-01, 1.3599515e+00, -4.7677776e-01, ...,\n",
" 6.5626711e-01, 7.2359240e-01, 3.0733588e+00],\n",
" [ 1.6624762e-03, 4.2648423e-01, -1.1380885e+00, ...,\n",
" 8.7202555e-01, 9.3173909e-01, -1.6760436e+00],\n",
" [ 7.7449006e-01, 4.9213606e-01, 3.5407653e-01, ...,\n",
" 1.3460466e+00, -1.9509128e-01, 7.7514690e-01]], dtype=float32)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"best_matrix # this is what you can multiply your embeddings by\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"name": "customized_embeddings_example_with_synthetic_negatives.ipynb",
"provenance": []
},
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
},
"kernelspec": {
"display_name": "Python 3.9.9 ('openai')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 0
}