From 3218534745397dac42823a57bac0bb573e6cacf4 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 19 Jul 2023 15:25:34 +0400 Subject: [PATCH] Fix --token arg (#378) --- src/petals/cli/run_server.py | 6 +++++- src/petals/server/from_pretrained.py | 6 +++--- src/petals/server/server.py | 4 ++-- src/petals/utils/peft.py | 6 +++--- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index c7264b4..8820dd2 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -25,6 +25,11 @@ def main(): help="path or name of a pretrained model, converted with cli/convert_model.py") group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path") + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()") + group.add_argument("--use_auth_token", action="store_true", dest="token", + help="Read token saved by `huggingface-cli login") + parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve") parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve") parser.add_argument('--dht_prefix', type=str, default=None, help="Announce all blocks with this DHT prefix") @@ -132,7 +137,6 @@ def main(): parser.add_argument("--mean_balance_check_period", type=float, default=60, help="Check the swarm's balance every N seconds (and rebalance it if necessary)") - parser.add_argument("--token", action='store_true', help="Hugging Face hub auth token for .from_pretrained()") parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType], help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or " "4-bit (nf4 from the QLoRA paper) formats to save GPU memory. " diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 9898759..950746e 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -34,7 +34,7 @@ def load_pretrained_block( config: Optional[PretrainedConfig] = None, torch_dtype: Union[torch.dtype, str] = "auto", revision: Optional[str] = None, - token: Optional[str] = None, + token: Optional[Union[str, bool]] = None, cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, ) -> nn.Module: @@ -82,7 +82,7 @@ def _load_state_dict_from_repo( block_prefix: str, *, revision: Optional[str] = None, - token: Optional[str] = None, + token: Optional[Union[str, bool]] = None, cache_dir: str, max_disk_space: Optional[int] = None, ) -> StateDict: @@ -125,7 +125,7 @@ def _load_state_dict_from_file( filename: str, *, revision: Optional[str] = None, - token: Optional[str] = None, + token: Optional[Union[str, bool]] = None, cache_dir: str, max_disk_space: Optional[int] = None, delay: float = 30, diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 72db9ce..ccc5292 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -77,7 +77,7 @@ class Server: balance_quality: float = 0.75, mean_balance_check_period: float = 120, mean_block_selection_delay: float = 2.5, - token: Optional[str] = None, + token: Optional[Union[str, bool]] = None, quant_type: Optional[QuantType] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None, skip_reachability_check: bool = False, @@ -409,7 +409,7 @@ class ModuleContainer(threading.Thread): update_period: float, expiration: Optional[float], revision: Optional[str], - token: Optional[str], + token: Optional[Union[str, bool]], quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], should_validate_reachability: bool, diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index 23661ae..de48cd2 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -1,7 +1,7 @@ import contextlib import re import time -from typing import Optional, Sequence +from typing import Optional, Sequence, Union import bitsandbytes as bnb import torch @@ -50,7 +50,7 @@ def get_adapter_from_repo( block_idx: Optional[int] = None, device: Optional[int] = None, *, - token: Optional[str] = None, + token: Optional[Union[str, bool]] = None, **kwargs, ): config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs) @@ -72,7 +72,7 @@ def load_peft( device: Optional[int] = None, *, revision: Optional[str] = None, - token: Optional[str] = None, + token: Optional[Union[str, bool]] = None, cache_dir: str, max_disk_space: Optional[int] = None, delay: float = 30,