diff --git a/Decoding_Strategies_in_Large_Language Models.ipynb b/Decoding_Strategies_in_Large_Language Models.ipynb
new file mode 100644
index 0000000..f628ad7
--- /dev/null
+++ b/Decoding_Strategies_in_Large_Language Models.ipynb
@@ -0,0 +1,2885 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Decoding Strategies in Large Language Models\n",
+ "\n",
+ "> A Guide to Text Generation From Beam Search to Nucleus Sampling\n",
+ "\n",
+ "❤️ Created by [@maximelabonne](https://twitter.com/maximelabonne).\n",
+ "\n",
+ "Companion notebook to execute the code from the following article: https://mlabonne.github.io/blog/decoding/"
+ ],
+ "metadata": {
+ "id": "qaLKx40NbTD6"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "cBr2gYVBVDka"
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "\n",
+ "# Install transformers and graphviz\n",
+ "!sudo apt-get install graphviz graphviz-dev\n",
+ "!pip install transformers pygraphviz\n",
+ "\n",
+ "# Make sure we're using UTF-8 as encoding\n",
+ "import locale\n",
+ "locale.getpreferredencoding = lambda: \"UTF-8\"\n",
+ "\n",
+ "# Set seed\n",
+ "import torch\n",
+ "torch.manual_seed(42)\n",
+ "torch.cuda.manual_seed(42)\n",
+ "torch.cuda.manual_seed_all(42)\n",
+ "torch.backends.cudnn.deterministic = True\n",
+ "torch.backends.cudnn.benchmark = False"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## 📚 Background"
+ ],
+ "metadata": {
+ "id": "7R4IKg4lbNcN"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 229,
+ "referenced_widgets": [
+ "9713d401d53341e38141920ea03dbd5a",
+ "7a4520eef6b046049ac015596f45e765",
+ "f630193167224e92aef625d875d74f03",
+ "8367416a0fae4c46887b947b4627c9ed",
+ "3d8a0c7938274453989b93a819dfbaae",
+ "24dfcd472872492ab9dd78704d01f391",
+ "928ac9d1df6348da8b937dd4baed73bc",
+ "df124877f6fd4d71be3dc3081577d963",
+ "abdd1746c797421688fff22da1c24471",
+ "69fe044f596f44d48ad092a6a605b8a7",
+ "d30b4657252f44cc9f18dfce198fa490",
+ "64efc321afdd4b3782765a02888a772b",
+ "17831ee3579f4013a5381400fd70b38d",
+ "9a35e7be906641fdbf6170cd73a15fa6",
+ "5a78d9f4fa474d68a1dea3999bb961da",
+ "ddb595c5abd54c81a6589063a0e9c472",
+ "9ae75f0fed3846298fa003bc700cb30c",
+ "974039b5318043569c2e5f04ab321efe",
+ "d44421e59afd4cafa21cde90bd16bd1a",
+ "b3dc8bf4a0814f2ba6ece0544d6ffba1",
+ "6b27b86bcd4f440da024f952cde0fb86",
+ "feda016e22184435b0beb5ee6a6c70a3",
+ "5443950734244bb58cc0d16f3a8e6431",
+ "7518763d64a643f88c6951d4a44da53c",
+ "c705fa5b9cbf4234a1b6fb5f534b6c58",
+ "08ef10859c1642c6a4062efdc8adb94d",
+ "762ea8e309a94a3584228c7102d6468c",
+ "33703304a4a74f608ca5387effe26df0",
+ "698d554641de4a518ac8387ec70fbde0",
+ "d1403cebed8f4963b1749398d17b1535",
+ "9eead7b6eca74c5ab5a48a1917f21df8",
+ "ef702a89b5b347ef93cffad017b14681",
+ "a74ffb45ba3742c587b42a18344bf688",
+ "a53f23ee62874333a78984b0043f04c0",
+ "56ba1e0e45884b0dbb21c090f2969532",
+ "78db0251ac9146809faf1401165d1b4e",
+ "4e867868b7bc4b4c8fd7a91df446931a",
+ "3a54459246c6490faf11ea96d1ee0d8e",
+ "2fba37301c8a42e2b96c98028c78eee7",
+ "464ad324599048089f36a5a78d27afd4",
+ "1c19b681477d45e098ab58d3a85dcf31",
+ "87ceaae75dda4979b97c63958929b3c6",
+ "504924887692443e9791a861e79e3683",
+ "a53a20a3248d460bb728b8931629328b",
+ "85b0326d03b94a2a81a4af035d011cad",
+ "285e6fce050c4bb591df0f93be22e4c3",
+ "b00326a362f24a098ddee078671e2efe",
+ "1e3601441cd0480aa6443b934d38bce1",
+ "ce38696b9d844577ba0a930e50b65f26",
+ "02a3d9c1a17141e589bc3cc938a60b05",
+ "cae537a671ec4796b414e67987e40108",
+ "6c7928af450e4cf89420dcafe21cef60",
+ "9db081d87e174740b247f80a48065170",
+ "0805e191fe564eaca6ff8545512c366b",
+ "1833da9abbfe4090b09c8c9ddcce41a2"
+ ]
+ },
+ "id": "1LS5sCUPwzaD",
+ "outputId": "147addb9-bbf7-4eeb-bfc0-3b46b06b076d"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading (…)lve/main/config.json: 0%| | 0.00/665 [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "9713d401d53341e38141920ea03dbd5a"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading pytorch_model.bin: 0%| | 0.00/548M [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "64efc321afdd4b3782765a02888a772b"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading (…)neration_config.json: 0%| | 0.00/124 [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "5443950734244bb58cc0d16f3a8e6431"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading (…)olve/main/vocab.json: 0%| | 0.00/1.04M [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "a53f23ee62874333a78984b0043f04c0"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading (…)olve/main/merges.txt: 0%| | 0.00/456k [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "85b0326d03b94a2a81a4af035d011cad"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
+ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Generated text: I have a dream of being a doctor.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformers import GPT2LMHeadModel, GPT2Tokenizer\n",
+ "import torch\n",
+ "\n",
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
+ "model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)\n",
+ "tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n",
+ "model.eval()\n",
+ "\n",
+ "text = \"I have a dream\"\n",
+ "input_ids = tokenizer.encode(text, return_tensors='pt').to(device)\n",
+ "\n",
+ "outputs = model.generate(input_ids, max_length=len(input_ids.squeeze())+5)\n",
+ "generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
+ "print(f\"Generated text: {generated_text}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## 🏃♂️ Greedy Search"
+ ],
+ "metadata": {
+ "id": "xw7mElwjbPYW"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "zm_AiUhK8yrP",
+ "outputId": "74bef07c-cc93-4643-d67c-5ba67dfed8eb"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Generated text: I have a dream of being a doctor.\n"
+ ]
+ }
+ ],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "import networkx as nx\n",
+ "import numpy as np\n",
+ "import time\n",
+ "\n",
+ "def get_log_prob(logits, token_id):\n",
+ " # Compute the softmax of the logits\n",
+ " probabilities = torch.nn.functional.softmax(logits, dim=-1)\n",
+ " log_probabilities = torch.log(probabilities)\n",
+ " \n",
+ " # Get the log probability of the token\n",
+ " token_log_probability = log_probabilities[token_id].item()\n",
+ " return token_log_probability\n",
+ "\n",
+ "def greedy_search(input_ids, node, length=5):\n",
+ " if length == 0:\n",
+ " return input_ids\n",
+ "\n",
+ " outputs = model(input_ids)\n",
+ " predictions = outputs.logits\n",
+ "\n",
+ " # Get the predicted next sub-word (here we use top-k search)\n",
+ " logits = predictions[0, -1, :]\n",
+ " token_id = torch.argmax(logits).unsqueeze(0)\n",
+ "\n",
+ " # Compute the score of the predicted token\n",
+ " token_score = get_log_prob(logits, token_id)\n",
+ "\n",
+ " # Add the predicted token to the list of input ids\n",
+ " new_input_ids = torch.cat([input_ids, token_id.unsqueeze(0)], dim=-1)\n",
+ "\n",
+ " # Add node and edge to graph\n",
+ " next_token = tokenizer.decode(token_id, skip_special_tokens=True)\n",
+ " current_node = list(graph.successors(node))[0]\n",
+ " graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100\n",
+ " graph.nodes[current_node]['token'] = next_token + f\"_{length}\"\n",
+ "\n",
+ " # Recursive call\n",
+ " input_ids = greedy_search(new_input_ids, current_node, length-1)\n",
+ " \n",
+ " return input_ids\n",
+ "\n",
+ "# Parameters\n",
+ "length = 5\n",
+ "beams = 1\n",
+ "\n",
+ "# Create a balanced tree with height 'length'\n",
+ "graph = nx.balanced_tree(1, length, create_using=nx.DiGraph())\n",
+ "\n",
+ "# Add 'tokenscore', 'cumscore', and 'token' attributes to each node\n",
+ "for node in graph.nodes:\n",
+ " graph.nodes[node]['tokenscore'] = 100\n",
+ " graph.nodes[node]['token'] = text\n",
+ "\n",
+ "# Start generating text\n",
+ "output_ids = greedy_search(input_ids, 0, length=length)\n",
+ "output = tokenizer.decode(output_ids.squeeze().tolist(), skip_special_tokens=True)\n",
+ "print(f\"Generated text: {output}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "id": "NeMNyl5maqTc",
+ "outputId": "0d4a5d12-0c05-416f-e5d7-99721b7f7214"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "