mirror of
https://github.com/mlabonne/llm-course.git
synced 2024-10-30 15:21:42 +00:00
228 lines
39 KiB
Plaintext
228 lines
39 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "view-in-github",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"<a href=\"https://colab.research.google.com/github/mlabonne/how-to-data-science/blob/main/Visualizing_GPT_2's_Loss_Landscape.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# Visualizing GPT-2's Loss Landscape\n",
|
|
"\n",
|
|
"❤️ Created by [@maximelabonne](https://twitter.com/maximelabonne).\n",
|
|
"\n",
|
|
"Simple perturbation-based calculation of the negative log-likelihood loss in two directions, given \"I have a dream\" as input.\n",
|
|
"\n",
|
|
"Reference: [Visualizing the Loss Landscape of Neural Nets](https://arxiv.org/abs/1712.09913), by Li et al. (2018)"
|
|
],
|
|
"metadata": {
|
|
"id": "Dgptqrg0zEY5"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"id": "lIYdn1woOS1n",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"outputId": "ffd96eb4-2861-4658-e89d-383eaeb36c3b"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.2/7.2 MB\u001b[0m \u001b[31m47.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
|
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m236.8/236.8 kB\u001b[0m \u001b[31m14.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
|
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m75.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
|
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m67.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
|
"\u001b[?25h"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!pip install -q transformers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"%%time\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"from transformers import GPT2LMHeadModel, GPT2Tokenizer\n",
|
|
"import numpy as np\n",
|
|
"import plotly.graph_objects as go\n",
|
|
"from tqdm import tqdm\n",
|
|
"import imageio\n",
|
|
"import os\n",
|
|
"\n",
|
|
"# Load pre-trained model\n",
|
|
"model_name = 'gpt2'\n",
|
|
"model = GPT2LMHeadModel.from_pretrained(model_name)\n",
|
|
"tokenizer = GPT2Tokenizer.from_pretrained(model_name)\n",
|
|
"\n",
|
|
"# Set model to evaluation mode\n",
|
|
"model.eval()\n",
|
|
"\n",
|
|
"# Define our input\n",
|
|
"input_text = \"I have a dream\"\n",
|
|
"inputs = tokenizer.encode_plus(input_text, return_tensors=\"pt\")\n",
|
|
"\n",
|
|
"# Compute the original loss\n",
|
|
"outputs = model(**inputs, labels=inputs[\"input_ids\"])\n",
|
|
"original_loss = outputs.loss.item()\n",
|
|
"\n",
|
|
"# Define two random directions\n",
|
|
"direction1 = [torch.randn_like(p) for p in model.parameters()]\n",
|
|
"direction2 = [torch.randn_like(p) for p in model.parameters()]\n",
|
|
"\n",
|
|
"# Normalize vectors\n",
|
|
"for p, d1, d2 in zip(model.parameters(), direction1, direction2):\n",
|
|
" norm_p = torch.linalg.norm(p.flatten())\n",
|
|
" d1.div_(torch.linalg.norm(d1.flatten())).mul_(norm_p)\n",
|
|
" d2.div_(torch.linalg.norm(d2.flatten())).mul_(norm_p)\n",
|
|
"\n",
|
|
"# Define the range to explore\n",
|
|
"x = np.linspace(-1, 1, 20)\n",
|
|
"y = np.linspace(-1, 1, 20)\n",
|
|
"X, Y = np.meshgrid(x, y)\n",
|
|
"\n",
|
|
"# Prepare to collect the losses\n",
|
|
"Z = np.zeros_like(X)\n",
|
|
"\n",
|
|
"# Compute loss for each direction\n",
|
|
"for i in tqdm(range(x.size), desc=\"x progress\"):\n",
|
|
" for j in tqdm(range(y.size), desc=\"y progress\", leave=False):\n",
|
|
" # Perturb the model parameters\n",
|
|
" for p, d1, d2 in zip(model.parameters(), direction1, direction2):\n",
|
|
" p.data.add_(x[i]*d1 + y[j]*d2)\n",
|
|
" \n",
|
|
" # Compute the loss\n",
|
|
" outputs = model(**inputs, labels=inputs['input_ids'])\n",
|
|
" Z[i, j] = outputs.loss.item()\n",
|
|
" \n",
|
|
" # Revert the model parameters\n",
|
|
" for p, d1, d2 in zip(model.parameters(), direction1, direction2):\n",
|
|
" p.data.sub_(x[i]*d1 + y[j]*d2)\n"
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "4Wud37sZa1Y7",
|
|
"outputId": "c4e9c839-2938-4df8-c54f-96f5792cc0ed"
|
|
},
|
|
"execution_count": 11,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stderr",
|
|
"text": [
|
|
"100%|██████████| 20/20 [12:42<00:00, 38.11s/it]\n"
|
|
]
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Create 3D plot\n",
|
|
"fig = go.Figure(data=[go.Surface(z=Z, x=X, y=Y, \n",
|
|
" showscale=False,)])\n",
|
|
"fig.update_layout(\n",
|
|
" title=\"GPT-2's Loss Landscape\",\n",
|
|
" autosize=True,\n",
|
|
" width=1000,\n",
|
|
" height=600,\n",
|
|
" # scene=dict(\n",
|
|
" # xaxis=dict(visible=False),\n",
|
|
" # yaxis=dict(visible=False),\n",
|
|
" # zaxis=dict(visible=False),\n",
|
|
" # )\n",
|
|
")\n",
|
|
"fig.show()"
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 617
|
|
},
|
|
"id": "5zcGwU4ji67L",
|
|
"outputId": "f149f7eb-abfb-4b06-aead-ace3772c5379"
|
|
},
|
|
"execution_count": 32,
|
|
"outputs": [
|
|
{
|
|
"output_type": "display_data",
|
|
"data": {
|
|
"text/html": [
|
|
"<html>\n",
|
|
"<head><meta charset=\"utf-8\" /></head>\n",
|
|
"<body>\n",
|
|
" <div> <script src=\"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_SVG\"></script><script type=\"text/javascript\">if (window.MathJax && window.MathJax.Hub && window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}</script> <script type=\"text/javascript\">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>\n",
|
|
" <script src=\"https://cdn.plot.ly/plotly-2.18.2.min.js\"></script> <div id=\"f166e894-966e-4e31-85e8-e626977b9a7f\" class=\"plotly-graph-div\" style=\"height:600px; width:1000px;\"></div> <script type=\"text/javascript\"> window.PLOTLYENV=window.PLOTLYENV || {}; if (document.getElementById(\"f166e894-966e-4e31-85e8-e626977b9a7f\")) { Plotly.newPlot( \"f166e894-966e-4e31-85e8-e626977b9a7f\", [{\"showscale\":false,\"x\":[[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0]],\"y\":[[-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0],[-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632,-0.8947368421052632],[-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263,-0.7894736842105263],[-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895,-0.6842105263157895],[-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527,-0.5789473684210527],[-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158,-0.4736842105263158],[-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579,-0.368421052631579],[-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215,-0.26315789473684215],[-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053,-0.1578947368421053],[-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474,-0.052631578947368474],[0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836,0.05263157894736836],[0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053,0.1578947368421053],[0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204,0.26315789473684204],[0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876,0.36842105263157876],[0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157,0.4736842105263157],[0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527,0.5789473684210527],[0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894,0.6842105263157894],[0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261,0.7894736842105261],[0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263,0.894736842105263],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]],\"z\":[[124.08859252929688,113.14456939697266,103.28240203857422,94.4238510131836,87.30879974365234,82.4040298461914,78.12612915039062,75.68881225585938,74.09918975830078,72.6692886352539,71.25695037841797,66.36454010009766,63.68756103515625,63.36842727661133,62.8138542175293,62.80119705200195,62.57830810546875,63.3427619934082,67.92422485351562,74.73457336425781],[122.952392578125,111.9377670288086,100.87358856201172,90.76703643798828,83.27447509765625,76.47315979003906,71.21940612792969,70.4721908569336,69.53540802001953,68.83709716796875,66.06153106689453,59.85396194458008,58.418731689453125,57.37763595581055,55.536678314208984,53.610076904296875,52.831512451171875,54.60457992553711,58.69700241088867,64.93589782714844],[122.18196868896484,109.2790756225586,97.35919952392578,86.6955337524414,77.68022918701172,69.21306610107422,63.30632400512695,62.368804931640625,62.70637512207031,63.374267578125,58.80855178833008,53.64743423461914,52.24203872680664,50.33491897583008,46.6462516784668,43.33146286010742,44.38566207885742,46.76642990112305,52.05381393432617,59.38408279418945],[120.36943817138672,106.97021484375,93.5135726928711,81.31269836425781,71.42364501953125,61.79465866088867,54.28132247924805,52.88187026977539,53.26190185546875,54.09297180175781,49.44447326660156,46.71321105957031,44.1356086730957,41.73820877075195,35.417510986328125,34.49213790893555,36.3136100769043,40.87216567993164,47.35464096069336,54.595123291015625],[116.46591186523438,103.7930679321289,89.98480987548828,76.76641082763672,66.09087371826172,55.41984176635742,44.969635009765625,41.581878662109375,42.358768463134766,45.69941329956055,40.58510208129883,39.72469711303711,35.148624420166016,32.051029205322266,25.919771194458008,27.014490127563477,31.265336990356445,37.1475715637207,43.88414001464844,51.1301383972168],[112.07797241210938,100.25452423095703,87.85061645507812,75.65353393554688,64.2408447265625,53.68330764770508,41.01625061035156,31.303848266601562,27.711809158325195,37.01311492919922,31.192136764526367,31.47548484802246,26.66499900817871,22.72065544128418,19.525075912475586,23.21742820739746,28.599946975708008,34.868526458740234,41.70822525024414,49.02685546875],[109.92477416992188,99.4665298461914,88.09374237060547,76.76531219482422,65.88333129882812,55.7800407409668,45.8292121887207,29.421175003051758,19.93309783935547,22.15559196472168,26.65926170349121,23.809167861938477,17.965394973754883,15.095309257507324,16.851932525634766,21.2945499420166,26.97292137145996,33.305755615234375,40.396976470947266,48.10235595703125],[111.68887329101562,99.18631744384766,87.61377716064453,76.50119018554688,65.69596099853516,56.5683479309082,49.60874938964844,37.333465576171875,19.218774795532227,14.587769508361816,15.539047241210938,15.001593589782715,11.779308319091797,12.917278289794922,15.809394836425781,20.308408737182617,26.035852432250977,32.643646240234375,40.069210052490234,48.31789016723633],[117.71771240234375,101.69365692138672,87.3941650390625,75.95032501220703,65.956787109375,57.6050910949707,49.75065612792969,40.915069580078125,17.368896484375,14.981727600097656,14.482171058654785,16.103912353515625,9.687844276428223,12.200180053710938,15.276966094970703,19.57913589477539,25.395662307739258,32.31587600708008,40.158382415771484,49.03691482543945],[118.1960220336914,103.9806137084961,89.5818099975586,76.6880111694336,66.22270965576172,58.12995910644531,48.7950325012207,37.264225006103516,14.356136322021484,8.659367561340332,7.356786727905273,13.174872398376465,12.816079139709473,11.342301368713379,14.775760650634766,19.029531478881836,24.5596866607666,31.630361557006836,39.78668212890625,49.1077880859375],[119.9597396850586,104.84210205078125,92.23938751220703,80.29007720947266,68.23020935058594,56.669403076171875,44.59339904785156,31.761032104492188,11.549056053161621,7.461862564086914,4.824077606201172,10.210274696350098,13.88726806640625,10.440917015075684,13.822260856628418,18.20302391052246,23.925195693969727,31.107412338256836,39.413204193115234,48.84455490112305],[125.70914459228516,109.4720458984375,97.4897689819336,88.52359771728516,77.2088851928711,63.65207290649414,49.28334045410156,35.76978302001953,18.345705032348633,12.721267700195312,10.522377967834473,6.378732204437256,16.288347244262695,9.830000877380371,13.098152160644531,17.514440536499023,23.164459228515625,30.32402229309082,38.69839859008789,48.12251663208008],[132.0594024658203,115.45111083984375,102.75662994384766,94.99996185302734,84.9287338256836,71.2254409790039,56.8242073059082,42.57889938354492,26.049936294555664,14.473036766052246,9.229398727416992,9.731602668762207,18.076862335205078,9.981461524963379,12.647774696350098,17.02879524230957,22.126724243164062,28.497629165649414,36.47734069824219,46.04225540161133],[137.8777313232422,122.7239761352539,109.92413330078125,98.54663848876953,88.29483795166016,74.89309692382812,62.7603759765625,49.20086669921875,35.3382453918457,22.581893920898438,23.33271026611328,25.858871459960938,19.89449119567871,12.72039794921875,14.089183807373047,18.10186195373535,22.778898239135742,28.125471115112305,34.3931770324707,42.67856979370117],[143.72544860839844,129.7960968017578,118.26416015625,106.40345001220703,96.23077392578125,81.42809295654297,68.5083999633789,56.28574752807617,46.05338668823242,39.848514556884766,37.25785446166992,33.0616569519043,25.557775497436523,17.1937313079834,16.964567184448242,21.668039321899414,26.427968978881836,31.44476890563965,37.063846588134766,43.386592864990234],[150.3450469970703,139.55322265625,128.7618408203125,116.16324615478516,105.4486312866211,90.4364013671875,76.72982025146484,65.39386749267578,57.195556640625,51.66306686401367,46.667694091796875,40.457435607910156,32.94401931762695,23.583776473999023,19.742597579956055,24.681692123413086,30.13435935974121,35.83777618408203,41.87080383300781,47.91225051879883],[160.451416015625,150.7127685546875,139.3443603515625,126.38531494140625,113.51985931396484,100.2462387084961,85.9513168334961,74.37586212158203,67.3287353515625,62.603179931640625,55.78438949584961,48.97542953491211,41.7438850402832,32.506317138671875,25.081880569458008,26.7265682220459,33.343048095703125,39.7734260559082,46.153011322021484,52.596923828125],[174.1394500732422,162.62489318847656,149.16004943847656,134.97364807128906,120.92447662353516,108.06343841552734,94.86676788330078,83.82272338867188,77.7781753540039,71.95187377929688,64.88156127929688,57.77402877807617,50.665042877197266,42.31514358520508,33.75063705444336,30.504684448242188,35.91915512084961,42.74984359741211,49.84637451171875,57.03544998168945],[184.509765625,171.5378875732422,157.13401794433594,142.78167724609375,129.5974884033203,117.06934356689453,104.20243072509766,94.47362518310547,88.49263763427734,80.9236831665039,73.77084350585938,67.1082992553711,59.94290542602539,52.192867279052734,43.48586654663086,38.224056243896484,38.90983200073242,45.6534538269043,53.63026428222656,62.99666213989258],[190.7437744140625,177.5851287841797,163.74839782714844,149.82400512695312,137.7917022705078,125.81805419921875,113.71285247802734,105.0035400390625,98.73310089111328,90.45975494384766,82.31554412841797,76.87224578857422,69.7837142944336,61.58315658569336,52.8474006652832,47.5811653137207,46.31401062011719,50.27994155883789,57.28769302368164,66.44225311279297]],\"type\":\"surface\"}], {\"template\":{\"data\":{\"histogram2dcontour\":[{\"type\":\"histogram2dcontour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"choropleth\":[{\"type\":\"choropleth\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"histogram2d\":[{\"type\":\"histogram2d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmap\":[{\"type\":\"heatmap\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmapgl\":[{\"type\":\"heatmapgl\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"contourcarpet\":[{\"type\":\"contourcarpet\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"contour\":[{\"type\":\"contour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"surface\":[{\"type\":\"surface\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"mesh3d\":[{\"type\":\"mesh3d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"scatter\":[{\"fillpattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2},\"type\":\"scatter\"}],\"parcoords\":[{\"type\":\"parcoords\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolargl\":[{\"type\":\"scatterpolargl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"bar\":[{\"error_x\":{\"color\":\"#2a3f5f\"},\"error_y\":{\"color\":\"#2a3f5f\"},\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"bar\"}],\"scattergeo\":[{\"type\":\"scattergeo\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolar\":[{\"type\":\"scatterpolar\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"histogram\":[{\"marker\":{\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"histogram\"}],\"scattergl\":[{\"type\":\"scattergl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatter3d\":[{\"type\":\"scatter3d\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattermapbox\":[{\"type\":\"scattermapbox\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterternary\":[{\"type\":\"scatterternary\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattercarpet\":[{\"type\":\"scattercarpet\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"carpet\":[{\"aaxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"baxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"type\":\"carpet\"}],\"table\":[{\"cells\":{\"fill\":{\"color\":\"#EBF0F8\"},\"line\":{\"color\":\"white\"}},\"header\":{\"fill\":{\"color\":\"#C8D4E3\"},\"line\":{\"color\":\"white\"}},\"type\":\"table\"}],\"barpolar\":[{\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"barpolar\"}],\"pie\":[{\"automargin\":true,\"type\":\"pie\"}]},\"layout\":{\"autotypenumbers\":\"strict\",\"colorway\":[\"#636efa\",\"#EF553B\",\"#00cc96\",\"#ab63fa\",\"#FFA15A\",\"#19d3f3\",\"#FF6692\",\"#B6E880\",\"#FF97FF\",\"#FECB52\"],\"font\":{\"color\":\"#2a3f5f\"},\"hovermode\":\"closest\",\"hoverlabel\":{\"align\":\"left\"},\"paper_bgcolor\":\"white\",\"plot_bgcolor\":\"#E5ECF6\",\"polar\":{\"bgcolor\":\"#E5ECF6\",\"angularaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"radialaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"ternary\":{\"bgcolor\":\"#E5ECF6\",\"aaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"baxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"caxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"coloraxis\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"colorscale\":{\"sequential\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"sequentialminus\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"diverging\":[[0,\"#8e0152\"],[0.1,\"#c51b7d\"],[0.2,\"#de77ae\"],[0.3,\"#f1b6da\"],[0.4,\"#fde0ef\"],[0.5,\"#f7f7f7\"],[0.6,\"#e6f5d0\"],[0.7,\"#b8e186\"],[0.8,\"#7fbc41\"],[0.9,\"#4d9221\"],[1,\"#276419\"]]},\"xaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"yaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"scene\":{\"xaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"yaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"zaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2}},\"shapedefaults\":{\"line\":{\"color\":\"#2a3f5f\"}},\"annotationdefaults\":{\"arrowcolor\":\"#2a3f5f\",\"arrowhead\":0,\"arrowwidth\":1},\"geo\":{\"bgcolor\":\"white\",\"landcolor\":\"#E5ECF6\",\"subunitcolor\":\"white\",\"showland\":true,\"showlakes\":true,\"lakecolor\":\"white\"},\"title\":{\"x\":0.05},\"mapbox\":{\"style\":\"light\"}}},\"title\":{\"text\":\"GPT-2's Loss Landscape\"},\"autosize\":true,\"width\":1000,\"height\":600}, {\"responsive\": true} ).then(function(){\n",
|
|
" \n",
|
|
"var gd = document.getElementById('f166e894-966e-4e31-85e8-e626977b9a7f');\n",
|
|
"var x = new MutationObserver(function (mutations, observer) {{\n",
|
|
" var display = window.getComputedStyle(gd).display;\n",
|
|
" if (!display || display === 'none') {{\n",
|
|
" console.log([gd, 'removed!']);\n",
|
|
" Plotly.purge(gd);\n",
|
|
" observer.disconnect();\n",
|
|
" }}\n",
|
|
"}});\n",
|
|
"\n",
|
|
"// Listen for the removal of the full notebook cells\n",
|
|
"var notebookContainer = gd.closest('#notebook-container');\n",
|
|
"if (notebookContainer) {{\n",
|
|
" x.observe(notebookContainer, {childList: true});\n",
|
|
"}}\n",
|
|
"\n",
|
|
"// Listen for the clearing of the current output cell\n",
|
|
"var outputEl = gd.closest('.output');\n",
|
|
"if (outputEl) {{\n",
|
|
" x.observe(outputEl, {childList: true});\n",
|
|
"}}\n",
|
|
"\n",
|
|
" }) }; </script> </div>\n",
|
|
"</body>\n",
|
|
"</html>"
|
|
]
|
|
},
|
|
"metadata": {}
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [],
|
|
"metadata": {
|
|
"id": "5Gvy1Bkut13M"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": [],
|
|
"gpuType": "T4",
|
|
"include_colab_link": true
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
} |