mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
335 lines
10 KiB
Plaintext
335 lines
10 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a07e0f5e",
|
|
"metadata": {},
|
|
"source": [
|
|
"<div>\n",
|
|
"<img src=\"https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67\" width=\"40%\"> \n",
|
|
"</div>\n",
|
|
"\n",
|
|
"# Distributed Bloom for Text Generation using Prompt Tuning\n",
|
|
"\n",
|
|
"In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
|
|
"\n",
|
|
"We will adapt the BLOOM model for the chatbot task using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer.\n",
|
|
"\n",
|
|
"To open this notebook in colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a3f8526f",
|
|
"metadata": {},
|
|
"source": [
|
|
"First, we have to prepare all dependencies."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "73bbc648",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import subprocess\n",
|
|
"import sys\n",
|
|
"\n",
|
|
"!pip install -r git+https://github.com/bigscience-workshop/petals\n",
|
|
"!pip install datasets wandb\n",
|
|
"\n",
|
|
"IN_COLAB = 'google.colab' in sys.modules\n",
|
|
"if IN_COLAB: # Remove CUDA binaries on CPU-only colabs to not confuse bitsandbytes\n",
|
|
" try:\n",
|
|
" subprocess.check_output([\"nvidia-smi\", \"-L\"])\n",
|
|
" except subprocess.CalledProcessError as e:\n",
|
|
" subprocess.run(\"rm -r /usr/local/cuda/lib64\", shell=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b4ab6ca7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import sys\n",
|
|
" \n",
|
|
"import torch\n",
|
|
"import transformers\n",
|
|
"import wandb\n",
|
|
"from datasets import load_dataset\n",
|
|
"from tqdm import tqdm\n",
|
|
"from torch.optim import AdamW\n",
|
|
"from torch.utils.data import DataLoader\n",
|
|
"from transformers import get_scheduler\n",
|
|
"\n",
|
|
"# Import a Petals model\n",
|
|
"from petals.client.remote_model import DistributedBloomForCausalLM"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1bf07b5d",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's set some hyperparameters for training:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f04ba4d2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"MODEL_NAME = ... # select model you like\n",
|
|
"INITIAL_PEERS = [...] # add your peers adresses here, like \"/ip4/192.168.1.2/tcp/31000/p2p/Qma....\"\n",
|
|
"NUM_PREFIX_TOKENS = 16\n",
|
|
"DEVICE = 'cpu'\n",
|
|
"BATCH_SIZE = 4\n",
|
|
"LR = 1e-2\n",
|
|
"WEIGHT_DECAY = 0.0\n",
|
|
"NUM_SAMPLES = 1000\n",
|
|
"SEED = 42\n",
|
|
"MODEL_MAX_LENGTH = 256\n",
|
|
"TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d38316bd",
|
|
"metadata": {},
|
|
"source": [
|
|
"Prepare tokenizer and distributed model, connect it to servers."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "03c6e53e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
|
|
"tokenizer.padding_side = 'right'\n",
|
|
"tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
|
|
"model = DistributedBloomForCausalLM.from_pretrained(\n",
|
|
" MODEL_NAME, \n",
|
|
" initial_peers=INITIAL_PEERS, \n",
|
|
" pre_seq_len=NUM_PREFIX_TOKENS, \n",
|
|
" tuning_mode=TUNING_MODE\n",
|
|
").to(DEVICE)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "042e3786",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's prepare the Personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9c44d516",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dataset = load_dataset(\"bavard/personachat_truecased\")\n",
|
|
"\n",
|
|
"\n",
|
|
"def chunking(examples):\n",
|
|
" inputs = [\n",
|
|
" \"\\n-----\\n\".join(history) + \"\\n-----\\n\" + candidate\n",
|
|
" for history, candidates in zip(examples[\"history\"], examples[\"candidates\"])\n",
|
|
" for candidate in candidates\n",
|
|
" ]\n",
|
|
" return {\"chunks\": inputs}\n",
|
|
"\n",
|
|
"\n",
|
|
"def tokenize(examples):\n",
|
|
" outputs = {\n",
|
|
" \"input_ids\": tokenizer(examples[\"chunks\"], padding='max_length', truncation=True)[\"input_ids\"]\n",
|
|
" }\n",
|
|
" outputs[\"labels\"] = outputs[\"input_ids\"]\n",
|
|
" return outputs\n",
|
|
"\n",
|
|
"\n",
|
|
"tokenized_datasets = (\n",
|
|
" dataset\n",
|
|
" .map(chunking, batched=True, remove_columns=dataset[\"train\"].column_names)\n",
|
|
" .map(tokenize, batched=True, remove_columns=[\"chunks\"])\n",
|
|
")\n",
|
|
"\n",
|
|
"\n",
|
|
"tokenized_datasets.set_format(\"torch\")\n",
|
|
"train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n",
|
|
"train_dataloader = DataLoader(\n",
|
|
" train_dataset.select(list(range(NUM_SAMPLES))),\n",
|
|
" shuffle=True,\n",
|
|
" batch_size=BATCH_SIZE,\n",
|
|
" drop_last=True,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ef4323fd",
|
|
"metadata": {},
|
|
"source": [
|
|
"Before setting up optimizers, check the model parameters that will be trained."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9cc0ba34",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for n, p in model.named_parameters():\n",
|
|
" if p.requires_grad:\n",
|
|
" print(n, p.requires_grad, p.device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "59cffce7",
|
|
"metadata": {},
|
|
"source": [
|
|
"The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ef9bf344",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
|
|
"\n",
|
|
"lr_scheduler = get_scheduler(\n",
|
|
" name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "423c56d5",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's initialize wandb for logging and start the training loop!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d9e46807",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"wandb.init(\n",
|
|
" project=\"bloom-personachat\",\n",
|
|
" config={\n",
|
|
" \"num_samples\": NUM_SAMPLES,\n",
|
|
" \"batch_size\": BATCH_SIZE,\n",
|
|
" \"learning_rate\": LR,\n",
|
|
" \"weight_decay\": WEIGHT_DECAY,\n",
|
|
" \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n",
|
|
" \"model_name\": MODEL_NAME,\n",
|
|
" \"seed\": SEED,\n",
|
|
" }\n",
|
|
")\n",
|
|
"\n",
|
|
"for batch in tqdm(train_dataloader):\n",
|
|
" batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
|
|
"\n",
|
|
" model.train()\n",
|
|
" outputs = model(**batch)\n",
|
|
" loss = outputs.loss\n",
|
|
" loss.backward()\n",
|
|
"\n",
|
|
" optimizer.step()\n",
|
|
" lr_scheduler.step()\n",
|
|
" optimizer.zero_grad()\n",
|
|
"\n",
|
|
" wandb.log({\"Train Loss\": loss})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0f36cb80",
|
|
"metadata": {},
|
|
"source": [
|
|
"Try to talk with the trained model! Submit an empty input to stop the execution.\n",
|
|
"\n",
|
|
"\n",
|
|
"__Note__: In this example, we the whole dialogue as a prefix when generating each new replica. In the future, we will support a faster \"interactive\" dialogue mode, so generating a new replica will be able to reuse inference caches from the previous replica."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "720181b7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"MAX_TOKENS = 16\n",
|
|
"TOP_K = 100\n",
|
|
"TEMPERATURE = 0.6\n",
|
|
"dialog = \"\"\n",
|
|
"\n",
|
|
"while True:\n",
|
|
" user_phrase = input()\n",
|
|
" if len(user_phrase) == 0:\n",
|
|
" break\n",
|
|
" dialog += f\"{user_phrase}\\n-----\\n\"\n",
|
|
" inputs = tokenizer([dialog], return_tensors='pt')['input_ids']\n",
|
|
" outputs = model.generate(\n",
|
|
" inputs,\n",
|
|
" temperature=TEMPERATURE,\n",
|
|
" do_sample=True,\n",
|
|
" top_k=TOP_K,\n",
|
|
" eos_token_id=tokenizer.eos_token_id,\n",
|
|
" max_new_tokens=MAX_TOKENS,\n",
|
|
" )\n",
|
|
" bloom_answer = tokenizer.batch_decode(outputs)[0]\n",
|
|
" bloom_answer = bloom_answer[len(dialog):].split(\"\\n\")[0]\n",
|
|
" print(bloom_answer)\n",
|
|
" dialog += f\"{bloom_answer}\\n-----\\n\""
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3.6.9 64-bit",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.9"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|