From 7911c2641d930df4cd89a4d528bdb89a9a97639b Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Thu, 15 Dec 2022 09:28:25 +0400 Subject: [PATCH] Update advanced notebooks (#148) Update examples Co-authored-by: Alexander Borzunov --- examples/prompt-tuning-personachat.ipynb | 50 +++---- examples/prompt-tuning-sst2.ipynb | 180 ++++++++++++++++++++++- 2 files changed, 200 insertions(+), 30 deletions(-) diff --git a/examples/prompt-tuning-personachat.ipynb b/examples/prompt-tuning-personachat.ipynb index 6993a3b..868299b 100644 --- a/examples/prompt-tuning-personachat.ipynb +++ b/examples/prompt-tuning-personachat.ipynb @@ -36,8 +36,8 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install git+https://github.com/bigscience-workshop/petals\n", - "!pip install datasets wandb" + "!pip install -q git+https://github.com/bigscience-workshop/petals\n", + "!pip install -q datasets wandb" ] }, { @@ -269,35 +269,35 @@ "metadata": {}, "outputs": [], "source": [ - "MAX_TOKENS = 16\n", "TOP_K = 100\n", "TEMPERATURE = 0.6\n", - "dialog = \"\"\n", "\n", - "while True:\n", - " user_phrase = input()\n", - " if len(user_phrase) == 0:\n", - " break\n", - " dialog += f\"{user_phrase}\\n-----\\n\"\n", - " inputs = tokenizer([dialog], return_tensors='pt')['input_ids']\n", - " outputs = model.generate(\n", - " inputs,\n", - " temperature=TEMPERATURE,\n", - " do_sample=True,\n", - " top_k=TOP_K,\n", - " eos_token_id=tokenizer.eos_token_id,\n", - " max_new_tokens=MAX_TOKENS,\n", - " )\n", - " bloom_answer = tokenizer.batch_decode(outputs)[0]\n", - " bloom_answer = bloom_answer[len(dialog):].split(\"\\n\")[0]\n", - " print(bloom_answer)\n", - " dialog += f\"{bloom_answer}\\n-----\\n\"" + "with model.inference_session(max_length=512) as sess:\n", + " while True:\n", + " user_phrase = input()\n", + " if len(user_phrase) == 0:\n", + " break\n", + " inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids']\n", + " while True:\n", + " outputs = model.generate(\n", + " inputs,\n", + " temperature=TEMPERATURE,\n", + " do_sample=True,\n", + " top_k=TOP_K,\n", + " max_new_tokens=1,\n", + " session=sess,\n", + " )\n", + " bloom_answer_token = tokenizer.decode(outputs[0, -1:])\n", + " print(bloom_answer_token, end=\"\", flush=True)\n", + " if bloom_answer_token == \"\\n\":\n", + " break\n", + " inputs = None" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.12 ('bloom-demo')", + "display_name": "Python 3.8.9 64-bit", "language": "python", "name": "python3" }, @@ -311,11 +311,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.9" }, "vscode": { "interpreter": { - "hash": "175c31e15dd38a7dfc9eb4117a9e428ffb6063af97d545b6bfba4d874ecc4bb8" + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index 7a534d9..dce7766 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -36,8 +36,8 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install git+https://github.com/bigscience-workshop/petals\n", - "!pip install datasets wandb" + "!pip install -q git+https://github.com/bigscience-workshop/petals\n", + "!pip install -q datasets wandb" ] }, { @@ -52,6 +52,10 @@ "import torch\n", "import transformers\n", "import wandb\n", + "\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", "from datasets import load_dataset, load_metric\n", "from tqdm import tqdm\n", "from torch.optim import AdamW\n", @@ -276,11 +280,177 @@ "source": [ "Our model have been trained!" ] + }, + { + "cell_type": "markdown", + "id": "1bbf014f", + "metadata": {}, + "source": [ + "## Beyond soft-propmt tuning\n", + "\n", + "Let's try to tune model using adapters in the middle of the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3bea4391", + "metadata": {}, + "outputs": [], + "source": [ + "class BloomBasedClassifier(nn.Module):\n", + " def __init__(\n", + " self,\n", + " model,\n", + " intermediate_size: int = 32,\n", + " num_classes: int = 2,\n", + " adapter_layer_position: int = 6,\n", + " head_layer_position: int = 10\n", + " ):\n", + " super().__init__()\n", + " self.distributed_layers = model.transformer.h\n", + "\n", + " self.hidden_size = model.config.hidden_size\n", + " self.intermediate_size = intermediate_size\n", + " self.num_classes = num_classes\n", + " self.adapter_layer_position = adapter_layer_position\n", + " self.head_layer_position = head_layer_position\n", + " \n", + " self.adapter = nn.Sequential(\n", + " nn.Linear(self.hidden_size, self.intermediate_size),\n", + " nn.Linear(self.intermediate_size, self.hidden_size),\n", + " )\n", + " self.head = nn.Sequential(\n", + " nn.LayerNorm(self.hidden_size),\n", + " nn.Linear(self.hidden_size, self.num_classes),\n", + " )\n", + " \n", + " def forward(self, embeddings):\n", + " before_layers = self.distributed_layers[0:self.adapter_layer_position]\n", + " after_layers = self.distributed_layers[self.adapter_layer_position:self.head_layer_position]\n", + " \n", + " hidden_states = before_layers(embeddings)\n", + " hidden_states = self.adapter(hidden_states)\n", + " hidden_states = after_layers(hidden_states)\n", + " pooled_states = torch.mean(hidden_states, dim=1)\n", + " return self.head(pooled_states)" + ] + }, + { + "cell_type": "markdown", + "id": "15299620", + "metadata": {}, + "source": [ + "Clear model and device memory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa27b168", + "metadata": {}, + "outputs": [], + "source": [ + "del model, optimizer, lr_scheduler\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "markdown", + "id": "5406390f", + "metadata": {}, + "source": [ + "Create new model with adapters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a251db80", + "metadata": {}, + "outputs": [], + "source": [ + "INTERMEDIATE_SIZE = 32\n", + "ADAPTER_LAYER_POSITION = 6\n", + "HEAD_LAYER_POSITION = 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3578df3a", + "metadata": {}, + "outputs": [], + "source": [ + "model = DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)\n", + "\n", + "cls_model = BloomBasedClassifier(\n", + " model,\n", + " intermediate_size=INTERMEDIATE_SIZE,\n", + " adapter_layer_position=ADAPTER_LAYER_POSITION,\n", + " head_layer_position=HEAD_LAYER_POSITION,\n", + ")\n", + "cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", + "\n", + "lr_scheduler = get_scheduler(\n", + " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a40468b9", + "metadata": {}, + "source": [ + "And start training our new adapted model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed051a5d", + "metadata": {}, + "outputs": [], + "source": [ + "wandb.init(\n", + " project=\"bloom_based_cls-sst-2\",\n", + " config={\n", + " \"num_epochs\": NUM_EPOCHS,\n", + " \"batch_size\": BATCH_SIZE,\n", + " \"learning_rate\": LR,\n", + " \"weight_decay\": WEIGHT_DECAY,\n", + " \"model_name\": MODEL_NAME,\n", + " \"seed\": SEED,\n", + " \"intermediate_size\": INTERMEDIATE_SIZE,\n", + " \"adapter_layer_position\": ADAPTER_LAYER_POSITION,\n", + " \"head_layer_position\": HEAD_LAYER_POSITION,\n", + " }\n", + ")\n", + "\n", + "for epoch in range(NUM_EPOCHS):\n", + " for batch in tqdm(train_dataloader):\n", + " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", + "\n", + " cls_model.train()\n", + " with torch.no_grad():\n", + " embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n", + " outputs = cls_model(embeddings_output)\n", + " loss.backward()\n", + "\n", + " cls_optimizer.step()\n", + " lr_scheduler.step()\n", + " cls_optimizer.zero_grad()\n", + "\n", + " wandb.log({\"Train Loss\": loss})\n", + "\n", + " accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n", + " wandb.log({\"Valid Accuracy\": accuracy}, commit=False)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.12 ('bloom-demo')", + "display_name": "Python 3.8.9 64-bit", "language": "python", "name": "python3" }, @@ -294,11 +464,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.9" }, "vscode": { "interpreter": { - "hash": "175c31e15dd38a7dfc9eb4117a9e428ffb6063af97d545b6bfba4d874ecc4bb8" + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } },