Minor changes to examples/prompt-tuning notebooks (#247)

Minor code changes required to run the notebook in a clean python environment
This commit is contained in:
justheuristic 2023-02-01 14:10:45 +03:00 committed by GitHub
parent 5367523df8
commit 8766a14d28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 10 deletions

View File

@ -36,7 +36,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"%pip install -q petals datasets wandb" "%pip install -q petals datasets wandb scikit-learn"
] ]
}, },
{ {
@ -285,7 +285,7 @@
" user_phrase = input()\n", " user_phrase = input()\n",
" if len(user_phrase) == 0:\n", " if len(user_phrase) == 0:\n",
" break\n", " break\n",
" inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids']\n", " inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids'].to(DEVICE)\n",
" while True:\n", " while True:\n",
" outputs = model.generate(\n", " outputs = model.generate(\n",
" inputs,\n", " inputs,\n",

View File

@ -36,7 +36,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"%pip install -q petals datasets wandb" "%pip install -q petals datasets wandb scikit-learn"
] ]
}, },
{ {
@ -390,16 +390,14 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model = DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)\n",
"\n",
"cls_model = BloomBasedClassifier(\n", "cls_model = BloomBasedClassifier(\n",
" model,\n", " DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME),\n",
" intermediate_size=INTERMEDIATE_SIZE,\n", " intermediate_size=INTERMEDIATE_SIZE,\n",
" adapter_layer_position=ADAPTER_LAYER_POSITION,\n", " adapter_layer_position=ADAPTER_LAYER_POSITION,\n",
" head_layer_position=HEAD_LAYER_POSITION,\n", " head_layer_position=HEAD_LAYER_POSITION,\n",
")\n", ").to(DEVICE)\n",
"cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", "cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
"cls_criterion = nn.CrossEntropyCriterion()\n", "cls_criterion = nn.CrossEntropyLoss()\n",
"\n", "\n",
"lr_scheduler = get_scheduler(\n", "lr_scheduler = get_scheduler(\n",
" name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", " name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
@ -442,7 +440,7 @@
"\n", "\n",
" cls_model.train()\n", " cls_model.train()\n",
" with torch.no_grad():\n", " with torch.no_grad():\n",
" embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n", " embeddings_output = model.transformer.word_embeddings(batch[\"input_ids\"])\n",
" outputs = cls_model(embeddings_output)\n", " outputs = cls_model(embeddings_output)\n",
" loss = cls_criterion(outputs, batch[\"labels\"])\n", " loss = cls_criterion(outputs, batch[\"labels\"])\n",
" loss.backward()\n", " loss.backward()\n",
@ -453,7 +451,7 @@
"\n", "\n",
" wandb.log({\"Train Loss\": loss})\n", " wandb.log({\"Train Loss\": loss})\n",
"\n", "\n",
" accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n", " accuracy = eval_metrics(cls_model, valid_dataloader, device=DEVICE)\n",
" wandb.log({\"Valid Accuracy\": accuracy}, commit=False)" " wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
] ]
} }