diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index 54840b1..c5dac6a 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -288,7 +288,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "1bbf014f", "metadata": {}, @@ -324,6 +323,7 @@ " self.adapter_layer_position = adapter_layer_position\n", " self.head_layer_position = head_layer_position\n", " \n", + " self.word_embeddings = model.transformer.word_embeddings\n", " self.adapter = nn.Sequential(\n", " nn.Linear(self.hidden_size, self.intermediate_size),\n", " nn.Linear(self.intermediate_size, self.hidden_size),\n", @@ -440,7 +440,7 @@ "\n", " cls_model.train()\n", " with torch.no_grad():\n", - " embeddings_output = model.transformer.word_embeddings(batch[\"input_ids\"])\n", + " embeddings_output = cls_model.word_embeddings(batch[\"input_ids\"])\n", " outputs = cls_model(embeddings_output)\n", " loss = cls_criterion(outputs, batch[\"labels\"])\n", " loss.backward()\n", @@ -458,7 +458,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.9 64-bit", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -472,7 +472,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]" + "version": "3.8.8" }, "vscode": { "interpreter": {