{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"source": [
"# Decoding Strategies in Large Language Models\n",
"\n",
"> A Guide to Text Generation From Beam Search to Nucleus Sampling\n",
"\n",
"❤️ Created by [@maximelabonne](https://twitter.com/maximelabonne).\n",
"\n",
"Companion notebook to execute the code from the following article: https://mlabonne.github.io/blog/decoding/"
],
"metadata": {
"id": "qaLKx40NbTD6"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cBr2gYVBVDka"
},
"outputs": [],
"source": [
"%%capture\n",
"\n",
"# Install transformers and graphviz\n",
"!sudo apt-get install graphviz graphviz-dev\n",
"!pip install transformers pygraphviz\n",
"\n",
"# Make sure we're using UTF-8 as encoding\n",
"import locale\n",
"locale.getpreferredencoding = lambda: \"UTF-8\"\n",
"\n",
"# Set seed\n",
"import torch\n",
"torch.manual_seed(42)\n",
"torch.cuda.manual_seed(42)\n",
"torch.cuda.manual_seed_all(42)\n",
"torch.backends.cudnn.deterministic = True\n",
"torch.backends.cudnn.benchmark = False"
]
},
{
"cell_type": "markdown",
"source": [
"## 📚 Background"
],
"metadata": {
"id": "7R4IKg4lbNcN"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 229,
"referenced_widgets": [
"9713d401d53341e38141920ea03dbd5a",
"7a4520eef6b046049ac015596f45e765",
"f630193167224e92aef625d875d74f03",
"8367416a0fae4c46887b947b4627c9ed",
"3d8a0c7938274453989b93a819dfbaae",
"24dfcd472872492ab9dd78704d01f391",
"928ac9d1df6348da8b937dd4baed73bc",
"df124877f6fd4d71be3dc3081577d963",
"abdd1746c797421688fff22da1c24471",
"69fe044f596f44d48ad092a6a605b8a7",
"d30b4657252f44cc9f18dfce198fa490",
"64efc321afdd4b3782765a02888a772b",
"17831ee3579f4013a5381400fd70b38d",
"9a35e7be906641fdbf6170cd73a15fa6",
"5a78d9f4fa474d68a1dea3999bb961da",
"ddb595c5abd54c81a6589063a0e9c472",
"9ae75f0fed3846298fa003bc700cb30c",
"974039b5318043569c2e5f04ab321efe",
"d44421e59afd4cafa21cde90bd16bd1a",
"b3dc8bf4a0814f2ba6ece0544d6ffba1",
"6b27b86bcd4f440da024f952cde0fb86",
"feda016e22184435b0beb5ee6a6c70a3",
"5443950734244bb58cc0d16f3a8e6431",
"7518763d64a643f88c6951d4a44da53c",
"c705fa5b9cbf4234a1b6fb5f534b6c58",
"08ef10859c1642c6a4062efdc8adb94d",
"762ea8e309a94a3584228c7102d6468c",
"33703304a4a74f608ca5387effe26df0",
"698d554641de4a518ac8387ec70fbde0",
"d1403cebed8f4963b1749398d17b1535",
"9eead7b6eca74c5ab5a48a1917f21df8",
"ef702a89b5b347ef93cffad017b14681",
"a74ffb45ba3742c587b42a18344bf688",
"a53f23ee62874333a78984b0043f04c0",
"56ba1e0e45884b0dbb21c090f2969532",
"78db0251ac9146809faf1401165d1b4e",
"4e867868b7bc4b4c8fd7a91df446931a",
"3a54459246c6490faf11ea96d1ee0d8e",
"2fba37301c8a42e2b96c98028c78eee7",
"464ad324599048089f36a5a78d27afd4",
"1c19b681477d45e098ab58d3a85dcf31",
"87ceaae75dda4979b97c63958929b3c6",
"504924887692443e9791a861e79e3683",
"a53a20a3248d460bb728b8931629328b",
"85b0326d03b94a2a81a4af035d011cad",
"285e6fce050c4bb591df0f93be22e4c3",
"b00326a362f24a098ddee078671e2efe",
"1e3601441cd0480aa6443b934d38bce1",
"ce38696b9d844577ba0a930e50b65f26",
"02a3d9c1a17141e589bc3cc938a60b05",
"cae537a671ec4796b414e67987e40108",
"6c7928af450e4cf89420dcafe21cef60",
"9db081d87e174740b247f80a48065170",
"0805e191fe564eaca6ff8545512c366b",
"1833da9abbfe4090b09c8c9ddcce41a2"
]
},
"id": "1LS5sCUPwzaD",
"outputId": "147addb9-bbf7-4eeb-bfc0-3b46b06b076d"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)lve/main/config.json: 0%| | 0.00/665 [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "9713d401d53341e38141920ea03dbd5a"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading pytorch_model.bin: 0%| | 0.00/548M [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "64efc321afdd4b3782765a02888a772b"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)neration_config.json: 0%| | 0.00/124 [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "5443950734244bb58cc0d16f3a8e6431"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)olve/main/vocab.json: 0%| | 0.00/1.04M [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "a53f23ee62874333a78984b0043f04c0"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)olve/main/merges.txt: 0%| | 0.00/456k [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "85b0326d03b94a2a81a4af035d011cad"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Generated text: I have a dream of being a doctor.\n"
]
}
],
"source": [
"from transformers import GPT2LMHeadModel, GPT2Tokenizer\n",
"import torch\n",
"\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)\n",
"tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n",
"model.eval()\n",
"\n",
"text = \"I have a dream\"\n",
"input_ids = tokenizer.encode(text, return_tensors='pt').to(device)\n",
"\n",
"outputs = model.generate(input_ids, max_length=len(input_ids.squeeze())+5)\n",
"generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
"print(f\"Generated text: {generated_text}\")"
]
},
{
"cell_type": "markdown",
"source": [
"## 🏃♂️ Greedy Search"
],
"metadata": {
"id": "xw7mElwjbPYW"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zm_AiUhK8yrP",
"outputId": "74bef07c-cc93-4643-d67c-5ba67dfed8eb"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Generated text: I have a dream of being a doctor.\n"
]
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import networkx as nx\n",
"import numpy as np\n",
"import time\n",
"\n",
"def get_log_prob(logits, token_id):\n",
" # Compute the softmax of the logits\n",
" probabilities = torch.nn.functional.softmax(logits, dim=-1)\n",
" log_probabilities = torch.log(probabilities)\n",
" \n",
" # Get the log probability of the token\n",
" token_log_probability = log_probabilities[token_id].item()\n",
" return token_log_probability\n",
"\n",
"def greedy_search(input_ids, node, length=5):\n",
" if length == 0:\n",
" return input_ids\n",
"\n",
" outputs = model(input_ids)\n",
" predictions = outputs.logits\n",
"\n",
" # Get the predicted next sub-word (here we use top-k search)\n",
" logits = predictions[0, -1, :]\n",
" token_id = torch.argmax(logits).unsqueeze(0)\n",
"\n",
" # Compute the score of the predicted token\n",
" token_score = get_log_prob(logits, token_id)\n",
"\n",
" # Add the predicted token to the list of input ids\n",
" new_input_ids = torch.cat([input_ids, token_id.unsqueeze(0)], dim=-1)\n",
"\n",
" # Add node and edge to graph\n",
" next_token = tokenizer.decode(token_id, skip_special_tokens=True)\n",
" current_node = list(graph.successors(node))[0]\n",
" graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100\n",
" graph.nodes[current_node]['token'] = next_token + f\"_{length}\"\n",
"\n",
" # Recursive call\n",
" input_ids = greedy_search(new_input_ids, current_node, length-1)\n",
" \n",
" return input_ids\n",
"\n",
"# Parameters\n",
"length = 5\n",
"beams = 1\n",
"\n",
"# Create a balanced tree with height 'length'\n",
"graph = nx.balanced_tree(1, length, create_using=nx.DiGraph())\n",
"\n",
"# Add 'tokenscore', 'cumscore', and 'token' attributes to each node\n",
"for node in graph.nodes:\n",
" graph.nodes[node]['tokenscore'] = 100\n",
" graph.nodes[node]['token'] = text\n",
"\n",
"# Start generating text\n",
"output_ids = greedy_search(input_ids, 0, length=length)\n",
"output = tokenizer.decode(output_ids.squeeze().tolist(), skip_special_tokens=True)\n",
"print(f\"Generated text: {output}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "NeMNyl5maqTc",
"outputId": "0d4a5d12-0c05-416f-e5d7-99721b7f7214"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"