"# Distributed Bloom for Text Generation using Prompt Tuning\n",
"\n",
"In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
"We will adapt the BLOOM model for the chatbot task using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer.\n",
"\n",
"To open this notebook in colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)"
"Let's prepare the Personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization."
"Let's initialize wandb for logging and start the training loop!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9e46807",
"metadata": {},
"outputs": [],
"source": [
"wandb.init(\n",
" project=\"bloom-personachat\",\n",
" config={\n",
" \"num_samples\": NUM_SAMPLES,\n",
" \"batch_size\": BATCH_SIZE,\n",
" \"learning_rate\": LR,\n",
" \"weight_decay\": WEIGHT_DECAY,\n",
" \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n",
" \"model_name\": MODEL_NAME,\n",
" \"seed\": SEED,\n",
" }\n",
")\n",
"\n",
"for batch in tqdm(train_dataloader):\n",
" batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
"\n",
" model.train()\n",
" outputs = model(**batch)\n",
" loss = outputs.loss\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
" lr_scheduler.step()\n",
" optimizer.zero_grad()\n",
"\n",
" wandb.log({\"Train Loss\": loss})"
]
},
{
"cell_type": "markdown",
"id": "0f36cb80",
"metadata": {},
"source": [
"Try to talk with the trained model! Submit an empty input to stop the execution.\n",
"\n",
"\n",
"__Note__: In this example, we the whole dialogue as a prefix when generating each new replica. In the future, we will support a faster \"interactive\" dialogue mode, so generating a new replica will be able to reuse inference caches from the previous replica."