mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
Move cli
=> src/petals/cli
This commit is contained in:
parent
67f96d49cf
commit
3ce3fc296d
@ -1,20 +0,0 @@
|
||||
{
|
||||
"apply_residual_connection_post_layernorm": false,
|
||||
"attention_dropout": 0.0,
|
||||
"attention_softmax_in_fp32": true,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"hidden_dropout": 0.0,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"masked_softmax_fusion": true,
|
||||
"model_type": "bloom",
|
||||
"n_embed": 14336,
|
||||
"n_layer": 70,
|
||||
"num_attention_heads": 112,
|
||||
"pretraining_tp": 4,
|
||||
"slow_but_exact": false,
|
||||
"transformers_version": "4.20.0.dev0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 250880
|
||||
}
|
@ -1,93 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import psutil
|
||||
import torch.backends.quantized
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
from huggingface_hub import Repository
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from petals import BloomModel
|
||||
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
|
||||
from petals.client import DistributedBloomConfig
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
|
||||
|
||||
parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
|
||||
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
|
||||
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
|
||||
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
|
||||
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
|
||||
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
|
||||
parser.add_argument(
|
||||
"--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
|
||||
)
|
||||
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
||||
parser.add_argument("--resize_token_embeddings", type=int, default=None, help="change the vocabulary size")
|
||||
args = parser.parse_args()
|
||||
|
||||
free_ram_gb = psutil.virtual_memory().available / 2**30
|
||||
if args.model == "bigscience/bloom" and free_ram_gb < 400:
|
||||
logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
|
||||
|
||||
assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
|
||||
if os.path.exists(args.output_path) and (
|
||||
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
|
||||
):
|
||||
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
|
||||
|
||||
logger.info(f"Loading source model {args.model} (this may take a few minutes)")
|
||||
config = DistributedBloomConfig.from_pretrained(
|
||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
||||
)
|
||||
config.dht_prefix = args.output_repo
|
||||
|
||||
model = BloomModel.from_pretrained(
|
||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
||||
)
|
||||
if args.resize_token_embeddings:
|
||||
logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
|
||||
model.resize_token_embeddings(args.resize_token_embeddings)
|
||||
config.vocab_size = args.resize_token_embeddings
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
||||
)
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
|
||||
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
|
||||
repo.git_pull()
|
||||
|
||||
transformer_blocks = model.h
|
||||
logger.info(
|
||||
f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
|
||||
f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
|
||||
)
|
||||
for i, block in enumerate(tqdm(transformer_blocks)):
|
||||
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
||||
with repo.commit(
|
||||
commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
|
||||
):
|
||||
torch.save(block.state_dict(), "./pytorch_model.bin")
|
||||
|
||||
logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
|
||||
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
||||
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
|
||||
model.h = nn.ModuleList()
|
||||
model.save_pretrained(".")
|
||||
tokenizer.save_pretrained(".")
|
||||
config.save_pretrained(".")
|
||||
|
||||
logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
|
@ -1,79 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
#################
|
||||
# Parse options #
|
||||
#################
|
||||
|
||||
instructions() {
|
||||
echo "Usage: $0 [-m] [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
|
||||
echo " -m: model name"
|
||||
echo " -i: initial peer"
|
||||
echo " -d: device" >&2
|
||||
echo " -p: server identity path" >&2
|
||||
echo " -b: block_ids" >&2
|
||||
echo " -a: host maddrs" >&2
|
||||
echo " -t: whether to run local tests" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [ ! $# -ge 8 ]; then
|
||||
instructions
|
||||
fi
|
||||
|
||||
while getopts ":m:i:d:p:b:a:t:" option; do
|
||||
case $option in
|
||||
m) MODEL_NAME=${OPTARG}
|
||||
;;
|
||||
i) INITIAL_PEER=${OPTARG}
|
||||
;;
|
||||
d) DEVICE=${OPTARG}
|
||||
;;
|
||||
p) SERVER_ID_PATH=${OPTARG}
|
||||
;;
|
||||
b) BLOCK_IDS=${OPTARG}
|
||||
;;
|
||||
a) HOST_MADDR=${OPTARG} # TODO: allow several maddrs
|
||||
;;
|
||||
t) RUN_LOCAL_TESTS=true
|
||||
;;
|
||||
\?) instructions
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
echo "=========="
|
||||
echo "= Config ="
|
||||
echo "=========="
|
||||
echo "Model name: ${MODEL_NAME}"
|
||||
echo "Initial peer: ${INITIAL_PEER}"
|
||||
echo "Device: ${DEVICE}"
|
||||
echo "Server name: ${SERVER_ID_PATH}"
|
||||
echo "Server address: ${HOST_MADDR}"
|
||||
echo "Bloom blocks: ${BLOCK_IDS}"
|
||||
|
||||
|
||||
###########################
|
||||
# Install or activate env #
|
||||
###########################
|
||||
|
||||
# TODO fix bug with self calling
|
||||
source ~/miniconda3/etc/profile.d/conda.sh
|
||||
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
||||
conda activate bloom-demo
|
||||
else
|
||||
conda create -y --name bloom-demo python=3.8.12 pip
|
||||
conda activate bloom-demo
|
||||
|
||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
||||
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install -i https://pypi.org/simple .
|
||||
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
|
||||
fi
|
||||
|
||||
##############
|
||||
# Run server #
|
||||
##############
|
||||
|
||||
python -m cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
|
||||
--block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit &> ${SERVER_ID_PATH}.log
|
@ -1,53 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
from tqdm.auto import trange
|
||||
|
||||
from petals.bloom.block import BloomBlock
|
||||
from petals.bloom.model import BloomConfig
|
||||
from petals.bloom.ops import build_alibi_tensor
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
||||
logger.warning("inference_one_block will soon be deprecated in favour of tests!")
|
||||
|
||||
|
||||
def print_device_info(device=None):
|
||||
"""Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
|
||||
device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Additional Info when using cuda
|
||||
if device.type == "cuda":
|
||||
logger.info(torch.cuda.get_device_name(0))
|
||||
logger.info(f"Memory Usage:")
|
||||
logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
|
||||
logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
|
||||
parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
|
||||
parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
|
||||
parser.add_argument("--layer_index", default=0, type=int, help="Optional path to saved block state dict")
|
||||
parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
|
||||
parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.device is None:
|
||||
args.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
config = BloomConfig.from_json_file(args.config)
|
||||
block = BloomBlock(config, args.layer_index).to(args.device)
|
||||
|
||||
cache = None
|
||||
|
||||
for i in trange(args.num_steps):
|
||||
dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device)
|
||||
alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device)
|
||||
with torch.no_grad():
|
||||
outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
|
||||
|
||||
print_device_info(args.device)
|
@ -1,5 +0,0 @@
|
||||
device=cpu
|
||||
block_ids=2:3
|
||||
id_path=./server.id
|
||||
maddr=/ip4/127.0.0.1/tcp/30000
|
||||
#
|
@ -1,6 +0,0 @@
|
||||
name=bloom-peer-0.bloom.net
|
||||
device=cpu
|
||||
block_ids=1:3
|
||||
id_path=./server.id
|
||||
maddr=/ip4/0.0.0.0/tcp/30000
|
||||
#
|
@ -1,109 +0,0 @@
|
||||
# !/usr/bin/env bash
|
||||
|
||||
#################
|
||||
# Parse options #
|
||||
#################
|
||||
|
||||
instructions() {
|
||||
echo "Usage: $0 [-n] [-c]" >&2
|
||||
echo " -n: number of servers to run" >&2
|
||||
echo " -c: path to the server configs" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [ $# != 4 ]; then
|
||||
instructions
|
||||
fi
|
||||
|
||||
while getopts ":n:c:t:" option; do
|
||||
case $option in
|
||||
n) NUM_SERVERS=${OPTARG}
|
||||
;;
|
||||
c) CONFIG_PATH=${OPTARG}
|
||||
;;
|
||||
\?) instructions
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
###########################
|
||||
# Install or activate env #
|
||||
###########################
|
||||
|
||||
source ~/miniconda3/etc/profile.d/conda.sh
|
||||
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
||||
conda activate bloom-demo
|
||||
else
|
||||
conda create -y --name bloom-demo python=3.8.12 pip
|
||||
conda activate bloom-demo
|
||||
|
||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
||||
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install -i https://pypi.org/simple -r .
|
||||
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
|
||||
fi
|
||||
|
||||
|
||||
#######################
|
||||
# Create Initial peer #
|
||||
#######################
|
||||
|
||||
hivemind-dht &> tmp.out &
|
||||
sleep 5
|
||||
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
|
||||
echo "Initial peer: ${INITIAL_PEER}"
|
||||
|
||||
|
||||
##############################
|
||||
# Initialize the config file #
|
||||
##############################
|
||||
|
||||
typeset -A cfg
|
||||
cfg=( # set default values in config array
|
||||
[device]="cpu"
|
||||
[block_ids]="1:2"
|
||||
[id_path]="server.id"
|
||||
[maddr]="/ip4/127.0.0.1/tcp/30000"
|
||||
)
|
||||
|
||||
###############
|
||||
# Run servers #
|
||||
###############
|
||||
|
||||
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
|
||||
do
|
||||
###############
|
||||
# Read config #
|
||||
###############
|
||||
|
||||
while read line
|
||||
do
|
||||
if echo $line | grep -F = &>/dev/null
|
||||
then
|
||||
varname=$(echo "$line" | cut -d '=' -f 1)
|
||||
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
|
||||
fi
|
||||
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
|
||||
|
||||
echo "=== Server #${SERVER_ID} ==="
|
||||
echo "Server ID: ${cfg[id_path]}"
|
||||
echo "Device: ${cfg[device]}"
|
||||
echo "Bloom block ids: ${cfg[block_ids]}"
|
||||
echo "Host maddr: ${cfg[maddr]}"
|
||||
echo ""
|
||||
|
||||
##############
|
||||
# Run server #
|
||||
##############
|
||||
|
||||
tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -m "bigscience/test-bloomd" -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
|
||||
done
|
||||
|
||||
#####################
|
||||
# Kill initial peer #
|
||||
#####################
|
||||
|
||||
sleep 10
|
||||
pkill -f hivemind-dht # TODO: kill only particular pids of hivemind-dht
|
||||
rm tmp.out
|
@ -1,110 +0,0 @@
|
||||
# !/usr/bin/env bash
|
||||
|
||||
SSH_KEY_PATH="~/.ssh/<YOUR_KEY>"
|
||||
|
||||
#################
|
||||
# Parse options #
|
||||
#################
|
||||
|
||||
instructions() {
|
||||
echo "Usage: $0 [-u] [-n] [-c]" >&2
|
||||
echo " -u: username" >&2
|
||||
echo " -n: number of servers to run" >&2
|
||||
echo " -c: path to the server configs" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [ $# != 6 ]; then
|
||||
instructions
|
||||
fi
|
||||
|
||||
while getopts ":u:n:c:" option; do
|
||||
case $option in
|
||||
u) USERNAME=${OPTARG}
|
||||
;;
|
||||
n) NUM_SERVERS=${OPTARG}
|
||||
;;
|
||||
c) CONFIG_PATH=${OPTARG}
|
||||
;;
|
||||
\?) instructions
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
###########################
|
||||
# Install or activate env #
|
||||
###########################
|
||||
|
||||
source ~/miniconda3/etc/profile.d/conda.sh
|
||||
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
|
||||
conda activate bloom-demo
|
||||
else
|
||||
conda create -y --name bloom-demo python=3.8.12 pip
|
||||
conda activate bloom-demo
|
||||
|
||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
||||
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install -i https://pypi.org/simple -r .
|
||||
fi
|
||||
|
||||
|
||||
#######################
|
||||
# Create Initial peer #
|
||||
#######################
|
||||
|
||||
hivemind-dht &> tmp.out &
|
||||
|
||||
sleep 5
|
||||
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" )
|
||||
rm tmp.out
|
||||
echo "Initial peer: ${INITIAL_PEER}"
|
||||
|
||||
|
||||
##############################
|
||||
# Initialize the config file #
|
||||
##############################
|
||||
|
||||
typeset -A cfg
|
||||
cfg=( # set default values in config array
|
||||
[name]=""
|
||||
[device]="cpu"
|
||||
[block_ids]="1:2"
|
||||
[id_path]="server.id"
|
||||
[maddr]="/ip4/0.0.0.0/tcp/30000"
|
||||
)
|
||||
|
||||
###############
|
||||
# Run servers #
|
||||
###############
|
||||
|
||||
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
|
||||
do
|
||||
###############
|
||||
# Read config #
|
||||
###############
|
||||
|
||||
while read line
|
||||
do
|
||||
if echo $line | grep -F = &>/dev/null
|
||||
then
|
||||
varname=$(echo "$line" | cut -d '=' -f 1)
|
||||
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
|
||||
fi
|
||||
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
|
||||
|
||||
SERVER_NAME="${USERNAME}@${cfg[name]}"
|
||||
echo "=== Server #${SERVER_ID} ==="
|
||||
echo "Server name ${SERVER_NAME}"
|
||||
echo "Server ID: ${cfg[id_path]}"
|
||||
echo "Device: ${cfg[device]}"
|
||||
echo "Bloom block ids: ${cfg[block_ids]}"
|
||||
echo "Host maddr: ${cfg[maddr]}"
|
||||
echo "================="
|
||||
|
||||
##############
|
||||
# Run server #
|
||||
##############
|
||||
|
||||
ssh -i ${SSH_KEY_PATH} ${SERVER_NAME} "tmux new-session -d -s 'Server_${SERVER_ID}' 'cd bloom-demo && bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}'"
|
||||
done
|
@ -1,146 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import configargparse
|
||||
from hivemind.proto.runtime_pb2 import CompressionType
|
||||
from hivemind.utils.limits import increase_file_limit
|
||||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
from humanfriendly import parse_size
|
||||
|
||||
from petals.constants import PUBLIC_INITIAL_PEERS
|
||||
from petals.server.server import Server
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def main():
|
||||
# fmt:off
|
||||
parser = configargparse.ArgParser(default_config_files=["config.yml"],
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
|
||||
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument('--converted_model_name_or_path', type=str, default=None,
|
||||
help="path or name of a pretrained model, converted with cli/convert_model.py")
|
||||
group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
|
||||
|
||||
parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
|
||||
parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
|
||||
parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
|
||||
"use the same name as in the converted model.")
|
||||
parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0', '/ip6/::/tcp/0'], required=False,
|
||||
help='Multiaddrs to listen for external connections from other peers. Default: all IPv4/IPv6 interfaces, a random free TCP port')
|
||||
parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,
|
||||
help='Visible multiaddrs the host announces for external connections from other peers')
|
||||
|
||||
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
|
||||
|
||||
parser.add_argument('--num_handlers', type=int, default=8, required=False,
|
||||
help='server will use this many processes to handle incoming requests')
|
||||
parser.add_argument('--min_batch_size', type=int, default=1,
|
||||
help='Minimum required batch size for all operations (in total tokens)')
|
||||
parser.add_argument('--max_batch_size', type=int, default=2048,
|
||||
help='The total number of tokens in the same batch will not exceed this value')
|
||||
parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
|
||||
help='Pre-form this many subsequent batches while GPU is processing the current one')
|
||||
parser.add_argument('--sender_threads', type=int, default=1, required=False,
|
||||
help='Use this many threads to pass results/exceptions from Runtime to Pools')
|
||||
parser.add_argument('--inference_max_length', type=int, default=2048,
|
||||
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
|
||||
parser.add_argument('--cache_dir', type=str, default=None,
|
||||
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
|
||||
parser.add_argument('--device', type=str, default=None, required=False,
|
||||
help='all blocks will use this device in torch notation; default: cuda if available else cpu')
|
||||
parser.add_argument("--torch_dtype", type=str, default="auto",
|
||||
help="Use this dtype to store block weights and do computations. "
|
||||
"By default, respect the dtypes in the pre-trained state dict.")
|
||||
parser.add_argument('--attn_cache_size', type=str, default=None,
|
||||
help='The size of GPU memory allocated for storing past attention keys/values between inference'
|
||||
' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB')
|
||||
parser.add_argument('--alloc_timeout', type=float, default=60,
|
||||
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
|
||||
'before rejecting the request')
|
||||
parser.add_argument('--revision', type=str, default='main',
|
||||
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
|
||||
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
|
||||
|
||||
parser.add_argument('--throughput',
|
||||
type=lambda value: value if value in ['auto', 'eval'] else float(value),
|
||||
default='auto',
|
||||
help='Expected server throughput (a float measured in RPS). '
|
||||
'If set to "auto" (default), the script evaluates network and compute throughput '
|
||||
'on the first run and uses these estimates for future runs. '
|
||||
'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
|
||||
parser.add_argument('--update_period', type=float, required=False, default=30,
|
||||
help='Server will report blocks to DHT once in this many seconds')
|
||||
parser.add_argument('--expiration', type=float, required=False, default=None,
|
||||
help='DHT entries will expire after this many seconds')
|
||||
parser.add_argument('--request_timeout', type=float, required=False, default=3 * 60,
|
||||
help='Timeout for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request')
|
||||
parser.add_argument('--session_timeout', type=float, required=False, default=30 * 60,
|
||||
help='Timeout for the whole inference session')
|
||||
parser.add_argument('--step_timeout', type=float, required=False, default=5 * 60,
|
||||
help="Timeout for waiting the next step's inputs inside an inference session")
|
||||
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument('--initial_peers', type=str, nargs='*', required=False, default=PUBLIC_INITIAL_PEERS,
|
||||
help='Multiaddrs of one or more DHT peers from the target swarm. Default: connects to the public swarm')
|
||||
group.add_argument('--new_swarm', action='store_true',
|
||||
help='Start a new private swarm (i.e., do not connect to any initial peers)')
|
||||
|
||||
parser.add_argument('--increase_file_limit', action='store_true',
|
||||
help='On *nix, this will increase the max number of processes '
|
||||
'a server can spawn before hitting "Too many open files"; Use at your own risk.')
|
||||
parser.add_argument('--stats_report_interval', type=int, required=False,
|
||||
help='Interval between two reports of batch processing performance statistics')
|
||||
|
||||
parser.add_argument('--custom_module_path', type=str, required=False,
|
||||
help='Path of a file with custom nn.modules, wrapped into special decorator')
|
||||
parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
|
||||
|
||||
parser.add_argument("--balance_quality", type=float, default=0.75,
|
||||
help="Rebalance the swarm if its throughput is worse than this share of the optimal "
|
||||
"throughput. Use 0.0 to disable rebalancing, values > 1.0 to force rebalancing "
|
||||
"on each check for debugging purposes.")
|
||||
parser.add_argument("--mean_balance_check_period", type=float, default=60,
|
||||
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
|
||||
|
||||
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
||||
parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.')
|
||||
|
||||
# fmt:on
|
||||
args = vars(parser.parse_args())
|
||||
args.pop("config", None)
|
||||
|
||||
args["converted_model_name_or_path"] = args.pop("model") or args["converted_model_name_or_path"]
|
||||
|
||||
if args.pop("increase_file_limit"):
|
||||
increase_file_limit()
|
||||
|
||||
compression_type = args.pop("compression").upper()
|
||||
compression = getattr(CompressionType, compression_type)
|
||||
|
||||
attn_cache_size = args.pop("attn_cache_size")
|
||||
if attn_cache_size is not None:
|
||||
attn_cache_size = parse_size(attn_cache_size)
|
||||
assert isinstance(
|
||||
attn_cache_size, (int, type(None))
|
||||
), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"
|
||||
|
||||
if args.pop("new_swarm"):
|
||||
args["initial_peers"] = []
|
||||
|
||||
use_auth_token = args.pop("use_auth_token")
|
||||
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
|
||||
|
||||
server = Server(**args, compression=compression, attn_cache_size=attn_cache_size)
|
||||
try:
|
||||
server.run()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Caught KeyboardInterrupt, shutting down")
|
||||
finally:
|
||||
server.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1941
cli/speed_test.py
1941
cli/speed_test.py
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user