Update advanced notebooks (#148)

Update examples

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
pull/145/head
Artem Chumachenko 1 year ago committed by GitHub
parent 668b736031
commit 7911c2641d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -36,8 +36,8 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install git+https://github.com/bigscience-workshop/petals\n",
"!pip install datasets wandb"
"!pip install -q git+https://github.com/bigscience-workshop/petals\n",
"!pip install -q datasets wandb"
]
},
{
@ -269,35 +269,35 @@
"metadata": {},
"outputs": [],
"source": [
"MAX_TOKENS = 16\n",
"TOP_K = 100\n",
"TEMPERATURE = 0.6\n",
"dialog = \"\"\n",
"\n",
"while True:\n",
" user_phrase = input()\n",
" if len(user_phrase) == 0:\n",
" break\n",
" dialog += f\"{user_phrase}\\n-----\\n\"\n",
" inputs = tokenizer([dialog], return_tensors='pt')['input_ids']\n",
" outputs = model.generate(\n",
" inputs,\n",
" temperature=TEMPERATURE,\n",
" do_sample=True,\n",
" top_k=TOP_K,\n",
" eos_token_id=tokenizer.eos_token_id,\n",
" max_new_tokens=MAX_TOKENS,\n",
" )\n",
" bloom_answer = tokenizer.batch_decode(outputs)[0]\n",
" bloom_answer = bloom_answer[len(dialog):].split(\"\\n\")[0]\n",
" print(bloom_answer)\n",
" dialog += f\"{bloom_answer}\\n-----\\n\""
"with model.inference_session(max_length=512) as sess:\n",
" while True:\n",
" user_phrase = input()\n",
" if len(user_phrase) == 0:\n",
" break\n",
" inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids']\n",
" while True:\n",
" outputs = model.generate(\n",
" inputs,\n",
" temperature=TEMPERATURE,\n",
" do_sample=True,\n",
" top_k=TOP_K,\n",
" max_new_tokens=1,\n",
" session=sess,\n",
" )\n",
" bloom_answer_token = tokenizer.decode(outputs[0, -1:])\n",
" print(bloom_answer_token, end=\"\", flush=True)\n",
" if bloom_answer_token == \"\\n\":\n",
" break\n",
" inputs = None"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.12 ('bloom-demo')",
"display_name": "Python 3.8.9 64-bit",
"language": "python",
"name": "python3"
},
@ -311,11 +311,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
"version": "3.8.9"
},
"vscode": {
"interpreter": {
"hash": "175c31e15dd38a7dfc9eb4117a9e428ffb6063af97d545b6bfba4d874ecc4bb8"
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},

@ -36,8 +36,8 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install git+https://github.com/bigscience-workshop/petals\n",
"!pip install datasets wandb"
"!pip install -q git+https://github.com/bigscience-workshop/petals\n",
"!pip install -q datasets wandb"
]
},
{
@ -52,6 +52,10 @@
"import torch\n",
"import transformers\n",
"import wandb\n",
"\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"from datasets import load_dataset, load_metric\n",
"from tqdm import tqdm\n",
"from torch.optim import AdamW\n",
@ -276,11 +280,177 @@
"source": [
"Our model have been trained!"
]
},
{
"cell_type": "markdown",
"id": "1bbf014f",
"metadata": {},
"source": [
"## Beyond soft-propmt tuning\n",
"\n",
"Let's try to tune model using adapters in the middle of the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3bea4391",
"metadata": {},
"outputs": [],
"source": [
"class BloomBasedClassifier(nn.Module):\n",
" def __init__(\n",
" self,\n",
" model,\n",
" intermediate_size: int = 32,\n",
" num_classes: int = 2,\n",
" adapter_layer_position: int = 6,\n",
" head_layer_position: int = 10\n",
" ):\n",
" super().__init__()\n",
" self.distributed_layers = model.transformer.h\n",
"\n",
" self.hidden_size = model.config.hidden_size\n",
" self.intermediate_size = intermediate_size\n",
" self.num_classes = num_classes\n",
" self.adapter_layer_position = adapter_layer_position\n",
" self.head_layer_position = head_layer_position\n",
" \n",
" self.adapter = nn.Sequential(\n",
" nn.Linear(self.hidden_size, self.intermediate_size),\n",
" nn.Linear(self.intermediate_size, self.hidden_size),\n",
" )\n",
" self.head = nn.Sequential(\n",
" nn.LayerNorm(self.hidden_size),\n",
" nn.Linear(self.hidden_size, self.num_classes),\n",
" )\n",
" \n",
" def forward(self, embeddings):\n",
" before_layers = self.distributed_layers[0:self.adapter_layer_position]\n",
" after_layers = self.distributed_layers[self.adapter_layer_position:self.head_layer_position]\n",
" \n",
" hidden_states = before_layers(embeddings)\n",
" hidden_states = self.adapter(hidden_states)\n",
" hidden_states = after_layers(hidden_states)\n",
" pooled_states = torch.mean(hidden_states, dim=1)\n",
" return self.head(pooled_states)"
]
},
{
"cell_type": "markdown",
"id": "15299620",
"metadata": {},
"source": [
"Clear model and device memory."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa27b168",
"metadata": {},
"outputs": [],
"source": [
"del model, optimizer, lr_scheduler\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "markdown",
"id": "5406390f",
"metadata": {},
"source": [
"Create new model with adapters."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a251db80",
"metadata": {},
"outputs": [],
"source": [
"INTERMEDIATE_SIZE = 32\n",
"ADAPTER_LAYER_POSITION = 6\n",
"HEAD_LAYER_POSITION = 10"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3578df3a",
"metadata": {},
"outputs": [],
"source": [
"model = DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)\n",
"\n",
"cls_model = BloomBasedClassifier(\n",
" model,\n",
" intermediate_size=INTERMEDIATE_SIZE,\n",
" adapter_layer_position=ADAPTER_LAYER_POSITION,\n",
" head_layer_position=HEAD_LAYER_POSITION,\n",
")\n",
"cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
"\n",
"lr_scheduler = get_scheduler(\n",
" name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "a40468b9",
"metadata": {},
"source": [
"And start training our new adapted model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed051a5d",
"metadata": {},
"outputs": [],
"source": [
"wandb.init(\n",
" project=\"bloom_based_cls-sst-2\",\n",
" config={\n",
" \"num_epochs\": NUM_EPOCHS,\n",
" \"batch_size\": BATCH_SIZE,\n",
" \"learning_rate\": LR,\n",
" \"weight_decay\": WEIGHT_DECAY,\n",
" \"model_name\": MODEL_NAME,\n",
" \"seed\": SEED,\n",
" \"intermediate_size\": INTERMEDIATE_SIZE,\n",
" \"adapter_layer_position\": ADAPTER_LAYER_POSITION,\n",
" \"head_layer_position\": HEAD_LAYER_POSITION,\n",
" }\n",
")\n",
"\n",
"for epoch in range(NUM_EPOCHS):\n",
" for batch in tqdm(train_dataloader):\n",
" batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
"\n",
" cls_model.train()\n",
" with torch.no_grad():\n",
" embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n",
" outputs = cls_model(embeddings_output)\n",
" loss.backward()\n",
"\n",
" cls_optimizer.step()\n",
" lr_scheduler.step()\n",
" cls_optimizer.zero_grad()\n",
"\n",
" wandb.log({\"Train Loss\": loss})\n",
"\n",
" accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n",
" wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.12 ('bloom-demo')",
"display_name": "Python 3.8.9 64-bit",
"language": "python",
"name": "python3"
},
@ -294,11 +464,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
"version": "3.8.9"
},
"vscode": {
"interpreter": {
"hash": "175c31e15dd38a7dfc9eb4117a9e428ffb6063af97d545b6bfba4d874ecc4bb8"
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},

Loading…
Cancel
Save