diff --git a/README.md b/README.md index a363222..d244600 100644 --- a/README.md +++ b/README.md @@ -155,7 +155,9 @@ loss.backward() print("Gradients (norm):", model.transformer.word_embeddings.weight.grad.norm()) ``` -Of course, this is a simplified code snippet. For actual training, see our example on "deep" prompt-tuning here: [examples/prompt-tuning-personachat.ipynb](./examples/prompt-tuning-personachat.ipynb). +Of course, this is a simplified code snippet. For actual training, see the example notebooks with "deep" prompt-tuning: +- Simple text semantic classification: [examples/prompt-tuning-sst2.ipynb](./examples/prompt-tuning-sst2.ipynb). +- A personified chatbot: [examples/prompt-tuning-personachat.ipynb](./examples/prompt-tuning-personachat.ipynb). Here's a [more advanced tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) that covers 8-bit quantization and best practices for running Petals. diff --git a/examples/prompt-tuning-personachat.ipynb b/examples/prompt-tuning-personachat.ipynb index 77312fd..3c74449 100644 --- a/examples/prompt-tuning-personachat.ipynb +++ b/examples/prompt-tuning-personachat.ipynb @@ -33,7 +33,6 @@ "metadata": {}, "outputs": [], "source": [ - "# This block is only need for colab users. It will change nothing if you are running this notebook locally.\n", "import subprocess\n", "import sys\n", "\n", @@ -41,14 +40,18 @@ "IN_COLAB = 'google.colab' in sys.modules\n", "\n", "if IN_COLAB:\n", - " subprocess.run(['git', 'clone', 'https://github.com/bigscience-workshop/petals'])\n", - " subprocess.run(['pip', 'install', '-r', 'petals/requirements.txt'])\n", - " subprocess.run(['pip', 'install', 'datasets', 'lib64'])\n", + " subprocess.run(\"git clone https://github.com/bigscience-workshop/petals\", shell=True)\n", + " subprocess.run(\"pip install -r petals/requirements.txt\", shell=True)\n", + " subprocess.run(\"pip install datasets wandb\", shell=True)\n", "\n", " try:\n", " subprocess.check_output([\"nvidia-smi\", \"-L\"])\n", " except subprocess.CalledProcessError as e:\n", - " subprocess.run(['rm', '-r', '/usr/local/cuda/lib64'])" + " subprocess.run(\"rm -r /usr/local/cuda/lib64\", shell=True)\n", + "\n", + " sys.path.insert(0, './petals/')\n", + "else:\n", + " sys.path.insert(0, \"..\")" ] }, { @@ -60,7 +63,6 @@ "source": [ "import os\n", "import sys\n", - "sys.path.insert(0, \"..\") # for colab change to sys.path.insert(0, './petals/')\n", " \n", "import torch\n", "import transformers\n", @@ -312,7 +314,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.10 64-bit", + "display_name": "Python 3.8.0 ('petals')", "language": "python", "name": "python3" }, @@ -326,11 +328,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.9" + "version": "3.8.0" }, "vscode": { "interpreter": { - "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + "hash": "a303c9f329a09f921588ea6ef03898c90b4a8e255a47e0bd6e36f6331488f609" } } }, diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb new file mode 100644 index 0000000..ed1de31 --- /dev/null +++ b/examples/prompt-tuning-sst2.ipynb @@ -0,0 +1,326 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a07e0f5e", + "metadata": {}, + "source": [ + "
\n", + " \n", + "
\n", + "\n", + "# Distributed Bloom for Text Classification using Prompt Tuning\n", + "\n", + "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n", + "\n", + "We will adapt the BLOOM model for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n", + "\n", + "To open this notebook in colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "a3f8526f", + "metadata": {}, + "source": [ + "First, we have to prepare all dependencies." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73bbc648", + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess\n", + "import sys\n", + "\n", + "\n", + "IN_COLAB = 'google.colab' in sys.modules\n", + "\n", + "if IN_COLAB:\n", + " subprocess.run(\"git clone https://github.com/bigscience-workshop/petals\", shell=True)\n", + " subprocess.run(\"pip install -r petals/requirements.txt\", shell=True)\n", + " subprocess.run(\"pip install datasets wandb\", shell=True)\n", + "\n", + " try:\n", + " subprocess.check_output([\"nvidia-smi\", \"-L\"])\n", + " except subprocess.CalledProcessError as e:\n", + " subprocess.run(\"rm -r /usr/local/cuda/lib64\", shell=True)\n", + "\n", + " sys.path.insert(0, './petals/')\n", + "else:\n", + " sys.path.insert(0, \"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4ab6ca7", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + " \n", + "import torch\n", + "import transformers\n", + "import wandb\n", + "from datasets import load_dataset, load_metric\n", + "from tqdm import tqdm\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader\n", + "from transformers import get_scheduler\n", + "\n", + "# Import a Petals model\n", + "from src.client.remote_model import DistributedBloomForSequenceClassification" + ] + }, + { + "cell_type": "markdown", + "id": "1bf07b5d", + "metadata": {}, + "source": [ + "Let's set some hyperparameters for training:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f04ba4d2", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_NAME = ... # select model you like\n", + "INITIAL_PEERS = [...] # add your peers adresses here, like \"/ip4/192.168.1.2/tcp/31000/p2p/Qma....\"\n", + "NUM_PREFIX_TOKENS = 16\n", + "DEVICE = 'cpu'\n", + "BATCH_SIZE = 4\n", + "LR = 1e-2\n", + "WEIGHT_DECAY = 0.0\n", + "NUM_SAMPLES = 1000\n", + "NUM_EPOCHS = 3\n", + "SEED = 42\n", + "MODEL_MAX_LENGTH = 64\n", + "TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] " + ] + }, + { + "cell_type": "markdown", + "id": "d38316bd", + "metadata": {}, + "source": [ + "Prepare tokenizer and distributed model, connect it to servers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03c6e53e", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)\n", + "tokenizer.padding_side = 'right'\n", + "tokenizer.model_max_length = MODEL_MAX_LENGTH\n", + "model = DistributedBloomForSequenceClassification.from_pretrained(\n", + " MODEL_NAME, \n", + " initial_peers=INITIAL_PEERS, \n", + " pre_seq_len=NUM_PREFIX_TOKENS, \n", + " tuning_mode=TUNING_MODE\n", + ").to(DEVICE)" + ] + }, + { + "cell_type": "markdown", + "id": "042e3786", + "metadata": {}, + "source": [ + "Let's prepare the SST-2 dataset. We need just one preprocessing function to tokenize the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c44d516", + "metadata": {}, + "outputs": [], + "source": [ + "task = 'sst2'\n", + "\n", + "dataset = load_dataset(\"glue\", task)\n", + "\n", + "def preprocess_function(examples):\n", + " return tokenizer(examples[\"sentence\"], padding='max_length', truncation=True)\n", + "\n", + "tokenized_datasets = dataset.map(preprocess_function, batched=True)\n", + "tokenized_datasets = tokenized_datasets.remove_columns([\"sentence\", \"idx\", \"attention_mask\"])\n", + "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n", + "tokenized_datasets.set_format(\"torch\")\n", + "\n", + "train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n", + "valid_dataset = tokenized_datasets[\"validation\"].shuffle(seed=SEED)\n", + "\n", + "train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, drop_last=True)\n", + "valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)" + ] + }, + { + "cell_type": "markdown", + "id": "2a3f3590", + "metadata": {}, + "source": [ + "To check training, we need a metric function. For SST-2 task is accuracy. We will load it from the datasets library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e1812be", + "metadata": {}, + "outputs": [], + "source": [ + "metric = load_metric('glue', task)\n", + "\n", + "def eval_metrics(model, dataloader, device='cpu'):\n", + " model.eval()\n", + " for batch in dataloader:\n", + " batch = {k: v.to(device) for k, v in batch.items()}\n", + " \n", + " with torch.no_grad():\n", + " outputs = model(**batch)\n", + "\n", + " logits = outputs.logits\n", + " predictions = torch.argmax(logits, dim=-1)\n", + " metric.add_batch(predictions=predictions, references=batch[\"labels\"])\n", + " model.train()\n", + " return metric.compute()" + ] + }, + { + "cell_type": "markdown", + "id": "ef4323fd", + "metadata": {}, + "source": [ + "Before setting up optimizers, check the model parameters that will be trained." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cc0ba34", + "metadata": {}, + "outputs": [], + "source": [ + "for n, p in model.named_parameters():\n", + " if p.requires_grad:\n", + " print(n, p.requires_grad, p.device)" + ] + }, + { + "cell_type": "markdown", + "id": "59cffce7", + "metadata": {}, + "source": [ + "The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef9bf344", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", + "\n", + "lr_scheduler = get_scheduler(\n", + " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "423c56d5", + "metadata": {}, + "source": [ + "Let's initialize wandb for logging and start the training loop!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9e46807", + "metadata": {}, + "outputs": [], + "source": [ + "wandb.init(\n", + " project=\"bloom-sst-2\",\n", + " config={\n", + " \"num_epochs\": NUM_EPOCHS,\n", + " \"num_samples\": NUM_SAMPLES,\n", + " \"batch_size\": BATCH_SIZE,\n", + " \"learning_rate\": LR,\n", + " \"weight_decay\": WEIGHT_DECAY,\n", + " \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n", + " \"model_name\": MODEL_NAME,\n", + " \"seed\": SEED,\n", + " }\n", + ")\n", + "\n", + "for epoch in range(NUM_EPOCHS):\n", + " for batch in tqdm(train_dataloader):\n", + " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", + "\n", + " model.train()\n", + " outputs = model(**batch)\n", + " loss = outputs.loss\n", + " loss.backward()\n", + "\n", + " optimizer.step()\n", + " lr_scheduler.step()\n", + " optimizer.zero_grad()\n", + "\n", + " wandb.log({\"Train Loss\": loss})\n", + "\n", + " accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n", + " wandb.log({\"Valid Accuracy\": accuracy}, commit=False)" + ] + }, + { + "cell_type": "markdown", + "id": "51770911", + "metadata": {}, + "source": [ + "Our model have been trained!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.10 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.9" + }, + "vscode": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}