mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
Minor changes to examples/prompt-tuning notebooks (#247)
Minor code changes required to run the notebook in a clean python environment
This commit is contained in:
parent
5367523df8
commit
8766a14d28
@ -36,7 +36,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"%pip install -q petals datasets wandb"
|
"%pip install -q petals datasets wandb scikit-learn"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -285,7 +285,7 @@
|
|||||||
" user_phrase = input()\n",
|
" user_phrase = input()\n",
|
||||||
" if len(user_phrase) == 0:\n",
|
" if len(user_phrase) == 0:\n",
|
||||||
" break\n",
|
" break\n",
|
||||||
" inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids']\n",
|
" inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids'].to(DEVICE)\n",
|
||||||
" while True:\n",
|
" while True:\n",
|
||||||
" outputs = model.generate(\n",
|
" outputs = model.generate(\n",
|
||||||
" inputs,\n",
|
" inputs,\n",
|
||||||
|
@ -36,7 +36,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"%pip install -q petals datasets wandb"
|
"%pip install -q petals datasets wandb scikit-learn"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -390,16 +390,14 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"model = DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)\n",
|
|
||||||
"\n",
|
|
||||||
"cls_model = BloomBasedClassifier(\n",
|
"cls_model = BloomBasedClassifier(\n",
|
||||||
" model,\n",
|
" DistributedBloomForSequenceClassification.from_pretrained(MODEL_NAME),\n",
|
||||||
" intermediate_size=INTERMEDIATE_SIZE,\n",
|
" intermediate_size=INTERMEDIATE_SIZE,\n",
|
||||||
" adapter_layer_position=ADAPTER_LAYER_POSITION,\n",
|
" adapter_layer_position=ADAPTER_LAYER_POSITION,\n",
|
||||||
" head_layer_position=HEAD_LAYER_POSITION,\n",
|
" head_layer_position=HEAD_LAYER_POSITION,\n",
|
||||||
")\n",
|
").to(DEVICE)\n",
|
||||||
"cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
|
"cls_optimizer = AdamW(cls_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
|
||||||
"cls_criterion = nn.CrossEntropyCriterion()\n",
|
"cls_criterion = nn.CrossEntropyLoss()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"lr_scheduler = get_scheduler(\n",
|
"lr_scheduler = get_scheduler(\n",
|
||||||
" name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
|
" name=\"linear\", optimizer=cls_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
|
||||||
@ -442,7 +440,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" cls_model.train()\n",
|
" cls_model.train()\n",
|
||||||
" with torch.no_grad():\n",
|
" with torch.no_grad():\n",
|
||||||
" embeddings_output = model.transformers.word_embeddings(batch[\"input_ids\"])\n",
|
" embeddings_output = model.transformer.word_embeddings(batch[\"input_ids\"])\n",
|
||||||
" outputs = cls_model(embeddings_output)\n",
|
" outputs = cls_model(embeddings_output)\n",
|
||||||
" loss = cls_criterion(outputs, batch[\"labels\"])\n",
|
" loss = cls_criterion(outputs, batch[\"labels\"])\n",
|
||||||
" loss.backward()\n",
|
" loss.backward()\n",
|
||||||
@ -453,7 +451,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" wandb.log({\"Train Loss\": loss})\n",
|
" wandb.log({\"Train Loss\": loss})\n",
|
||||||
"\n",
|
"\n",
|
||||||
" accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n",
|
" accuracy = eval_metrics(cls_model, valid_dataloader, device=DEVICE)\n",
|
||||||
" wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
|
" wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user