{ "cells": [ { "cell_type": "markdown", "id": "a07e0f5e", "metadata": { "id": "a07e0f5e" }, "source": [ "
\n", " \n", "
\n", "\n", "# Distributed LLaMA for Text Classification using Prompt Tuning\n", "\n", "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [LLaMA](https://github.com/facebookresearch/llama) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the LLaMA blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n", "\n", "We will adapt LLaMA for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n", "\n", "To use this notebook in Colab:\n", "\n", "1. Follow this link: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)\n", "2. Go to **Runtime** -> **Change runtime type** and select the GPU accelerator." ] }, { "cell_type": "markdown", "id": "a3f8526f", "metadata": { "id": "a3f8526f" }, "source": [ "First, we have to prepare all dependencies." ] }, { "cell_type": "code", "execution_count": null, "id": "73bbc648", "metadata": { "id": "73bbc648" }, "outputs": [], "source": [ "%pip install -q datasets wandb scikit-learn\n", "%pip install -q git+https://github.com/bigscience-workshop/petals@main" ] }, { "cell_type": "code", "execution_count": null, "id": "b4ab6ca7", "metadata": { "id": "b4ab6ca7" }, "outputs": [], "source": [ "import os\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import transformers\n", "import wandb\n", "from datasets import load_dataset, load_metric\n", "from tqdm import tqdm\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader\n", "from transformers import LlamaTokenizer, get_scheduler, set_seed\n", "\n", "from petals import DistributedLlamaForSequenceClassification\n", "\n", "set_seed(0)" ] }, { "cell_type": "markdown", "id": "1bf07b5d", "metadata": { "id": "1bf07b5d" }, "source": [ "Let's set some hyperparameters for training:" ] }, { "cell_type": "code", "execution_count": null, "id": "f04ba4d2", "metadata": { "id": "f04ba4d2" }, "outputs": [], "source": [ "MODEL_NAME = \"enoch/llama-65b-hf\"\n", "\n", "# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n", "# The latter fine-tunes separate prefixes for each transformer block,\n", "# so prompt-tuning will take more time but yield better results.\n", "# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\n", "TUNING_MODE = 'ptune'\n", "\n", "NUM_PREFIX_TOKENS = 8\n", "DEVICE = 'cuda'\n", "BATCH_SIZE = 32\n", "LR = 1e-2\n", "WEIGHT_DECAY = 0.0\n", "NUM_EPOCHS = 3\n", "SEED = 42\n", "MODEL_MAX_LENGTH = 64" ] }, { "cell_type": "markdown", "id": "d38316bd", "metadata": { "id": "d38316bd" }, "source": [ "Here, we prepare tokenizer and distributed model and connect it to the public swarm." ] }, { "cell_type": "code", "execution_count": null, "id": "03c6e53e", "metadata": { "id": "03c6e53e" }, "outputs": [], "source": [ "tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME)\n", "tokenizer.padding_side = 'right'\n", "tokenizer.model_max_length = MODEL_MAX_LENGTH\n", "tokenizer.pad_token = tokenizer.unk_token\n", "model = DistributedLlamaForSequenceClassification.from_pretrained(\n", " MODEL_NAME,\n", " pre_seq_len=NUM_PREFIX_TOKENS,\n", " tuning_mode=TUNING_MODE\n", ").float().to(DEVICE)\n", "model.config.pad_token_id = tokenizer.pad_token_id" ] }, { "cell_type": "markdown", "id": "042e3786", "metadata": { "id": "042e3786" }, "source": [ "Let's prepare the SST-2 dataset. We need just one preprocessing function to tokenize the dataset." ] }, { "cell_type": "code", "execution_count": null, "id": "9c44d516", "metadata": { "id": "9c44d516" }, "outputs": [], "source": [ "task = 'sst2'\n", "\n", "dataset = load_dataset(\"glue\", task)\n", "\n", "def preprocess_function(examples):\n", " return tokenizer(examples[\"sentence\"], padding='max_length', truncation=True, return_token_type_ids=False)\n", "\n", "tokenized_datasets = dataset.map(preprocess_function, batched=True)\n", "tokenized_datasets = tokenized_datasets.remove_columns([\"sentence\", \"idx\", \"attention_mask\"])\n", "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n", "tokenized_datasets.set_format(\"torch\")\n", "\n", "train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n", "valid_dataset = tokenized_datasets[\"validation\"].shuffle(seed=SEED)\n", "\n", "train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, drop_last=True)\n", "valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)" ] }, { "cell_type": "markdown", "id": "2a3f3590", "metadata": { "id": "2a3f3590" }, "source": [ "To monitor training, we need the metric function. For SST-2, the target metric is accuracy. We will load it from the datasets library." ] }, { "cell_type": "code", "execution_count": null, "id": "1e1812be", "metadata": { "id": "1e1812be" }, "outputs": [], "source": [ "metric = load_metric('glue', task)\n", "\n", "def eval_metrics(model, dataloader, device='cpu'):\n", " model.eval()\n", " for batch in dataloader:\n", " batch = {k: v.to(device) for k, v in batch.items()}\n", "\n", " with torch.no_grad():\n", " outputs = model(**batch)\n", "\n", " logits = outputs.logits\n", " predictions = torch.argmax(logits, dim=-1)\n", " metric.add_batch(predictions=predictions, references=batch[\"labels\"])\n", " model.train()\n", " return metric.compute()" ] }, { "cell_type": "markdown", "id": "ef4323fd", "metadata": { "id": "ef4323fd" }, "source": [ "Before setting up optimizers, let's check the model parameters that will be trained." ] }, { "cell_type": "code", "execution_count": null, "id": "9cc0ba34", "metadata": { "id": "9cc0ba34" }, "outputs": [], "source": [ "for n, p in model.named_parameters():\n", " if p.requires_grad:\n", " print(n, p.requires_grad, p.device)" ] }, { "cell_type": "markdown", "id": "59cffce7", "metadata": { "id": "59cffce7" }, "source": [ "The optimizer will only work on **prompts and classifier head**: they are only trainable parameters. Let's initialize the optimizer and the learning rate scheduler." ] }, { "cell_type": "code", "execution_count": null, "id": "ef9bf344", "metadata": { "id": "ef9bf344" }, "outputs": [], "source": [ "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", "\n", "lr_scheduler = get_scheduler(\n", " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * NUM_EPOCHS\n", ")" ] }, { "cell_type": "markdown", "id": "423c56d5", "metadata": { "id": "423c56d5" }, "source": [ "Let's initialize wandb for logging and start the training loop!" ] }, { "cell_type": "code", "execution_count": null, "id": "d9e46807", "metadata": { "id": "d9e46807" }, "outputs": [], "source": [ "wandb.init(\n", " project=\"bloom-sst-2\",\n", " config={\n", " \"num_epochs\": NUM_EPOCHS,\n", " \"batch_size\": BATCH_SIZE,\n", " \"learning_rate\": LR,\n", " \"weight_decay\": WEIGHT_DECAY,\n", " \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n", " \"model_name\": MODEL_NAME,\n", " \"seed\": SEED,\n", " }\n", ")\n", "\n", "scaler = torch.cuda.amp.GradScaler()\n", "\n", "for epoch in range(NUM_EPOCHS):\n", " model.train()\n", " for batch in tqdm(train_dataloader):\n", " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", "\n", " with torch.autocast(device_type=DEVICE, dtype=torch.float16):\n", " outputs = model(**batch)\n", " loss = outputs.loss\n", " scaler.scale(loss).backward()\n", "\n", " scaler.step(optimizer)\n", " scaler.update()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", "\n", " wandb.log({\"Train Loss\": loss.detach()})\n", "\n", " accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n", " wandb.log({\"Valid Accuracy\": accuracy}, commit=False)" ] }, { "cell_type": "markdown", "id": "51770911", "metadata": { "id": "51770911" }, "source": [ "Our model has been trained! You can now upload it to the Hub for later use, try out different models [served in the public swarm](https://health.petals.dev/), or [join Petals with your own GPU](https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity)!" ] }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [], "metadata": { "collapsed": false } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" }, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } }, "colab": { "provenance": [], "gpuType": "T4" }, "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 5 }