Fix dtype error in fine-tuning notebooks (#231)

pull/234/head
Artem Chumachenko 1 year ago committed by GitHub
parent 0ebf6de117
commit d4c687daca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -308,6 +308,7 @@
" self.distributed_layers = model.transformer.h\n",
"\n",
" self.hidden_size = model.config.hidden_size\n",
" self.dtype = model.config.torch_dtype\n",
" self.intermediate_size = intermediate_size\n",
" self.num_classes = num_classes\n",
" self.adapter_layer_position = adapter_layer_position\n",
@ -316,11 +317,11 @@
" self.adapter = nn.Sequential(\n",
" nn.Linear(self.hidden_size, self.intermediate_size),\n",
" nn.Linear(self.intermediate_size, self.hidden_size),\n",
" )\n",
" ).to(self.dtype)\n",
" self.head = nn.Sequential(\n",
" nn.LayerNorm(self.hidden_size),\n",
" nn.Linear(self.hidden_size, self.num_classes),\n",
" )\n",
" ).to(self.dtype)\n",
" \n",
" def forward(self, embeddings):\n",
" before_layers = self.distributed_layers[0:self.adapter_layer_position]\n",
@ -388,9 +389,10 @@
" head_layer_position=HEAD_LAYER_POSITION,\n",
")\n",
"cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
"cls_criterion = nn.CrossEntoryCriterion()\n",
"\n",
"lr_scheduler = get_scheduler(\n",
" name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
" name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
")"
]
},
@ -432,6 +434,7 @@
" with torch.no_grad():\n",
" embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n",
" outputs = cls_model(embeddings_output)\n",
" loss = cls_criterion(outputs, batch[\"labels\"])\n",
" loss.backward()\n",
"\n",
" cls_optimizer.step()\n",
@ -461,7 +464,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.9 (default, Apr 13 2022, 08:48:07) \n[Clang 13.1.6 (clang-1316.0.21.2.5)]"
"version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]"
},
"vscode": {
"interpreter": {

@ -265,7 +265,7 @@ class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequ
self.num_labels = config.num_labels
self.transformer = DistributedBloomModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype)
# Initialize weights and apply final processing
self.post_init()

Loading…
Cancel
Save