|
|
|
@ -308,6 +308,7 @@
|
|
|
|
|
" self.distributed_layers = model.transformer.h\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" self.hidden_size = model.config.hidden_size\n",
|
|
|
|
|
" self.dtype = model.config.torch_dtype\n",
|
|
|
|
|
" self.intermediate_size = intermediate_size\n",
|
|
|
|
|
" self.num_classes = num_classes\n",
|
|
|
|
|
" self.adapter_layer_position = adapter_layer_position\n",
|
|
|
|
@ -316,11 +317,11 @@
|
|
|
|
|
" self.adapter = nn.Sequential(\n",
|
|
|
|
|
" nn.Linear(self.hidden_size, self.intermediate_size),\n",
|
|
|
|
|
" nn.Linear(self.intermediate_size, self.hidden_size),\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" ).to(self.dtype)\n",
|
|
|
|
|
" self.head = nn.Sequential(\n",
|
|
|
|
|
" nn.LayerNorm(self.hidden_size),\n",
|
|
|
|
|
" nn.Linear(self.hidden_size, self.num_classes),\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" ).to(self.dtype)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" def forward(self, embeddings):\n",
|
|
|
|
|
" before_layers = self.distributed_layers[0:self.adapter_layer_position]\n",
|
|
|
|
@ -388,9 +389,10 @@
|
|
|
|
|
" head_layer_position=HEAD_LAYER_POSITION,\n",
|
|
|
|
|
")\n",
|
|
|
|
|
"cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
|
|
|
|
|
"cls_criterion = nn.CrossEntoryCriterion()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"lr_scheduler = get_scheduler(\n",
|
|
|
|
|
" name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
|
|
|
|
|
" name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
|
|
|
|
|
")"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
@ -432,6 +434,7 @@
|
|
|
|
|
" with torch.no_grad():\n",
|
|
|
|
|
" embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n",
|
|
|
|
|
" outputs = cls_model(embeddings_output)\n",
|
|
|
|
|
" loss = cls_criterion(outputs, batch[\"labels\"])\n",
|
|
|
|
|
" loss.backward()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" cls_optimizer.step()\n",
|
|
|
|
@ -461,7 +464,7 @@
|
|
|
|
|
"name": "python",
|
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
|
"version": "3.8.9 (default, Apr 13 2022, 08:48:07) \n[Clang 13.1.6 (clang-1316.0.21.2.5)]"
|
|
|
|
|
"version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]"
|
|
|
|
|
},
|
|
|
|
|
"vscode": {
|
|
|
|
|
"interpreter": {
|
|
|
|
|