Fix examples/sst, add cls_model embeddings (#248)

pull/249/head
justheuristic 1 year ago committed by GitHub
parent 8766a14d28
commit b8a6788490
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -288,7 +288,6 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"id": "1bbf014f", "id": "1bbf014f",
"metadata": {}, "metadata": {},
@ -324,6 +323,7 @@
" self.adapter_layer_position = adapter_layer_position\n", " self.adapter_layer_position = adapter_layer_position\n",
" self.head_layer_position = head_layer_position\n", " self.head_layer_position = head_layer_position\n",
" \n", " \n",
" self.word_embeddings = model.transformer.word_embeddings\n",
" self.adapter = nn.Sequential(\n", " self.adapter = nn.Sequential(\n",
" nn.Linear(self.hidden_size, self.intermediate_size),\n", " nn.Linear(self.hidden_size, self.intermediate_size),\n",
" nn.Linear(self.intermediate_size, self.hidden_size),\n", " nn.Linear(self.intermediate_size, self.hidden_size),\n",
@ -440,7 +440,7 @@
"\n", "\n",
" cls_model.train()\n", " cls_model.train()\n",
" with torch.no_grad():\n", " with torch.no_grad():\n",
" embeddings_output = model.transformer.word_embeddings(batch[\"input_ids\"])\n", " embeddings_output = cls_model.word_embeddings(batch[\"input_ids\"])\n",
" outputs = cls_model(embeddings_output)\n", " outputs = cls_model(embeddings_output)\n",
" loss = cls_criterion(outputs, batch[\"labels\"])\n", " loss = cls_criterion(outputs, batch[\"labels\"])\n",
" loss.backward()\n", " loss.backward()\n",
@ -458,7 +458,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3.8.9 64-bit", "display_name": "Python 3",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -472,7 +472,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]" "version": "3.8.8"
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {

Loading…
Cancel
Save