You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
486 lines
15 KiB
Plaintext
486 lines
15 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a07e0f5e",
|
|
"metadata": {},
|
|
"source": [
|
|
"<div>\n",
|
|
"<img src=\"https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67\" width=\"40%\"> \n",
|
|
"</div>\n",
|
|
"\n",
|
|
"# Distributed Bloom for Text Classification using Prompt Tuning\n",
|
|
"\n",
|
|
"In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
|
|
"\n",
|
|
"We will adapt BLOOM for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n",
|
|
"\n",
|
|
"To use this notebook in Colab:\n",
|
|
"\n",
|
|
"1. Follow this link: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)\n",
|
|
"2. Go to **Runtime** -> **Change runtime type** and select the GPU accelerator."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a3f8526f",
|
|
"metadata": {},
|
|
"source": [
|
|
"First, we have to prepare all dependencies."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "73bbc648",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%pip install -q petals datasets wandb scikit-learn"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b4ab6ca7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"import transformers\n",
|
|
"import wandb\n",
|
|
"from datasets import load_dataset, load_metric\n",
|
|
"from tqdm import tqdm\n",
|
|
"from torch.optim import AdamW\n",
|
|
"from torch.utils.data import DataLoader\n",
|
|
"from transformers import BloomTokenizerFast, get_scheduler\n",
|
|
"\n",
|
|
"from petals import DistributedBloomForSequenceClassification"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1bf07b5d",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's set some hyperparameters for training:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f04ba4d2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Choose a model you'd like to prompt-tune. We recommend starting with\n",
|
|
"# the smaller 7.1B version of BLOOM (bigscience/bloom-7b1-petals) for faster prototyping.\n",
|
|
"# Once your code is ready, you can switch to full-scale\n",
|
|
"# 176B-parameter BLOOM (bigscience/bloom-petals) or BLOOMZ (bigscience/bloomz-petals).\n",
|
|
"MODEL_NAME = \"bigscience/bloom-7b1-petals\"\n",
|
|
"\n",
|
|
"# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n",
|
|
"# The latter fine-tunes separate prefixes for each transformer block,\n",
|
|
"# so prompt-tuning will take more time but yield better results.\n",
|
|
"# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\n",
|
|
"TUNING_MODE = 'ptune'\n",
|
|
"\n",
|
|
"NUM_PREFIX_TOKENS = 16\n",
|
|
"DEVICE = 'cuda'\n",
|
|
"BATCH_SIZE = 16\n",
|
|
"LR = 1e-2\n",
|
|
"WEIGHT_DECAY = 0.0\n",
|
|
"NUM_EPOCHS = 3\n",
|
|
"SEED = 42\n",
|
|
"MODEL_MAX_LENGTH = 64"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d38316bd",
|
|
"metadata": {},
|
|
"source": [
|
|
"Prepare tokenizer and distributed model, connect it to servers."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "03c6e53e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
|
|
"tokenizer.padding_side = 'right'\n",
|
|
"tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
|
|
"model = DistributedBloomForSequenceClassification.from_pretrained(\n",
|
|
" MODEL_NAME,\n",
|
|
" pre_seq_len=NUM_PREFIX_TOKENS,\n",
|
|
" tuning_mode=TUNING_MODE\n",
|
|
").to(DEVICE)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "042e3786",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's prepare the SST-2 dataset. We need just one preprocessing function to tokenize the dataset."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9c44d516",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"task = 'sst2'\n",
|
|
"\n",
|
|
"dataset = load_dataset(\"glue\", task)\n",
|
|
"\n",
|
|
"def preprocess_function(examples):\n",
|
|
" return tokenizer(examples[\"sentence\"], padding='max_length', truncation=True)\n",
|
|
"\n",
|
|
"tokenized_datasets = dataset.map(preprocess_function, batched=True)\n",
|
|
"tokenized_datasets = tokenized_datasets.remove_columns([\"sentence\", \"idx\", \"attention_mask\"])\n",
|
|
"tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
|
|
"tokenized_datasets.set_format(\"torch\")\n",
|
|
"\n",
|
|
"train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n",
|
|
"valid_dataset = tokenized_datasets[\"validation\"].shuffle(seed=SEED)\n",
|
|
"\n",
|
|
"train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, drop_last=True)\n",
|
|
"valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2a3f3590",
|
|
"metadata": {},
|
|
"source": [
|
|
"To check training, we need a metric function. For SST-2 task is accuracy. We will load it from the datasets library."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1e1812be",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"metric = load_metric('glue', task)\n",
|
|
"\n",
|
|
"def eval_metrics(model, dataloader, device='cpu'):\n",
|
|
" model.eval()\n",
|
|
" for batch in dataloader:\n",
|
|
" batch = {k: v.to(device) for k, v in batch.items()}\n",
|
|
" \n",
|
|
" with torch.no_grad():\n",
|
|
" outputs = model(**batch)\n",
|
|
"\n",
|
|
" logits = outputs.logits\n",
|
|
" predictions = torch.argmax(logits, dim=-1)\n",
|
|
" metric.add_batch(predictions=predictions, references=batch[\"labels\"])\n",
|
|
" model.train()\n",
|
|
" return metric.compute()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ef4323fd",
|
|
"metadata": {},
|
|
"source": [
|
|
"Before setting up optimizers, check the model parameters that will be trained."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9cc0ba34",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for n, p in model.named_parameters():\n",
|
|
" if p.requires_grad:\n",
|
|
" print(n, p.requires_grad, p.device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "59cffce7",
|
|
"metadata": {},
|
|
"source": [
|
|
"The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ef9bf344",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
|
|
"\n",
|
|
"lr_scheduler = get_scheduler(\n",
|
|
" name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "423c56d5",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's initialize wandb for logging and start the training loop!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d9e46807",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"wandb.init(\n",
|
|
" project=\"bloom-sst-2\",\n",
|
|
" config={\n",
|
|
" \"num_epochs\": NUM_EPOCHS,\n",
|
|
" \"batch_size\": BATCH_SIZE,\n",
|
|
" \"learning_rate\": LR,\n",
|
|
" \"weight_decay\": WEIGHT_DECAY,\n",
|
|
" \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n",
|
|
" \"model_name\": MODEL_NAME,\n",
|
|
" \"seed\": SEED,\n",
|
|
" }\n",
|
|
")\n",
|
|
"\n",
|
|
"for epoch in range(NUM_EPOCHS):\n",
|
|
" for batch in tqdm(train_dataloader):\n",
|
|
" batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
|
|
"\n",
|
|
" model.train()\n",
|
|
" outputs = model(**batch)\n",
|
|
" loss = outputs.loss\n",
|
|
" loss.backward()\n",
|
|
"\n",
|
|
" optimizer.step()\n",
|
|
" lr_scheduler.step()\n",
|
|
" optimizer.zero_grad()\n",
|
|
"\n",
|
|
" wandb.log({\"Train Loss\": loss})\n",
|
|
"\n",
|
|
" accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n",
|
|
" wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "51770911",
|
|
"metadata": {},
|
|
"source": [
|
|
"Our model have been trained!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1bbf014f",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Beyond soft-prompt tuning\n",
|
|
"\n",
|
|
"Let's try to tune model using adapters in the middle of the model."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3bea4391",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class BloomBasedClassifier(nn.Module):\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" model,\n",
|
|
" intermediate_size: int = 32,\n",
|
|
" num_classes: int = 2,\n",
|
|
" adapter_layer_position: int = 6,\n",
|
|
" head_layer_position: int = 10\n",
|
|
" ):\n",
|
|
" super().__init__()\n",
|
|
" self.distributed_layers = model.transformer.h\n",
|
|
"\n",
|
|
" self.hidden_size = model.config.hidden_size\n",
|
|
" self.dtype = model.config.torch_dtype\n",
|
|
" self.intermediate_size = intermediate_size\n",
|
|
" self.num_classes = num_classes\n",
|
|
" self.adapter_layer_position = adapter_layer_position\n",
|
|
" self.head_layer_position = head_layer_position\n",
|
|
" \n",
|
|
" self.word_embeddings = model.transformer.word_embeddings\n",
|
|
" self.adapter = nn.Sequential(\n",
|
|
" nn.Linear(self.hidden_size, self.intermediate_size),\n",
|
|
" nn.Linear(self.intermediate_size, self.hidden_size),\n",
|
|
" ).to(self.dtype)\n",
|
|
" self.head = nn.Sequential(\n",
|
|
" nn.LayerNorm(self.hidden_size),\n",
|
|
" nn.Linear(self.hidden_size, self.num_classes),\n",
|
|
" ).to(self.dtype)\n",
|
|
" \n",
|
|
" def forward(self, embeddings):\n",
|
|
" before_layers = self.distributed_layers[0:self.adapter_layer_position]\n",
|
|
" after_layers = self.distributed_layers[self.adapter_layer_position:self.head_layer_position]\n",
|
|
" \n",
|
|
" hidden_states = before_layers(embeddings)\n",
|
|
" hidden_states = self.adapter(hidden_states)\n",
|
|
" hidden_states = after_layers(hidden_states)\n",
|
|
" pooled_states = torch.mean(hidden_states, dim=1)\n",
|
|
" return self.head(pooled_states)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "15299620",
|
|
"metadata": {},
|
|
"source": [
|
|
"Clear model and device memory."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "aa27b168",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"del model, optimizer, lr_scheduler\n",
|
|
"torch.cuda.empty_cache()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5406390f",
|
|
"metadata": {},
|
|
"source": [
|
|
"Create new model with adapters."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a251db80",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"INTERMEDIATE_SIZE = 32\n",
|
|
"ADAPTER_LAYER_POSITION = 6\n",
|
|
"HEAD_LAYER_POSITION = 10"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3578df3a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"cls_model = BloomBasedClassifier(\n",
|
|
" DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME),\n",
|
|
" intermediate_size=INTERMEDIATE_SIZE,\n",
|
|
" adapter_layer_position=ADAPTER_LAYER_POSITION,\n",
|
|
" head_layer_position=HEAD_LAYER_POSITION,\n",
|
|
").to(DEVICE)\n",
|
|
"cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
|
|
"cls_criterion = nn.CrossEntropyLoss()\n",
|
|
"\n",
|
|
"lr_scheduler = get_scheduler(\n",
|
|
" name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a40468b9",
|
|
"metadata": {},
|
|
"source": [
|
|
"And start training our new adapted model."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ed051a5d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"wandb.init(\n",
|
|
" project=\"bloom_based_cls-sst-2\",\n",
|
|
" config={\n",
|
|
" \"num_epochs\": NUM_EPOCHS,\n",
|
|
" \"batch_size\": BATCH_SIZE,\n",
|
|
" \"learning_rate\": LR,\n",
|
|
" \"weight_decay\": WEIGHT_DECAY,\n",
|
|
" \"model_name\": MODEL_NAME,\n",
|
|
" \"seed\": SEED,\n",
|
|
" \"intermediate_size\": INTERMEDIATE_SIZE,\n",
|
|
" \"adapter_layer_position\": ADAPTER_LAYER_POSITION,\n",
|
|
" \"head_layer_position\": HEAD_LAYER_POSITION,\n",
|
|
" }\n",
|
|
")\n",
|
|
"\n",
|
|
"for epoch in range(NUM_EPOCHS):\n",
|
|
" for batch in tqdm(train_dataloader):\n",
|
|
" batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
|
|
"\n",
|
|
" cls_model.train()\n",
|
|
" with torch.no_grad():\n",
|
|
" embeddings_output = cls_model.word_embeddings(batch[\"input_ids\"])\n",
|
|
" outputs = cls_model(embeddings_output)\n",
|
|
" loss = cls_criterion(outputs, batch[\"labels\"])\n",
|
|
" loss.backward()\n",
|
|
"\n",
|
|
" cls_optimizer.step()\n",
|
|
" lr_scheduler.step()\n",
|
|
" cls_optimizer.zero_grad()\n",
|
|
"\n",
|
|
" wandb.log({\"Train Loss\": loss})\n",
|
|
"\n",
|
|
" accuracy = eval_metrics(cls_model, valid_dataloader, device=DEVICE)\n",
|
|
" wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.8"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|