{ "cells": [ { "cell_type": "markdown", "id": "a07e0f5e", "metadata": {}, "source": [ "
\n", " \n", "
\n", "\n", "# Distributed Bloom for Text Generation using Prompt Tuning\n", "\n", "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n", "\n", "We will adapt the BLOOM model for the chatbot task using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer.\n", "\n", "To open this notebook in colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)" ] }, { "cell_type": "markdown", "id": "a3f8526f", "metadata": {}, "source": [ "First, we have to prepare all dependencies." ] }, { "cell_type": "code", "execution_count": null, "id": "73bbc648", "metadata": {}, "outputs": [], "source": [ "import subprocess\n", "import sys\n", "\n", "!pip install -r git+https://github.com/bigscience-workshop/petals\n", "!pip install datasets wandb\n", "\n", "IN_COLAB = 'google.colab' in sys.modules\n", "if IN_COLAB: # Remove CUDA binaries on CPU-only colabs to not confuse bitsandbytes\n", " try:\n", " subprocess.check_output([\"nvidia-smi\", \"-L\"])\n", " except subprocess.CalledProcessError as e:\n", " subprocess.run(\"rm -r /usr/local/cuda/lib64\", shell=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "b4ab6ca7", "metadata": {}, "outputs": [], "source": [ "import os\n", " \n", "import torch\n", "import transformers\n", "import wandb\n", "from datasets import load_dataset\n", "from tqdm import tqdm\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader\n", "from transformers import BloomTokenizerFast, get_scheduler\n", "\n", "# Import a Petals model\n", "from petals.client.remote_model import DistributedBloomForCausalLM" ] }, { "cell_type": "markdown", "id": "1bf07b5d", "metadata": {}, "source": [ "Let's set some hyperparameters for training:" ] }, { "cell_type": "code", "execution_count": null, "id": "f04ba4d2", "metadata": {}, "outputs": [], "source": [ "MODEL_NAME = ... # select model you like\n", "INITIAL_PEERS = [...] # add your peers adresses here, like \"/ip4/192.168.1.2/tcp/31000/p2p/Qma....\"\n", "NUM_PREFIX_TOKENS = 16\n", "DEVICE = 'cpu'\n", "BATCH_SIZE = 4\n", "LR = 1e-2\n", "WEIGHT_DECAY = 0.0\n", "NUM_SAMPLES = 1000\n", "SEED = 42\n", "MODEL_MAX_LENGTH = 256\n", "TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] " ] }, { "cell_type": "markdown", "id": "d38316bd", "metadata": {}, "source": [ "Prepare tokenizer and distributed model, connect it to servers." ] }, { "cell_type": "code", "execution_count": null, "id": "03c6e53e", "metadata": {}, "outputs": [], "source": [ "tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)\n", "tokenizer.padding_side = 'right'\n", "tokenizer.model_max_length = MODEL_MAX_LENGTH\n", "model = DistributedBloomForCausalLM.from_pretrained(\n", " MODEL_NAME, \n", " initial_peers=INITIAL_PEERS, \n", " pre_seq_len=NUM_PREFIX_TOKENS, \n", " tuning_mode=TUNING_MODE\n", ").to(DEVICE)" ] }, { "cell_type": "markdown", "id": "042e3786", "metadata": {}, "source": [ "Let's prepare the Personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization." ] }, { "cell_type": "code", "execution_count": null, "id": "9c44d516", "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset(\"bavard/personachat_truecased\")\n", "\n", "\n", "def chunking(examples):\n", " inputs = [\n", " \"\\n-----\\n\".join(history) + \"\\n-----\\n\" + candidate\n", " for history, candidates in zip(examples[\"history\"], examples[\"candidates\"])\n", " for candidate in candidates\n", " ]\n", " return {\"chunks\": inputs}\n", "\n", "\n", "def tokenize(examples):\n", " outputs = {\n", " \"input_ids\": tokenizer(examples[\"chunks\"], padding='max_length', truncation=True)[\"input_ids\"]\n", " }\n", " outputs[\"labels\"] = outputs[\"input_ids\"]\n", " return outputs\n", "\n", "\n", "tokenized_datasets = (\n", " dataset\n", " .map(chunking, batched=True, remove_columns=dataset[\"train\"].column_names)\n", " .map(tokenize, batched=True, remove_columns=[\"chunks\"])\n", ")\n", "\n", "\n", "tokenized_datasets.set_format(\"torch\")\n", "train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n", "train_dataloader = DataLoader(\n", " train_dataset.select(list(range(NUM_SAMPLES))),\n", " shuffle=True,\n", " batch_size=BATCH_SIZE,\n", " drop_last=True,\n", ")" ] }, { "cell_type": "markdown", "id": "ef4323fd", "metadata": {}, "source": [ "Before setting up optimizers, check the model parameters that will be trained." ] }, { "cell_type": "code", "execution_count": null, "id": "9cc0ba34", "metadata": {}, "outputs": [], "source": [ "for n, p in model.named_parameters():\n", " if p.requires_grad:\n", " print(n, p.requires_grad, p.device)" ] }, { "cell_type": "markdown", "id": "59cffce7", "metadata": {}, "source": [ "The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler." ] }, { "cell_type": "code", "execution_count": null, "id": "ef9bf344", "metadata": {}, "outputs": [], "source": [ "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", "\n", "lr_scheduler = get_scheduler(\n", " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", ")" ] }, { "cell_type": "markdown", "id": "423c56d5", "metadata": {}, "source": [ "Let's initialize wandb for logging and start the training loop!" ] }, { "cell_type": "code", "execution_count": null, "id": "d9e46807", "metadata": {}, "outputs": [], "source": [ "wandb.init(\n", " project=\"bloom-personachat\",\n", " config={\n", " \"num_samples\": NUM_SAMPLES,\n", " \"batch_size\": BATCH_SIZE,\n", " \"learning_rate\": LR,\n", " \"weight_decay\": WEIGHT_DECAY,\n", " \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n", " \"model_name\": MODEL_NAME,\n", " \"seed\": SEED,\n", " }\n", ")\n", "\n", "for batch in tqdm(train_dataloader):\n", " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", "\n", " model.train()\n", " outputs = model(**batch)\n", " loss = outputs.loss\n", " loss.backward()\n", "\n", " optimizer.step()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", "\n", " wandb.log({\"Train Loss\": loss})" ] }, { "cell_type": "markdown", "id": "0f36cb80", "metadata": {}, "source": [ "Try to talk with the trained model! Submit an empty input to stop the execution.\n", "\n", "\n", "__Note__: In this example, we the whole dialogue as a prefix when generating each new replica. In the future, we will support a faster \"interactive\" dialogue mode, so generating a new replica will be able to reuse inference caches from the previous replica." ] }, { "cell_type": "code", "execution_count": null, "id": "720181b7", "metadata": {}, "outputs": [], "source": [ "MAX_TOKENS = 16\n", "TOP_K = 100\n", "TEMPERATURE = 0.6\n", "dialog = \"\"\n", "\n", "while True:\n", " user_phrase = input()\n", " if len(user_phrase) == 0:\n", " break\n", " dialog += f\"{user_phrase}\\n-----\\n\"\n", " inputs = tokenizer([dialog], return_tensors='pt')['input_ids']\n", " outputs = model.generate(\n", " inputs,\n", " temperature=TEMPERATURE,\n", " do_sample=True,\n", " top_k=TOP_K,\n", " eos_token_id=tokenizer.eos_token_id,\n", " max_new_tokens=MAX_TOKENS,\n", " )\n", " bloom_answer = tokenizer.batch_decode(outputs)[0]\n", " bloom_answer = bloom_answer[len(dialog):].split(\"\\n\")[0]\n", " print(bloom_answer)\n", " dialog += f\"{bloom_answer}\\n-----\\n\"" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.6.9 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" }, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, "nbformat": 4, "nbformat_minor": 5 }