|
|
|
@ -36,8 +36,8 @@
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"!pip install git+https://github.com/bigscience-workshop/petals\n",
|
|
|
|
|
"!pip install datasets wandb"
|
|
|
|
|
"!pip install -q git+https://github.com/bigscience-workshop/petals\n",
|
|
|
|
|
"!pip install -q datasets wandb"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
@ -52,6 +52,10 @@
|
|
|
|
|
"import torch\n",
|
|
|
|
|
"import transformers\n",
|
|
|
|
|
"import wandb\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"import torch.nn as nn\n",
|
|
|
|
|
"import torch.nn.functional as F\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"from datasets import load_dataset, load_metric\n",
|
|
|
|
|
"from tqdm import tqdm\n",
|
|
|
|
|
"from torch.optim import AdamW\n",
|
|
|
|
@ -276,11 +280,177 @@
|
|
|
|
|
"source": [
|
|
|
|
|
"Our model have been trained!"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "1bbf014f",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"## Beyond soft-propmt tuning\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"Let's try to tune model using adapters in the middle of the model."
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "3bea4391",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"class BloomBasedClassifier(nn.Module):\n",
|
|
|
|
|
" def __init__(\n",
|
|
|
|
|
" self,\n",
|
|
|
|
|
" model,\n",
|
|
|
|
|
" intermediate_size: int = 32,\n",
|
|
|
|
|
" num_classes: int = 2,\n",
|
|
|
|
|
" adapter_layer_position: int = 6,\n",
|
|
|
|
|
" head_layer_position: int = 10\n",
|
|
|
|
|
" ):\n",
|
|
|
|
|
" super().__init__()\n",
|
|
|
|
|
" self.distributed_layers = model.transformer.h\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" self.hidden_size = model.config.hidden_size\n",
|
|
|
|
|
" self.intermediate_size = intermediate_size\n",
|
|
|
|
|
" self.num_classes = num_classes\n",
|
|
|
|
|
" self.adapter_layer_position = adapter_layer_position\n",
|
|
|
|
|
" self.head_layer_position = head_layer_position\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" self.adapter = nn.Sequential(\n",
|
|
|
|
|
" nn.Linear(self.hidden_size, self.intermediate_size),\n",
|
|
|
|
|
" nn.Linear(self.intermediate_size, self.hidden_size),\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" self.head = nn.Sequential(\n",
|
|
|
|
|
" nn.LayerNorm(self.hidden_size),\n",
|
|
|
|
|
" nn.Linear(self.hidden_size, self.num_classes),\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" def forward(self, embeddings):\n",
|
|
|
|
|
" before_layers = self.distributed_layers[0:self.adapter_layer_position]\n",
|
|
|
|
|
" after_layers = self.distributed_layers[self.adapter_layer_position:self.head_layer_position]\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" hidden_states = before_layers(embeddings)\n",
|
|
|
|
|
" hidden_states = self.adapter(hidden_states)\n",
|
|
|
|
|
" hidden_states = after_layers(hidden_states)\n",
|
|
|
|
|
" pooled_states = torch.mean(hidden_states, dim=1)\n",
|
|
|
|
|
" return self.head(pooled_states)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "15299620",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"Clear model and device memory."
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "aa27b168",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"del model, optimizer, lr_scheduler\n",
|
|
|
|
|
"torch.cuda.empty_cache()"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "5406390f",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"Create new model with adapters."
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "a251db80",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"INTERMEDIATE_SIZE = 32\n",
|
|
|
|
|
"ADAPTER_LAYER_POSITION = 6\n",
|
|
|
|
|
"HEAD_LAYER_POSITION = 10"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "3578df3a",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"model = DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"cls_model = BloomBasedClassifier(\n",
|
|
|
|
|
" model,\n",
|
|
|
|
|
" intermediate_size=INTERMEDIATE_SIZE,\n",
|
|
|
|
|
" adapter_layer_position=ADAPTER_LAYER_POSITION,\n",
|
|
|
|
|
" head_layer_position=HEAD_LAYER_POSITION,\n",
|
|
|
|
|
")\n",
|
|
|
|
|
"cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"lr_scheduler = get_scheduler(\n",
|
|
|
|
|
" name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
|
|
|
|
|
")"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "a40468b9",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"And start training our new adapted model."
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "ed051a5d",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"wandb.init(\n",
|
|
|
|
|
" project=\"bloom_based_cls-sst-2\",\n",
|
|
|
|
|
" config={\n",
|
|
|
|
|
" \"num_epochs\": NUM_EPOCHS,\n",
|
|
|
|
|
" \"batch_size\": BATCH_SIZE,\n",
|
|
|
|
|
" \"learning_rate\": LR,\n",
|
|
|
|
|
" \"weight_decay\": WEIGHT_DECAY,\n",
|
|
|
|
|
" \"model_name\": MODEL_NAME,\n",
|
|
|
|
|
" \"seed\": SEED,\n",
|
|
|
|
|
" \"intermediate_size\": INTERMEDIATE_SIZE,\n",
|
|
|
|
|
" \"adapter_layer_position\": ADAPTER_LAYER_POSITION,\n",
|
|
|
|
|
" \"head_layer_position\": HEAD_LAYER_POSITION,\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
")\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"for epoch in range(NUM_EPOCHS):\n",
|
|
|
|
|
" for batch in tqdm(train_dataloader):\n",
|
|
|
|
|
" batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" cls_model.train()\n",
|
|
|
|
|
" with torch.no_grad():\n",
|
|
|
|
|
" embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n",
|
|
|
|
|
" outputs = cls_model(embeddings_output)\n",
|
|
|
|
|
" loss.backward()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" cls_optimizer.step()\n",
|
|
|
|
|
" lr_scheduler.step()\n",
|
|
|
|
|
" cls_optimizer.zero_grad()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" wandb.log({\"Train Loss\": loss})\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n",
|
|
|
|
|
" wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"metadata": {
|
|
|
|
|
"kernelspec": {
|
|
|
|
|
"display_name": "Python 3.8.12 ('bloom-demo')",
|
|
|
|
|
"display_name": "Python 3.8.9 64-bit",
|
|
|
|
|
"language": "python",
|
|
|
|
|
"name": "python3"
|
|
|
|
|
},
|
|
|
|
@ -294,11 +464,11 @@
|
|
|
|
|
"name": "python",
|
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
|
"version": "3.8.12"
|
|
|
|
|
"version": "3.8.9"
|
|
|
|
|
},
|
|
|
|
|
"vscode": {
|
|
|
|
|
"interpreter": {
|
|
|
|
|
"hash": "175c31e15dd38a7dfc9eb4117a9e428ffb6063af97d545b6bfba4d874ecc4bb8"
|
|
|
|
|
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|