diff --git a/examples/prompt-tuning-sst2.ipynb b/examples/prompt-tuning-sst2.ipynb index d99a48d..5bcb0c9 100644 --- a/examples/prompt-tuning-sst2.ipynb +++ b/examples/prompt-tuning-sst2.ipynb @@ -308,6 +308,7 @@ " self.distributed_layers = model.transformer.h\n", "\n", " self.hidden_size = model.config.hidden_size\n", + " self.dtype = model.config.torch_dtype\n", " self.intermediate_size = intermediate_size\n", " self.num_classes = num_classes\n", " self.adapter_layer_position = adapter_layer_position\n", @@ -316,11 +317,11 @@ " self.adapter = nn.Sequential(\n", " nn.Linear(self.hidden_size, self.intermediate_size),\n", " nn.Linear(self.intermediate_size, self.hidden_size),\n", - " )\n", + " ).to(self.dtype)\n", " self.head = nn.Sequential(\n", " nn.LayerNorm(self.hidden_size),\n", " nn.Linear(self.hidden_size, self.num_classes),\n", - " )\n", + " ).to(self.dtype)\n", " \n", " def forward(self, embeddings):\n", " before_layers = self.distributed_layers[0:self.adapter_layer_position]\n", @@ -388,9 +389,10 @@ " head_layer_position=HEAD_LAYER_POSITION,\n", ")\n", "cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", + "cls_criterion = nn.CrossEntoryCriterion()\n", "\n", "lr_scheduler = get_scheduler(\n", - " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", + " name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", ")" ] }, @@ -432,6 +434,7 @@ " with torch.no_grad():\n", " embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n", " outputs = cls_model(embeddings_output)\n", + " loss = cls_criterion(outputs, batch[\"labels\"])\n", " loss.backward()\n", "\n", " cls_optimizer.step()\n", @@ -461,7 +464,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.9 (default, Apr 13 2022, 08:48:07) \n[Clang 13.1.6 (clang-1316.0.21.2.5)]" + "version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]" }, "vscode": { "interpreter": { diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index 5d22bfd..af8a20c 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -265,7 +265,7 @@ class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequ self.num_labels = config.num_labels self.transformer = DistributedBloomModel(config) - self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) + self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype) # Initialize weights and apply final processing self.post_init()