Minor changes to examples/prompt-tuning notebooks (#247)

Minor code changes required to run the notebook in a clean python environment
pull/248/head
justheuristic 1 year ago committed by GitHub
parent 5367523df8
commit 8766a14d28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -36,7 +36,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install -q petals datasets wandb"
"%pip install -q petals datasets wandb scikit-learn"
]
},
{
@ -285,7 +285,7 @@
" user_phrase = input()\n",
" if len(user_phrase) == 0:\n",
" break\n",
" inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids']\n",
" inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids'].to(DEVICE)\n",
" while True:\n",
" outputs = model.generate(\n",
" inputs,\n",

@ -36,7 +36,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install -q petals datasets wandb"
"%pip install -q petals datasets wandb scikit-learn"
]
},
{
@ -390,16 +390,14 @@
"metadata": {},
"outputs": [],
"source": [
"model = DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)\n",
"\n",
"cls_model = BloomBasedClassifier(\n",
" model,\n",
" DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME),\n",
" intermediate_size=INTERMEDIATE_SIZE,\n",
" adapter_layer_position=ADAPTER_LAYER_POSITION,\n",
" head_layer_position=HEAD_LAYER_POSITION,\n",
")\n",
").to(DEVICE)\n",
"cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
"cls_criterion = nn.CrossEntropyCriterion()\n",
"cls_criterion = nn.CrossEntropyLoss()\n",
"\n",
"lr_scheduler = get_scheduler(\n",
" name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
@ -442,7 +440,7 @@
"\n",
" cls_model.train()\n",
" with torch.no_grad():\n",
" embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n",
" embeddings_output = model.transformer.word_embeddings(batch[\"input_ids\"])\n",
" outputs = cls_model(embeddings_output)\n",
" loss = cls_criterion(outputs, batch[\"labels\"])\n",
" loss.backward()\n",
@ -453,7 +451,7 @@
"\n",
" wandb.log({\"Train Loss\": loss})\n",
"\n",
" accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n",
" accuracy = eval_metrics(cls_model, valid_dataloader, device=DEVICE)\n",
" wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
]
}

Loading…
Cancel
Save