{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"source": [
"# Introduction to Weight Quantization\n",
"> Reducing the size of Large Language Models with 8-bit quantization\n",
"\n",
"❤️ Created by [@maximelabonne](https://twitter.com/maximelabonne).\n",
"\n",
"Companion notebook to execute the code from the following article: https://mlabonne.github.io/blog/intro_weight_quantization/"
],
"metadata": {
"id": "yG1VY-TJoxix"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WMVwLxdUzlq2"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"def absmax_quantize(X):\n",
" # Calculate scale\n",
" scale = 127 / torch.max(torch.abs(X))\n",
"\n",
" # Quantize\n",
" X_quant = (scale * X).round()\n",
"\n",
" # Dequantize\n",
" X_dequant = X_quant / scale\n",
"\n",
" return X_quant.to(torch.int8), X_dequant"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CE7XqWOR6oCa"
},
"outputs": [],
"source": [
"def zeropoint_quantize(X):\n",
" # Calculate value range (denominator)\n",
" x_range = torch.max(X) - torch.min(X)\n",
" x_range = 1 if x_range == 0 else x_range\n",
"\n",
" # Calculate scale\n",
" scale = 255 / x_range\n",
"\n",
" # Shift by zero-point\n",
" zeropoint = (-scale * torch.min(X) - 128).round()\n",
"\n",
" # Scale and round the inputs\n",
" X_quant = torch.clip((X * scale + zeropoint).round(), -128, 127)\n",
"\n",
" # Dequantize\n",
" X_dequant = (X_quant - zeropoint) / scale\n",
"\n",
" return X_quant.to(torch.int8), X_dequant"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lIYdn1woOS1n"
},
"outputs": [],
"source": [
"!pip install -q bitsandbytes>=0.39.0\n",
"!pip install -q git+https://github.com/huggingface/accelerate.git\n",
"!pip install -q git+https://github.com/huggingface/transformers.git"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 792,
"referenced_widgets": [
"e7e1636dbc8944c49b375286b8d89fe4",
"ca714b7b6dd94fedb669ce39c6796bc4",
"71616c99117145adaca155701522460b",
"d7364ae105524cfba0eed4095273ed57",
"eeb266ac1aba409184d9ac1fbd433e9c",
"fd1a537d4c1c49b6972f7bad58a24417",
"bb06d9a3f9f64821b7839ce777580c4b",
"c53b9dc02f144a91aeca032b8b67259b",
"e5a0a603b6ee48c8ab7e5f5bcb62c3ff",
"15a5a3eb178348ef83cb7d02198636b3",
"47b5db2b2caa48cca14ea162ed5e7bd3",
"7784bd5d3f1b4ac890a711c01bd653cb",
"7626f77200244f489150c2c4c6794e53",
"1482e979e3ae4731b9b673b5a7db44ab",
"d95190bba0364de99f04bb9adad173b5",
"2f4124fead964700a4158cf1373bf74e",
"6288899dd1c64a779d79e17095b8ea9d",
"c013211a09134f42934ac041f86b5cfe",
"85e096686a604059a0a61afacedac67b",
"0a23f8cda4484b6ca80c14b6691cc56d",
"dcb679577da0472c841ab985ef92618f",
"c7438dd71b964dea809e9f4a95046439",
"0e047a8a2a0e4337b17f3e8612044967",
"69e9a5bb2972442b9895e70210274c8d",
"768c433b555b43669de698cf9c738d79",
"43a348d835cd4eb4936d8d89c6999de7",
"715ab0815132494892001b7a15ebd9ed",
"9290f81e9e5d4a39b28a3836d6472886",
"fba8130749824daaa73c5890c773e900",
"de639b2a33e34fc593136c7bb07da47b",
"d804702825694057a213fbae380b94d6",
"b1d0c88f5b87449380a29956147b867a",
"38888cadff0a472e96c4925e2881a755",
"681a9a78878945b7b6afb2d87b769146",
"55add7bdffbe4ddea4fb7407aa61fbc8",
"a3a20248b4e843249ec3a10d7c8e84ad",
"e1ef7e3213a446a4815a84b8aab67576",
"6a9c13356d424bd6bad9565a14f28f16",
"c4f633e0dcb74c8992c482efc80ebe31",
"d8f5dafc06ca4bf0b28a7101ebe7a07e",
"d12f1b2228d444948862d94d769d0b0d",
"216cb64c2b4a41eb8114176608f6a0ca",
"0747780fac22461a8d8ef53dbb18ca39",
"4fdb1280c19a4df6ba35a95abf9862f0",
"cea83a47549a4ddb91eae020d1cd943c",
"f1f75a95e2094ebb9d5a891447008a7b",
"12a91a0d7ff94165a75640b285c08a52",
"dd3fa44234334118918ef5c632ee65d1",
"2463480560e14307864ba9743dc5d41d",
"06d0c57030474d008ff4bbafa1e35695",
"450cb6fd2bec48dfb73869cffd3d5c9f",
"a893cd0d7cd74a53b0a6504af7580f65",
"290836d667df420597ccbe2b934ca5ed",
"1d57938333654dd49156e6d488688d13",
"1ca5165f5f4443d5a20cf361b01922d1",
"90cf2d52029d4392aaea970508506261",
"c0c5a26685dc414f9a1295f077c81488",
"82d9759ea89a44feb72b57ca67bc7f2a",
"f8b92ab82eb64c57bafb1f37ef34af1b",
"8d78e5c72f894275b504948461193b66",
"77b9f13345b24caaa7afc3bd43e33eba",
"57b079495f844af684c24bb1cd711bf0",
"0227537a636e41cb93abc9e416015cce",
"b45e6e8aa21f47c1821f2b24a2f46944",
"8257dc1927514846b51511bc387a54db",
"7e7d6bc9b8544e078003d2fe6c74dcef"
]
},
"id": "NTDg7uUOGBmS",
"outputId": "cc48b090-31d1-41ae-ca5c-dbffcb67bcb6"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)lve/main/config.json: 0%| | 0.00/665 [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "e7e1636dbc8944c49b375286b8d89fe4"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"===================================BUG REPORT===================================\n",
"Welcome to bitsandbytes. For bug reports, please run\n",
"\n",
"python -m bitsandbytes\n",
"\n",
" and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
"================================================================================\n",
"bin /usr/local/lib/python3.10/dist-packages/bitsandbytes/libbitsandbytes_cuda118.so\n",
"CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...\n",
"CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so.11.0\n",
"CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
"CUDA SETUP: Detected CUDA version 118\n",
"CUDA SETUP: Loading binary /usr/local/lib/python3.10/dist-packages/bitsandbytes/libbitsandbytes_cuda118.so...\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: /usr/lib64-nvidia did not contain ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] as expected! Searching further paths...\n",
" warn(msg)\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/sys/fs/cgroup/memory.events /var/colab/cgroup/jupyter-children/memory.events')}\n",
" warn(msg)\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('http'), PosixPath('//172.28.0.1'), PosixPath('8013')}\n",
" warn(msg)\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('--logtostderr --listen_host=172.28.0.12 --target_host=172.28.0.12 --tunnel_background_save_url=https'), PosixPath('//colab.research.google.com/tun/m/cc48301118ce562b961b3c22d803539adc1e0c19/gpu-t4-s-20b5bv2xvtu9a --tunnel_background_save_delay=10s --tunnel_periodic_background_save_frequency=30m0s --enable_output_coalescing=true --output_coalescing_required=true')}\n",
" warn(msg)\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/env/python')}\n",
" warn(msg)\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('module'), PosixPath('//ipykernel.pylab.backend_inline')}\n",
" warn(msg)\n",
"/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: Found duplicate ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] files: {PosixPath('/usr/local/cuda/lib64/libcudart.so.11.0'), PosixPath('/usr/local/cuda/lib64/libcudart.so')}.. We'll flip a coin and try one of these, in order to fail forward.\n",
"Either way, this might cause trouble in the future:\n",
"If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.\n",
" warn(msg)\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading model.safetensors: 0%| | 0.00/548M [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "7784bd5d3f1b4ac890a711c01bd653cb"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)neration_config.json: 0%| | 0.00/124 [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "0e047a8a2a0e4337b17f3e8612044967"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)olve/main/vocab.json: 0%| | 0.00/1.04M [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "681a9a78878945b7b6afb2d87b769146"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)olve/main/merges.txt: 0%| | 0.00/456k [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "cea83a47549a4ddb91eae020d1cd943c"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)/main/tokenizer.json: 0%| | 0.00/1.36M [00:00, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "90cf2d52029d4392aaea970508506261"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model size: 510,342,192 bytes\n"
]
}
],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"import torch\n",
"torch.manual_seed(0)\n",
"\n",
"# Set device to CPU for now\n",
"device = 'cpu'\n",
"\n",
"# Load model and tokenizer\n",
"model_id = 'gpt2'\n",
"model = AutoModelForCausalLM.from_pretrained(model_id).to(device)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"\n",
"# Print model size\n",
"print(f\"Model size: {model.get_memory_footprint():,} bytes\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YPI1EaimHyHm",
"outputId": "977e9b34-9426-46a1-d6c2-da80884b7483"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Original weights:\n",
"tensor([[-0.4738, -0.2614, -0.0978, ..., 0.0513, -0.0584, 0.0250],\n",
" [ 0.0874, 0.1473, 0.2387, ..., -0.0525, -0.0113, -0.0156],\n",
" [ 0.0039, 0.0695, 0.3668, ..., 0.1143, 0.0363, -0.0318],\n",
" ...,\n",
" [-0.2592, -0.0164, 0.1991, ..., 0.0095, -0.0516, 0.0319],\n",
" [ 0.1517, 0.2170, 0.1043, ..., 0.0293, -0.0429, -0.0475],\n",
" [-0.4100, -0.1924, -0.2400, ..., -0.0046, 0.0070, 0.0198]])\n",
"\n",
"Absmax quantized weights:\n",
"tensor([[-21, -12, -4, ..., 2, -3, 1],\n",
" [ 4, 7, 11, ..., -2, -1, -1],\n",
" [ 0, 3, 16, ..., 5, 2, -1],\n",
" ...,\n",
" [-12, -1, 9, ..., 0, -2, 1],\n",
" [ 7, 10, 5, ..., 1, -2, -2],\n",
" [-18, -9, -11, ..., 0, 0, 1]], dtype=torch.int8)\n",
"\n",
"Zero-point quantized weights:\n",
"tensor([[-20, -11, -3, ..., 3, -2, 2],\n",
" [ 5, 8, 12, ..., -1, 0, 0],\n",
" [ 1, 4, 18, ..., 6, 3, 0],\n",
" ...,\n",
" [-11, 0, 10, ..., 1, -1, 2],\n",
" [ 8, 11, 6, ..., 2, -1, -1],\n",
" [-18, -8, -10, ..., 1, 1, 2]], dtype=torch.int8)\n"
]
}
],
"source": [
"# Extract weights of the first layer\n",
"weights = model.transformer.h[0].attn.c_attn.weight.data\n",
"print(\"Original weights:\")\n",
"print(weights)\n",
"\n",
"# Quantize layer using absmax quantization\n",
"weights_abs_quant, _ = absmax_quantize(weights)\n",
"print(\"\\nAbsmax quantized weights:\")\n",
"print(weights_abs_quant)\n",
"\n",
"# Quantize layer using absmax quantization\n",
"weights_zp_quant, _ = zeropoint_quantize(weights)\n",
"print(\"\\nZero-point quantized weights:\")\n",
"print(weights_zp_quant)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5i2N7HC9Mmn7"
},
"outputs": [],
"source": [
"import numpy as np\n",
"from copy import deepcopy\n",
"\n",
"# Store original weights\n",
"weights = [param.data.clone() for param in model.parameters()]\n",
"\n",
"# Create model to quantize\n",
"model_abs = deepcopy(model)\n",
"\n",
"# Quantize all model weights\n",
"weights_abs = []\n",
"for param in model_abs.parameters():\n",
" _, dequantized = absmax_quantize(param.data)\n",
" param.data = dequantized\n",
" weights_abs.append(dequantized)\n",
"\n",
"# Create model to quantize\n",
"model_zp = deepcopy(model)\n",
"\n",
"# Quantize all model weights\n",
"weights_zp = []\n",
"for param in model_zp.parameters():\n",
" _, dequantized = zeropoint_quantize(param.data)\n",
" param.data = dequantized\n",
" weights_zp.append(dequantized)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "FlM_jWwpHh34",
"outputId": "0705932d-ec5a-4cb1-cc92-08072c014ee7"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"