diff --git a/examples/prompt-tuning-personachat.ipynb b/examples/prompt-tuning-personachat.ipynb new file mode 100644 index 0000000..7231d5b --- /dev/null +++ b/examples/prompt-tuning-personachat.ipynb @@ -0,0 +1,307 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a07e0f5e", + "metadata": {}, + "source": [ + "
\n", + " \n", + "
\n", + "\n", + "# Distributed Bloom for Text Generation using Prompt Tuning\n", + "\n", + "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n", + "\n", + "We will adapt the BLOOM model for the chatbot task using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer." + ] + }, + { + "cell_type": "markdown", + "id": "a3f8526f", + "metadata": {}, + "source": [ + "First, we have to prepare all dependencies." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4ab6ca7", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "sys.path.insert(0, \"..\")\n", + " \n", + "import torch\n", + "import transformers\n", + "import wandb\n", + "from datasets import load_dataset\n", + "from tqdm import tqdm\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader\n", + "from transformers import get_scheduler\n", + "\n", + "# Import a Petals model\n", + "from src.client.remote_model import DistributedBloomForCausalLM" + ] + }, + { + "cell_type": "markdown", + "id": "1bf07b5d", + "metadata": {}, + "source": [ + "Let's set some hyperparameters for training:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f04ba4d2", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_NAME = ... # select model you like\n", + "INITIAL_PEERS = [...] # add your peers adresses here, like \"/ip4/192.168.1.2/tcp/31000/p2p/Qma....\"\n", + "NUM_PREFIX_TOKENS = 16\n", + "DEVICE = 'cpu'\n", + "BATCH_SIZE = 4\n", + "LR = 1e-2\n", + "WEIGHT_DECAY = 0.0\n", + "NUM_SAMPLES = 1000\n", + "SEED = 42\n", + "MODEL_MAX_LENGTH = 256\n", + "TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] " + ] + }, + { + "cell_type": "markdown", + "id": "d38316bd", + "metadata": {}, + "source": [ + "Prepare tokenizer and distributed model, connect it to servers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03c6e53e", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)\n", + "tokenizer.padding_side = 'right'\n", + "tokenizer.model_max_length = MODEL_MAX_LENGTH\n", + "model = DistributedBloomForCausalLM.from_pretrained(\n", + " MODEL_NAME, \n", + " initial_peers=INITIAL_PEERS, \n", + " pre_seq_len=NUM_PREFIX_TOKENS, \n", + " tuning_mode=TUNING_MODE\n", + ").to(DEVICE)" + ] + }, + { + "cell_type": "markdown", + "id": "042e3786", + "metadata": {}, + "source": [ + "Let's prepare the Personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c44d516", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = load_dataset(\"bavard/personachat_truecased\")\n", + "\n", + "\n", + "def chunking(examples):\n", + " inputs = [\n", + " \"\\n-----\\n\".join(history) + \"\\n-----\\n\" + candidate\n", + " for history, candidates in zip(examples[\"history\"], examples[\"candidates\"])\n", + " for candidate in candidates\n", + " ]\n", + " return {\"chunks\": inputs}\n", + "\n", + "\n", + "def tokenize(examples):\n", + " outputs = {\n", + " \"input_ids\": tokenizer(examples[\"chunks\"], padding='max_length', truncation=True)[\"input_ids\"]\n", + " }\n", + " outputs[\"labels\"] = outputs[\"input_ids\"]\n", + " return outputs\n", + "\n", + "\n", + "tokenized_datasets = (\n", + " dataset\n", + " .map(chunking, batched=True, remove_columns=dataset[\"train\"].column_names)\n", + " .map(tokenize, batched=True, remove_columns=[\"chunks\"])\n", + ")\n", + "\n", + "\n", + "tokenized_datasets.set_format(\"torch\")\n", + "train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n", + "train_dataloader = DataLoader(\n", + " train_dataset.select(list(range(NUM_SAMPLES))),\n", + " shuffle=True,\n", + " batch_size=BATCH_SIZE,\n", + " drop_last=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "ef4323fd", + "metadata": {}, + "source": [ + "Before setting up optimizers, check the model parameters that will be trained." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cc0ba34", + "metadata": {}, + "outputs": [], + "source": [ + "for n, p in model.named_parameters():\n", + " if p.requires_grad:\n", + " print(n, p.requires_grad, p.device)" + ] + }, + { + "cell_type": "markdown", + "id": "59cffce7", + "metadata": {}, + "source": [ + "The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef9bf344", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", + "\n", + "lr_scheduler = get_scheduler(\n", + " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "423c56d5", + "metadata": {}, + "source": [ + "Let's initialize wandb for logging and start the training loop!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9e46807", + "metadata": {}, + "outputs": [], + "source": [ + "wandb.init(\n", + " project=\"bloom-personachat\",\n", + " config={\n", + " \"num_samples\": NUM_SAMPLES,\n", + " \"batch_size\": BATCH_SIZE,\n", + " \"learning_rate\": LR,\n", + " \"weight_decay\": WEIGHT_DECAY,\n", + " \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n", + " \"model_name\": MODEL_NAME,\n", + " \"seed\": SEED,\n", + " }\n", + ")\n", + "\n", + "for batch in tqdm(train_dataloader):\n", + " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", + "\n", + " model.train()\n", + " outputs = model(**batch)\n", + " loss = outputs.loss\n", + " loss.backward()\n", + "\n", + " optimizer.step()\n", + " lr_scheduler.step()\n", + " optimizer.zero_grad()\n", + "\n", + " wandb.log({\"Train Loss\": loss})" + ] + }, + { + "cell_type": "markdown", + "id": "0f36cb80", + "metadata": {}, + "source": [ + "Try to talk with the trained model! Submit an empty input to stop the execution.\n", + "\n", + "\n", + "__Note__: In this example, we the whole dialogue as a prefix when generating each new replica. In the future, we will support a faster \"interactive\" dialogue mode, so generating a new replica will be able to reuse inference caches from the previous replica." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "720181b7", + "metadata": {}, + "outputs": [], + "source": [ + "MAX_TOKENS = 16\n", + "TOP_K = 100\n", + "TEMPERATURE = 0.6\n", + "dialog = \"\"\n", + "\n", + "while True:\n", + " user_phrase = input()\n", + " if len(user_phrase) == 0:\n", + " break\n", + " dialog += f\"{user_phrase}\\n-----\\n\"\n", + " inputs = tokenizer([dialog], return_tensors='pt')['input_ids']\n", + " outputs = model.generate(\n", + " inputs,\n", + " temperature=TEMPERATURE,\n", + " do_sample=True,\n", + " top_k=TOP_K,\n", + " eos_token_id=tokenizer.eos_token_id,\n", + " max_new_tokens=MAX_TOKENS,\n", + " )\n", + " bloom_answer = tokenizer.batch_decode(outputs)[0]\n", + " bloom_answer = bloom_answer[len(dialog):].split(\"\\n\")[0]\n", + " print(bloom_answer)\n", + " dialog += f\"{bloom_answer}\\n-----\\n\"" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}