Fix --token arg (#378)

pull/380/head
Alexander Borzunov 10 months ago committed by GitHub
parent 398a384075
commit 3218534745
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,6 +25,11 @@ def main():
help="path or name of a pretrained model, converted with cli/convert_model.py")
group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()")
group.add_argument("--use_auth_token", action="store_true", dest="token",
help="Read token saved by `huggingface-cli login")
parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
parser.add_argument('--dht_prefix', type=str, default=None, help="Announce all blocks with this DHT prefix")
@ -132,7 +137,6 @@ def main():
parser.add_argument("--mean_balance_check_period", type=float, default=60,
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
parser.add_argument("--token", action='store_true', help="Hugging Face hub auth token for .from_pretrained()")
parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType],
help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or "
"4-bit (nf4 from the QLoRA paper) formats to save GPU memory. "

@ -34,7 +34,7 @@ def load_pretrained_block(
config: Optional[PretrainedConfig] = None,
torch_dtype: Union[torch.dtype, str] = "auto",
revision: Optional[str] = None,
token: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
cache_dir: Optional[str] = None,
max_disk_space: Optional[int] = None,
) -> nn.Module:
@ -82,7 +82,7 @@ def _load_state_dict_from_repo(
block_prefix: str,
*,
revision: Optional[str] = None,
token: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
cache_dir: str,
max_disk_space: Optional[int] = None,
) -> StateDict:
@ -125,7 +125,7 @@ def _load_state_dict_from_file(
filename: str,
*,
revision: Optional[str] = None,
token: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
cache_dir: str,
max_disk_space: Optional[int] = None,
delay: float = 30,

@ -77,7 +77,7 @@ class Server:
balance_quality: float = 0.75,
mean_balance_check_period: float = 120,
mean_block_selection_delay: float = 2.5,
token: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
quant_type: Optional[QuantType] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
skip_reachability_check: bool = False,
@ -409,7 +409,7 @@ class ModuleContainer(threading.Thread):
update_period: float,
expiration: Optional[float],
revision: Optional[str],
token: Optional[str],
token: Optional[Union[str, bool]],
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
should_validate_reachability: bool,

@ -1,7 +1,7 @@
import contextlib
import re
import time
from typing import Optional, Sequence
from typing import Optional, Sequence, Union
import bitsandbytes as bnb
import torch
@ -50,7 +50,7 @@ def get_adapter_from_repo(
block_idx: Optional[int] = None,
device: Optional[int] = None,
*,
token: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
**kwargs,
):
config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs)
@ -72,7 +72,7 @@ def load_peft(
device: Optional[int] = None,
*,
revision: Optional[str] = None,
token: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
cache_dir: str,
max_disk_space: Optional[int] = None,
delay: float = 30,

Loading…
Cancel
Save