Fix examples/sst, add cls_model embeddings (#248)

pull/249/head
justheuristic 1 year ago committed by GitHub
parent 8766a14d28
commit b8a6788490
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -288,7 +288,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "1bbf014f",
"metadata": {},
@ -324,6 +323,7 @@
" self.adapter_layer_position = adapter_layer_position\n",
" self.head_layer_position = head_layer_position\n",
" \n",
" self.word_embeddings = model.transformer.word_embeddings\n",
" self.adapter = nn.Sequential(\n",
" nn.Linear(self.hidden_size, self.intermediate_size),\n",
" nn.Linear(self.intermediate_size, self.hidden_size),\n",
@ -440,7 +440,7 @@
"\n",
" cls_model.train()\n",
" with torch.no_grad():\n",
" embeddings_output = model.transformer.word_embeddings(batch[\"input_ids\"])\n",
" embeddings_output = cls_model.word_embeddings(batch[\"input_ids\"])\n",
" outputs = cls_model(embeddings_output)\n",
" loss = cls_criterion(outputs, batch[\"labels\"])\n",
" loss.backward()\n",
@ -458,7 +458,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.9 64-bit",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
@ -472,7 +472,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]"
"version": "3.8.8"
},
"vscode": {
"interpreter": {

Loading…
Cancel
Save