|
|
|
@ -55,7 +55,6 @@
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"import os\n",
|
|
|
|
|
"import sys\n",
|
|
|
|
|
" \n",
|
|
|
|
|
"import torch\n",
|
|
|
|
|
"import transformers\n",
|
|
|
|
@ -64,7 +63,7 @@
|
|
|
|
|
"from tqdm import tqdm\n",
|
|
|
|
|
"from torch.optim import AdamW\n",
|
|
|
|
|
"from torch.utils.data import DataLoader\n",
|
|
|
|
|
"from transformers import get_scheduler\n",
|
|
|
|
|
"from transformers import BloomTokenizerFast, get_scheduler\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Import a Petals model\n",
|
|
|
|
|
"from petals.client.remote_model import DistributedBloomForSequenceClassification"
|
|
|
|
@ -114,7 +113,7 @@
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
|
|
|
|
|
"tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
|
|
|
|
|
"tokenizer.padding_side = 'right'\n",
|
|
|
|
|
"tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
|
|
|
|
|
"model = DistributedBloomForSequenceClassification.from_pretrained(\n",
|
|
|
|
|