You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
llm-course/Visualizing_GPT_2's_Loss_La...

228 lines
39 KiB
Plaintext

1 year ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/mlabonne/how-to-data-science/blob/main/Visualizing_GPT_2's_Loss_Landscape.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Visualizing GPT-2's Loss Landscape\n",
"\n",
"❤️ Created by [@maximelabonne](https://twitter.com/maximelabonne).\n",
"\n",
"Simple perturbation-based calculation of the negative log-likelihood loss in two directions, given \"I have a dream\" as input.\n",
"\n",
"Reference: [Visualizing the Loss Landscape of Neural Nets](https://arxiv.org/abs/1712.09913), by Li et al. (2018)"
],
"metadata": {
"id": "Dgptqrg0zEY5"
}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "lIYdn1woOS1n",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "ffd96eb4-2861-4658-e89d-383eaeb36c3b"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.2/7.2 MB\u001b[0m \u001b[31m47.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m236.8/236.8 kB\u001b[0m \u001b[31m14.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m75.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m67.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h"
]
}
],
"source": [
"!pip install -q transformers"
]
},
{
"cell_type": "code",
"source": [
"%%time\n",
"\n",
"import torch\n",
"from transformers import GPT2LMHeadModel, GPT2Tokenizer\n",
"import numpy as np\n",
"import plotly.graph_objects as go\n",
"from tqdm import tqdm\n",
"import imageio\n",
"import os\n",
"\n",
"# Load pre-trained model\n",
"model_name = 'gpt2'\n",
"model = GPT2LMHeadModel.from_pretrained(model_name)\n",
"tokenizer = GPT2Tokenizer.from_pretrained(model_name)\n",
"\n",
"# Set model to evaluation mode\n",
"model.eval()\n",
"\n",
"# Define our input\n",
"input_text = \"I have a dream\"\n",
"inputs = tokenizer.encode_plus(input_text, return_tensors=\"pt\")\n",
"\n",
"# Compute the original loss\n",
"outputs = model(**inputs, labels=inputs[\"input_ids\"])\n",
"original_loss = outputs.loss.item()\n",
"\n",
"# Define two random directions\n",
"direction1 = [torch.randn_like(p) for p in model.parameters()]\n",
"direction2 = [torch.randn_like(p) for p in model.parameters()]\n",
"\n",
"# Normalize vectors\n",
"for p, d1, d2 in zip(model.parameters(), direction1, direction2):\n",
" norm_p = torch.linalg.norm(p.flatten())\n",
" d1.div_(torch.linalg.norm(d1.flatten())).mul_(norm_p)\n",
" d2.div_(torch.linalg.norm(d2.flatten())).mul_(norm_p)\n",
"\n",
"# Define the range to explore\n",
"x = np.linspace(-1, 1, 20)\n",
"y = np.linspace(-1, 1, 20)\n",
"X, Y = np.meshgrid(x, y)\n",
"\n",
"# Prepare to collect the losses\n",
"Z = np.zeros_like(X)\n",
"\n",
"# Compute loss for each direction\n",
"for i in tqdm(range(x.size), desc=\"x progress\"):\n",
" for j in tqdm(range(y.size), desc=\"y progress\", leave=False):\n",
" # Perturb the model parameters\n",
" for p, d1, d2 in zip(model.parameters(), direction1, direction2):\n",
" p.data.add_(x[i]*d1 + y[j]*d2)\n",
" \n",
" # Compute the loss\n",
" outputs = model(**inputs, labels=inputs['input_ids'])\n",
" Z[i, j] = outputs.loss.item()\n",
" \n",
" # Revert the model parameters\n",
" for p, d1, d2 in zip(model.parameters(), direction1, direction2):\n",
" p.data.sub_(x[i]*d1 + y[j]*d2)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4Wud37sZa1Y7",
"outputId": "c4e9c839-2938-4df8-c54f-96f5792cc0ed"
},
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 20/20 [12:42<00:00, 38.11s/it]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Create 3D plot\n",
"fig = go.Figure(data=[go.Surface(z=Z, x=X, y=Y, \n",
" showscale=False,)])\n",
"fig.update_layout(\n",
" title=\"GPT-2's Loss Landscape\",\n",
" autosize=True,\n",
" width=1000,\n",
" height=600,\n",
" # scene=dict(\n",
" # xaxis=dict(visible=False),\n",
" # yaxis=dict(visible=False),\n",
" # zaxis=dict(visible=False),\n",
" # )\n",
")\n",
"fig.show()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 617
},
"id": "5zcGwU4ji67L",
"outputId": "f149f7eb-abfb-4b06-aead-ace3772c5379"
},
"execution_count": 32,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<html>\n",
"<head><meta charset=\"utf-8\" /></head>\n",
"<body>\n",
" <div> <script src=\"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_SVG\"></script><script type=\"text/javascript\">if (window.MathJax && window.MathJax.Hub && window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}</script> <script type=\"text/javascript\">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>\n",
" <script src=\"https://cdn.plot.ly/plotly-2.18.2.min.js\"></script> <div id=\"f166e894-966e-4e31-85e8-e626977b9a7f\" class=\"plotly-graph-div\" style=\"height:600px; width:1000px;\"></div> <script type=\"text/javascript\"> window.PLOTLYENV=window.PLOTLYENV || {}; if (document.getElementById(\"f166e894-966e-4e31-85e8-e626977b9a7f\")) { Plotly.newPlot( \"f166e894-966e-4e31-85e8-e626977b9a7f\", [{\"showscale\":false,\"x\":[[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.05263157894736836,0.1578947368421053,0.26315789473684204,0.36842105263157876,0.4736842105263157,0.5789473684210527,0.6842105263157894,0.7894736842105261,0.894736842105263,1.0],[-1.0,-0.8947368421052632,-0.7894736842105263,-0.6842105263157895,-0.5789473684210527,-0.4736842105263158,-0.368421052631579,-0.26315789473684215,-0.1578947368421053,-0.052631578947368474,0.0526
" \n",
"var gd = document.getElementById('f166e894-966e-4e31-85e8-e626977b9a7f');\n",
"var x = new MutationObserver(function (mutations, observer) {{\n",
" var display = window.getComputedStyle(gd).display;\n",
" if (!display || display === 'none') {{\n",
" console.log([gd, 'removed!']);\n",
" Plotly.purge(gd);\n",
" observer.disconnect();\n",
" }}\n",
"}});\n",
"\n",
"// Listen for the removal of the full notebook cells\n",
"var notebookContainer = gd.closest('#notebook-container');\n",
"if (notebookContainer) {{\n",
" x.observe(notebookContainer, {childList: true});\n",
"}}\n",
"\n",
"// Listen for the clearing of the current output cell\n",
"var outputEl = gd.closest('.output');\n",
"if (outputEl) {{\n",
" x.observe(outputEl, {childList: true});\n",
"}}\n",
"\n",
" }) }; </script> </div>\n",
"</body>\n",
"</html>"
]
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "5Gvy1Bkut13M"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}