diff --git a/.gitignore b/.gitignore index 7114a35..d8c10af 100644 --- a/.gitignore +++ b/.gitignore @@ -126,3 +126,5 @@ dmypy.json # Pyre type checker .pyre/ + +.idea/ diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index c5dac6a..876db8f 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -3,17 +3,19 @@ { "cell_type": "markdown", "id": "a07e0f5e", - "metadata": {}, + "metadata": { + "id": "a07e0f5e" + }, "source": [ "
\n", " \n", "
\n", "\n", - "# Distributed Bloom for Text Classification using Prompt Tuning\n", + "# Distributed LLaMA for Text Classification using Prompt Tuning\n", "\n", - "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n", + "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [LLaMA](https://github.com/facebookresearch/llama) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the LLaMA blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n", "\n", - "We will adapt BLOOM for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n", + "We will adapt LLaMA for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n", "\n", "To use this notebook in Colab:\n", "\n", @@ -24,7 +26,9 @@ { "cell_type": "markdown", "id": "a3f8526f", - "metadata": {}, + "metadata": { + "id": "a3f8526f" + }, "source": [ "First, we have to prepare all dependencies." ] @@ -33,17 +37,22 @@ "cell_type": "code", "execution_count": null, "id": "73bbc648", - "metadata": {}, + "metadata": { + "id": "73bbc648" + }, "outputs": [], "source": [ - "%pip install -q petals datasets wandb scikit-learn" + "%pip install -q datasets wandb scikit-learn\n", + "%pip install -q git+https://github.com/bigscience-workshop/petals@main" ] }, { "cell_type": "code", "execution_count": null, "id": "b4ab6ca7", - "metadata": {}, + "metadata": { + "id": "b4ab6ca7" + }, "outputs": [], "source": [ "import os\n", @@ -57,15 +66,19 @@ "from tqdm import tqdm\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader\n", - "from transformers import BloomTokenizerFast, get_scheduler\n", + "from transformers import LlamaTokenizer, get_scheduler, set_seed\n", "\n", - "from petals import DistributedBloomForSequenceClassification" + "from petals import DistributedLlamaForSequenceClassification\n", + "\n", + "set_seed(0)" ] }, { "cell_type": "markdown", "id": "1bf07b5d", - "metadata": {}, + "metadata": { + "id": "1bf07b5d" + }, "source": [ "Let's set some hyperparameters for training:" ] @@ -74,14 +87,15 @@ "cell_type": "code", "execution_count": null, "id": "f04ba4d2", - "metadata": {}, + "metadata": { + "id": "f04ba4d2" + }, "outputs": [], "source": [ "# Choose a model you'd like to prompt-tune. We recommend starting with\n", - "# the smaller 7.1B version of BLOOM (bigscience/bloom-7b1-petals) for faster prototyping.\n", - "# Once your code is ready, you can switch to full-scale\n", - "# 176B-parameter BLOOM (bigscience/bloom-petals) or BLOOMZ (bigscience/bloomz-petals).\n", - "MODEL_NAME = \"bigscience/bloom-7b1-petals\"\n", + "# a smaller model (bigscience/bloom-7b1-petals) for faster prototyping.\n", + "# The code below uses LLaMA-65B.\n", + "MODEL_NAME = \"enoch/llama-65b-hf\"\n", "\n", "# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n", "# The latter fine-tunes separate prefixes for each transformer block,\n", @@ -89,9 +103,9 @@ "# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\n", "TUNING_MODE = 'ptune'\n", "\n", - "NUM_PREFIX_TOKENS = 16\n", + "NUM_PREFIX_TOKENS = 8\n", "DEVICE = 'cuda'\n", - "BATCH_SIZE = 16\n", + "BATCH_SIZE = 32\n", "LR = 1e-2\n", "WEIGHT_DECAY = 0.0\n", "NUM_EPOCHS = 3\n", @@ -102,32 +116,40 @@ { "cell_type": "markdown", "id": "d38316bd", - "metadata": {}, + "metadata": { + "id": "d38316bd" + }, "source": [ - "Prepare tokenizer and distributed model, connect it to servers." + "Here, we prepare tokenizer and distributed model and connect it to the public swarm." ] }, { "cell_type": "code", "execution_count": null, "id": "03c6e53e", - "metadata": {}, + "metadata": { + "id": "03c6e53e" + }, "outputs": [], "source": [ - "tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)\n", + "tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME)\n", "tokenizer.padding_side = 'right'\n", "tokenizer.model_max_length = MODEL_MAX_LENGTH\n", - "model = DistributedBloomForSequenceClassification.from_pretrained(\n", + "tokenizer.pad_token = tokenizer.unk_token\n", + "model = DistributedLlamaForSequenceClassification.from_pretrained(\n", " MODEL_NAME,\n", " pre_seq_len=NUM_PREFIX_TOKENS,\n", " tuning_mode=TUNING_MODE\n", - ").to(DEVICE)" + ").float().to(DEVICE)\n", + "model.config.pad_token_id = tokenizer.pad_token_id" ] }, { "cell_type": "markdown", "id": "042e3786", - "metadata": {}, + "metadata": { + "id": "042e3786" + }, "source": [ "Let's prepare the SST-2 dataset. We need just one preprocessing function to tokenize the dataset." ] @@ -136,7 +158,9 @@ "cell_type": "code", "execution_count": null, "id": "9c44d516", - "metadata": {}, + "metadata": { + "id": "9c44d516" + }, "outputs": [], "source": [ "task = 'sst2'\n", @@ -144,7 +168,7 @@ "dataset = load_dataset(\"glue\", task)\n", "\n", "def preprocess_function(examples):\n", - " return tokenizer(examples[\"sentence\"], padding='max_length', truncation=True)\n", + " return tokenizer(examples[\"sentence\"], padding='max_length', truncation=True, return_token_type_ids=False)\n", "\n", "tokenized_datasets = dataset.map(preprocess_function, batched=True)\n", "tokenized_datasets = tokenized_datasets.remove_columns([\"sentence\", \"idx\", \"attention_mask\"])\n", @@ -161,16 +185,20 @@ { "cell_type": "markdown", "id": "2a3f3590", - "metadata": {}, + "metadata": { + "id": "2a3f3590" + }, "source": [ - "To check training, we need a metric function. For SST-2 task is accuracy. We will load it from the datasets library." + "To monitor training, we need the metric function. For SST-2, the target metric is accuracy. We will load it from the datasets library." ] }, { "cell_type": "code", "execution_count": null, "id": "1e1812be", - "metadata": {}, + "metadata": { + "id": "1e1812be" + }, "outputs": [], "source": [ "metric = load_metric('glue', task)\n", @@ -179,7 +207,7 @@ " model.eval()\n", " for batch in dataloader:\n", " batch = {k: v.to(device) for k, v in batch.items()}\n", - " \n", + "\n", " with torch.no_grad():\n", " outputs = model(**batch)\n", "\n", @@ -193,16 +221,20 @@ { "cell_type": "markdown", "id": "ef4323fd", - "metadata": {}, + "metadata": { + "id": "ef4323fd" + }, "source": [ - "Before setting up optimizers, check the model parameters that will be trained." + "Before setting up optimizers, let's check the model parameters that will be trained." ] }, { "cell_type": "code", "execution_count": null, "id": "9cc0ba34", - "metadata": {}, + "metadata": { + "id": "9cc0ba34" + }, "outputs": [], "source": [ "for n, p in model.named_parameters():\n", @@ -213,29 +245,35 @@ { "cell_type": "markdown", "id": "59cffce7", - "metadata": {}, + "metadata": { + "id": "59cffce7" + }, "source": [ - "The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler." + "The optimizer will only work on **prompts and classifier head**: they are only trainable parameters. Let's initialize the optimizer and the learning rate scheduler." ] }, { "cell_type": "code", "execution_count": null, "id": "ef9bf344", - "metadata": {}, + "metadata": { + "id": "ef9bf344" + }, "outputs": [], "source": [ "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", "\n", "lr_scheduler = get_scheduler(\n", - " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", + " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * NUM_EPOCHS\n", ")" ] }, { "cell_type": "markdown", "id": "423c56d5", - "metadata": {}, + "metadata": { + "id": "423c56d5" + }, "source": [ "Let's initialize wandb for logging and start the training loop!" ] @@ -244,7 +282,9 @@ "cell_type": "code", "execution_count": null, "id": "d9e46807", - "metadata": {}, + "metadata": { + "id": "d9e46807" + }, "outputs": [], "source": [ "wandb.init(\n", @@ -260,20 +300,24 @@ " }\n", ")\n", "\n", + "scaler = torch.cuda.amp.GradScaler()\n", + "\n", "for epoch in range(NUM_EPOCHS):\n", + " model.train()\n", " for batch in tqdm(train_dataloader):\n", " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", "\n", - " model.train()\n", - " outputs = model(**batch)\n", + " with torch.autocast(device_type=DEVICE, dtype=torch.float16):\n", + " outputs = model(**batch)\n", " loss = outputs.loss\n", - " loss.backward()\n", + " scaler.scale(loss).backward()\n", "\n", - " optimizer.step()\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", "\n", - " wandb.log({\"Train Loss\": loss})\n", + " wandb.log({\"Train Loss\": loss.detach()})\n", "\n", " accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n", " wandb.log({\"Valid Accuracy\": accuracy}, commit=False)" @@ -282,184 +326,26 @@ { "cell_type": "markdown", "id": "51770911", - "metadata": {}, - "source": [ - "Our model have been trained!" - ] - }, - { - "cell_type": "markdown", - "id": "1bbf014f", - "metadata": {}, - "source": [ - "## Beyond soft-prompt tuning\n", - "\n", - "Let's try to tune model using adapters in the middle of the model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3bea4391", - "metadata": {}, - "outputs": [], - "source": [ - "class BloomBasedClassifier(nn.Module):\n", - " def __init__(\n", - " self,\n", - " model,\n", - " intermediate_size: int = 32,\n", - " num_classes: int = 2,\n", - " adapter_layer_position: int = 6,\n", - " head_layer_position: int = 10\n", - " ):\n", - " super().__init__()\n", - " self.distributed_layers = model.transformer.h\n", - "\n", - " self.hidden_size = model.config.hidden_size\n", - " self.dtype = model.config.torch_dtype\n", - " self.intermediate_size = intermediate_size\n", - " self.num_classes = num_classes\n", - " self.adapter_layer_position = adapter_layer_position\n", - " self.head_layer_position = head_layer_position\n", - " \n", - " self.word_embeddings = model.transformer.word_embeddings\n", - " self.adapter = nn.Sequential(\n", - " nn.Linear(self.hidden_size, self.intermediate_size),\n", - " nn.Linear(self.intermediate_size, self.hidden_size),\n", - " ).to(self.dtype)\n", - " self.head = nn.Sequential(\n", - " nn.LayerNorm(self.hidden_size),\n", - " nn.Linear(self.hidden_size, self.num_classes),\n", - " ).to(self.dtype)\n", - " \n", - " def forward(self, embeddings):\n", - " before_layers = self.distributed_layers[0:self.adapter_layer_position]\n", - " after_layers = self.distributed_layers[self.adapter_layer_position:self.head_layer_position]\n", - " \n", - " hidden_states = before_layers(embeddings)\n", - " hidden_states = self.adapter(hidden_states)\n", - " hidden_states = after_layers(hidden_states)\n", - " pooled_states = torch.mean(hidden_states, dim=1)\n", - " return self.head(pooled_states)" - ] - }, - { - "cell_type": "markdown", - "id": "15299620", - "metadata": {}, - "source": [ - "Clear model and device memory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aa27b168", - "metadata": {}, - "outputs": [], - "source": [ - "del model, optimizer, lr_scheduler\n", - "torch.cuda.empty_cache()" - ] - }, - { - "cell_type": "markdown", - "id": "5406390f", - "metadata": {}, - "source": [ - "Create new model with adapters." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a251db80", - "metadata": {}, - "outputs": [], - "source": [ - "INTERMEDIATE_SIZE = 32\n", - "ADAPTER_LAYER_POSITION = 6\n", - "HEAD_LAYER_POSITION = 10" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3578df3a", - "metadata": {}, - "outputs": [], - "source": [ - "cls_model = BloomBasedClassifier(\n", - " DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME),\n", - " intermediate_size=INTERMEDIATE_SIZE,\n", - " adapter_layer_position=ADAPTER_LAYER_POSITION,\n", - " head_layer_position=HEAD_LAYER_POSITION,\n", - ").to(DEVICE)\n", - "cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", - "cls_criterion = nn.CrossEntropyLoss()\n", - "\n", - "lr_scheduler = get_scheduler(\n", - " name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "a40468b9", - "metadata": {}, + "metadata": { + "id": "51770911" + }, "source": [ - "And start training our new adapted model." + "Our model has been trained! You can now upload it to the Hub for later use, try out different models [served in the public swarm](http://health.petals.ml/), or [join Petals with your own GPU](https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity)!" ] }, { "cell_type": "code", "execution_count": null, - "id": "ed051a5d", - "metadata": {}, "outputs": [], - "source": [ - "wandb.init(\n", - " project=\"bloom_based_cls-sst-2\",\n", - " config={\n", - " \"num_epochs\": NUM_EPOCHS,\n", - " \"batch_size\": BATCH_SIZE,\n", - " \"learning_rate\": LR,\n", - " \"weight_decay\": WEIGHT_DECAY,\n", - " \"model_name\": MODEL_NAME,\n", - " \"seed\": SEED,\n", - " \"intermediate_size\": INTERMEDIATE_SIZE,\n", - " \"adapter_layer_position\": ADAPTER_LAYER_POSITION,\n", - " \"head_layer_position\": HEAD_LAYER_POSITION,\n", - " }\n", - ")\n", - "\n", - "for epoch in range(NUM_EPOCHS):\n", - " for batch in tqdm(train_dataloader):\n", - " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", - "\n", - " cls_model.train()\n", - " with torch.no_grad():\n", - " embeddings_output = cls_model.word_embeddings(batch[\"input_ids\"])\n", - " outputs = cls_model(embeddings_output)\n", - " loss = cls_criterion(outputs, batch[\"labels\"])\n", - " loss.backward()\n", - "\n", - " cls_optimizer.step()\n", - " lr_scheduler.step()\n", - " cls_optimizer.zero_grad()\n", - "\n", - " wandb.log({\"Train Loss\": loss})\n", - "\n", - " accuracy = eval_metrics(cls_model, valid_dataloader, device=DEVICE)\n", - " wandb.log({\"Valid Accuracy\": accuracy}, commit=False)" - ] + "source": [], + "metadata": { + "collapsed": false + } } ], "metadata": { "kernelspec": { "display_name": "Python 3", - "language": "python", "name": "python3" }, "language_info": { @@ -478,7 +364,12 @@ "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } - } + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 5