|
|
@ -288,7 +288,6 @@
|
|
|
|
]
|
|
|
|
]
|
|
|
|
},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
{
|
|
|
|
"attachments": {},
|
|
|
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"id": "1bbf014f",
|
|
|
|
"id": "1bbf014f",
|
|
|
|
"metadata": {},
|
|
|
|
"metadata": {},
|
|
|
@ -324,6 +323,7 @@
|
|
|
|
" self.adapter_layer_position = adapter_layer_position\n",
|
|
|
|
" self.adapter_layer_position = adapter_layer_position\n",
|
|
|
|
" self.head_layer_position = head_layer_position\n",
|
|
|
|
" self.head_layer_position = head_layer_position\n",
|
|
|
|
" \n",
|
|
|
|
" \n",
|
|
|
|
|
|
|
|
" self.word_embeddings = model.transformer.word_embeddings\n",
|
|
|
|
" self.adapter = nn.Sequential(\n",
|
|
|
|
" self.adapter = nn.Sequential(\n",
|
|
|
|
" nn.Linear(self.hidden_size, self.intermediate_size),\n",
|
|
|
|
" nn.Linear(self.hidden_size, self.intermediate_size),\n",
|
|
|
|
" nn.Linear(self.intermediate_size, self.hidden_size),\n",
|
|
|
|
" nn.Linear(self.intermediate_size, self.hidden_size),\n",
|
|
|
@ -440,7 +440,7 @@
|
|
|
|
"\n",
|
|
|
|
"\n",
|
|
|
|
" cls_model.train()\n",
|
|
|
|
" cls_model.train()\n",
|
|
|
|
" with torch.no_grad():\n",
|
|
|
|
" with torch.no_grad():\n",
|
|
|
|
" embeddings_output = model.transformer.word_embeddings(batch[\"input_ids\"])\n",
|
|
|
|
" embeddings_output = cls_model.word_embeddings(batch[\"input_ids\"])\n",
|
|
|
|
" outputs = cls_model(embeddings_output)\n",
|
|
|
|
" outputs = cls_model(embeddings_output)\n",
|
|
|
|
" loss = cls_criterion(outputs, batch[\"labels\"])\n",
|
|
|
|
" loss = cls_criterion(outputs, batch[\"labels\"])\n",
|
|
|
|
" loss.backward()\n",
|
|
|
|
" loss.backward()\n",
|
|
|
@ -458,7 +458,7 @@
|
|
|
|
],
|
|
|
|
],
|
|
|
|
"metadata": {
|
|
|
|
"metadata": {
|
|
|
|
"kernelspec": {
|
|
|
|
"kernelspec": {
|
|
|
|
"display_name": "Python 3.8.9 64-bit",
|
|
|
|
"display_name": "Python 3",
|
|
|
|
"language": "python",
|
|
|
|
"language": "python",
|
|
|
|
"name": "python3"
|
|
|
|
"name": "python3"
|
|
|
|
},
|
|
|
|
},
|
|
|
@ -472,7 +472,7 @@
|
|
|
|
"name": "python",
|
|
|
|
"name": "python",
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
"version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]"
|
|
|
|
"version": "3.8.8"
|
|
|
|
},
|
|
|
|
},
|
|
|
|
"vscode": {
|
|
|
|
"vscode": {
|
|
|
|
"interpreter": {
|
|
|
|
"interpreter": {
|
|
|
|