Add `src/petals/cli`

pip-installable
Aleksandr Borzunov 1 year ago
parent 43e13e1a12
commit dc813c426e

@ -0,0 +1,20 @@
{
"apply_residual_connection_post_layernorm": false,
"attention_dropout": 0.0,
"attention_softmax_in_fp32": true,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_dropout": 0.0,
"initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
"masked_softmax_fusion": true,
"model_type": "bloom",
"n_embed": 14336,
"n_layer": 70,
"num_attention_heads": 112,
"pretraining_tp": 4,
"slow_but_exact": false,
"transformers_version": "4.20.0.dev0",
"use_cache": true,
"vocab_size": 250880
}

@ -0,0 +1,93 @@
import argparse
import os
import psutil
import torch.backends.quantized
import torch.nn as nn
import transformers
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from huggingface_hub import Repository
from tqdm.auto import tqdm
from petals.client import BloomModel
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
from petals.client import DistributedBloomConfig
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
parser.add_argument(
"--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
)
parser.add_argument(
"--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
)
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
parser.add_argument("--resize_token_embeddings", type=int, default=None, help="change the vocabulary size")
args = parser.parse_args()
free_ram_gb = psutil.virtual_memory().available / 2**30
if args.model == "bigscience/bloom" and free_ram_gb < 400:
logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
if os.path.exists(args.output_path) and (
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
):
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
logger.info(f"Loading source model {args.model} (this may take a few minutes)")
config = DistributedBloomConfig.from_pretrained(
args.model, use_auth_token=args.use_auth_token, revision=args.revision
)
config.dht_prefix = args.output_repo
model = BloomModel.from_pretrained(
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
)
if args.resize_token_embeddings:
logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
model.resize_token_embeddings(args.resize_token_embeddings)
config.vocab_size = args.resize_token_embeddings
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.model, use_auth_token=args.use_auth_token, revision=args.revision
)
os.makedirs(args.output_path, exist_ok=True)
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
repo.git_pull()
transformer_blocks = model.h
logger.info(
f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
)
for i, block in enumerate(tqdm(transformer_blocks)):
repo.git_checkout(args.client_branch, create_branch_ok=True)
with repo.commit(
commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
):
torch.save(block.state_dict(), "./pytorch_model.bin")
logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
repo.git_checkout(args.client_branch, create_branch_ok=True)
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
model.h = nn.ModuleList()
model.save_pretrained(".")
tokenizer.save_pretrained(".")
config.save_pretrained(".")
logger.info(f"Converted {args.model} and pushed to {args.output_repo}")

@ -0,0 +1,79 @@
#!/usr/bin/env bash
#################
# Parse options #
#################
instructions() {
echo "Usage: $0 [-m] [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
echo " -m: model name"
echo " -i: initial peer"
echo " -d: device" >&2
echo " -p: server identity path" >&2
echo " -b: block_ids" >&2
echo " -a: host maddrs" >&2
echo " -t: whether to run local tests" >&2
exit 1
}
if [ ! $# -ge 8 ]; then
instructions
fi
while getopts ":m:i:d:p:b:a:t:" option; do
case $option in
m) MODEL_NAME=${OPTARG}
;;
i) INITIAL_PEER=${OPTARG}
;;
d) DEVICE=${OPTARG}
;;
p) SERVER_ID_PATH=${OPTARG}
;;
b) BLOCK_IDS=${OPTARG}
;;
a) HOST_MADDR=${OPTARG} # TODO: allow several maddrs
;;
t) RUN_LOCAL_TESTS=true
;;
\?) instructions
;;
esac
done
echo "=========="
echo "= Config ="
echo "=========="
echo "Model name: ${MODEL_NAME}"
echo "Initial peer: ${INITIAL_PEER}"
echo "Device: ${DEVICE}"
echo "Server name: ${SERVER_ID_PATH}"
echo "Server address: ${HOST_MADDR}"
echo "Bloom blocks: ${BLOCK_IDS}"
###########################
# Install or activate env #
###########################
# TODO fix bug with self calling
source ~/miniconda3/etc/profile.d/conda.sh
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
conda activate bloom-demo
else
conda create -y --name bloom-demo python=3.8.12 pip
conda activate bloom-demo
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -i https://pypi.org/simple .
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
fi
##############
# Run server #
##############
python -m petals.cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
--block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit &> ${SERVER_ID_PATH}.log

@ -0,0 +1,53 @@
import argparse
import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from tqdm.auto import trange
from petals.bloom.block import BloomBlock
from petals.bloom.model import BloomConfig
from petals.bloom.ops import build_alibi_tensor
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
logger.warning("inference_one_block will soon be deprecated in favour of tests!")
def print_device_info(device=None):
"""Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
logger.info(f"Using device: {device}")
# Additional Info when using cuda
if device.type == "cuda":
logger.info(torch.cuda.get_device_name(0))
logger.info(f"Memory Usage:")
logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
parser.add_argument("--layer_index", default=0, type=int, help="Optional path to saved block state dict")
parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
args = parser.parse_args()
if args.device is None:
args.device = "cuda" if torch.cuda.is_available() else "cpu"
config = BloomConfig.from_json_file(args.config)
block = BloomBlock(config, args.layer_index).to(args.device)
cache = None
for i in trange(args.num_steps):
dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device)
alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device)
with torch.no_grad():
outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
print_device_info(args.device)

@ -0,0 +1,5 @@
device=cpu
block_ids=2:3
id_path=./server.id
maddr=/ip4/127.0.0.1/tcp/30000
#

@ -0,0 +1,6 @@
name=bloom-peer-0.bloom.net
device=cpu
block_ids=1:3
id_path=./server.id
maddr=/ip4/0.0.0.0/tcp/30000
#

@ -0,0 +1,109 @@
# !/usr/bin/env bash
#################
# Parse options #
#################
instructions() {
echo "Usage: $0 [-n] [-c]" >&2
echo " -n: number of servers to run" >&2
echo " -c: path to the server configs" >&2
exit 1
}
if [ $# != 4 ]; then
instructions
fi
while getopts ":n:c:t:" option; do
case $option in
n) NUM_SERVERS=${OPTARG}
;;
c) CONFIG_PATH=${OPTARG}
;;
\?) instructions
;;
esac
done
###########################
# Install or activate env #
###########################
source ~/miniconda3/etc/profile.d/conda.sh
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
conda activate bloom-demo
else
conda create -y --name bloom-demo python=3.8.12 pip
conda activate bloom-demo
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -i https://pypi.org/simple -r .
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
fi
#######################
# Create Initial peer #
#######################
hivemind-dht &> tmp.out &
sleep 5
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
echo "Initial peer: ${INITIAL_PEER}"
##############################
# Initialize the config file #
##############################
typeset -A cfg
cfg=( # set default values in config array
[device]="cpu"
[block_ids]="1:2"
[id_path]="server.id"
[maddr]="/ip4/127.0.0.1/tcp/30000"
)
###############
# Run servers #
###############
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
do
###############
# Read config #
###############
while read line
do
if echo $line | grep -F = &>/dev/null
then
varname=$(echo "$line" | cut -d '=' -f 1)
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
fi
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
echo "=== Server #${SERVER_ID} ==="
echo "Server ID: ${cfg[id_path]}"
echo "Device: ${cfg[device]}"
echo "Bloom block ids: ${cfg[block_ids]}"
echo "Host maddr: ${cfg[maddr]}"
echo ""
##############
# Run server #
##############
tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -m "bigscience/test-bloomd" -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
done
#####################
# Kill initial peer #
#####################
sleep 10
pkill -f hivemind-dht # TODO: kill only particular pids of hivemind-dht
rm tmp.out

@ -0,0 +1,110 @@
# !/usr/bin/env bash
SSH_KEY_PATH="~/.ssh/<YOUR_KEY>"
#################
# Parse options #
#################
instructions() {
echo "Usage: $0 [-u] [-n] [-c]" >&2
echo " -u: username" >&2
echo " -n: number of servers to run" >&2
echo " -c: path to the server configs" >&2
exit 1
}
if [ $# != 6 ]; then
instructions
fi
while getopts ":u:n:c:" option; do
case $option in
u) USERNAME=${OPTARG}
;;
n) NUM_SERVERS=${OPTARG}
;;
c) CONFIG_PATH=${OPTARG}
;;
\?) instructions
;;
esac
done
###########################
# Install or activate env #
###########################
source ~/miniconda3/etc/profile.d/conda.sh
if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
conda activate bloom-demo
else
conda create -y --name bloom-demo python=3.8.12 pip
conda activate bloom-demo
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -i https://pypi.org/simple -r .
fi
#######################
# Create Initial peer #
#######################
hivemind-dht &> tmp.out &
sleep 5
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" )
rm tmp.out
echo "Initial peer: ${INITIAL_PEER}"
##############################
# Initialize the config file #
##############################
typeset -A cfg
cfg=( # set default values in config array
[name]=""
[device]="cpu"
[block_ids]="1:2"
[id_path]="server.id"
[maddr]="/ip4/0.0.0.0/tcp/30000"
)
###############
# Run servers #
###############
for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
do
###############
# Read config #
###############
while read line
do
if echo $line | grep -F = &>/dev/null
then
varname=$(echo "$line" | cut -d '=' -f 1)
cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
fi
done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
SERVER_NAME="${USERNAME}@${cfg[name]}"
echo "=== Server #${SERVER_ID} ==="
echo "Server name ${SERVER_NAME}"
echo "Server ID: ${cfg[id_path]}"
echo "Device: ${cfg[device]}"
echo "Bloom block ids: ${cfg[block_ids]}"
echo "Host maddr: ${cfg[maddr]}"
echo "================="
##############
# Run server #
##############
ssh -i ${SSH_KEY_PATH} ${SERVER_NAME} "tmux new-session -d -s 'Server_${SERVER_ID}' 'cd bloom-demo && bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}'"
done

@ -0,0 +1,146 @@
import argparse
import configargparse
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.limits import increase_file_limit
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from humanfriendly import parse_size
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.server.server import Server
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
def main():
# fmt:off
parser = configargparse.ArgParser(default_config_files=["config.yml"],
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--converted_model_name_or_path', type=str, default=None,
help="path or name of a pretrained model, converted with cli/convert_model.py")
group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
"use the same name as in the converted model.")
parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0', '/ip6/::/tcp/0'], required=False,
help='Multiaddrs to listen for external connections from other peers. Default: all IPv4/IPv6 interfaces, a random free TCP port')
parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,
help='Visible multiaddrs the host announces for external connections from other peers')
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
parser.add_argument('--num_handlers', type=int, default=8, required=False,
help='server will use this many processes to handle incoming requests')
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all operations (in total tokens)')
parser.add_argument('--max_batch_size', type=int, default=2048,
help='The total number of tokens in the same batch will not exceed this value')
parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
help='Pre-form this many subsequent batches while GPU is processing the current one')
parser.add_argument('--sender_threads', type=int, default=1, required=False,
help='Use this many threads to pass results/exceptions from Runtime to Pools')
parser.add_argument('--inference_max_length', type=int, default=2048,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
parser.add_argument('--device', type=str, default=None, required=False,
help='all blocks will use this device in torch notation; default: cuda if available else cpu')
parser.add_argument("--torch_dtype", type=str, default="auto",
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--attn_cache_size', type=str, default=None,
help='The size of GPU memory allocated for storing past attention keys/values between inference'
' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB')
parser.add_argument('--alloc_timeout', type=float, default=60,
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
'before rejecting the request')
parser.add_argument('--revision', type=str, default='main',
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
parser.add_argument('--throughput',
type=lambda value: value if value in ['auto', 'eval'] else float(value),
default='auto',
help='Expected server throughput (a float measured in RPS). '
'If set to "auto" (default), the script evaluates network and compute throughput '
'on the first run and uses these estimates for future runs. '
'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
parser.add_argument('--update_period', type=float, required=False, default=30,
help='Server will report blocks to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,
help='DHT entries will expire after this many seconds')
parser.add_argument('--request_timeout', type=float, required=False, default=3 * 60,
help='Timeout for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request')
parser.add_argument('--session_timeout', type=float, required=False, default=30 * 60,
help='Timeout for the whole inference session')
parser.add_argument('--step_timeout', type=float, required=False, default=5 * 60,
help="Timeout for waiting the next step's inputs inside an inference session")
group = parser.add_mutually_exclusive_group()
group.add_argument('--initial_peers', type=str, nargs='*', required=False, default=PUBLIC_INITIAL_PEERS,
help='Multiaddrs of one or more DHT peers from the target swarm. Default: connects to the public swarm')
group.add_argument('--new_swarm', action='store_true',
help='Start a new private swarm (i.e., do not connect to any initial peers)')
parser.add_argument('--increase_file_limit', action='store_true',
help='On *nix, this will increase the max number of processes '
'a server can spawn before hitting "Too many open files"; Use at your own risk.')
parser.add_argument('--stats_report_interval', type=int, required=False,
help='Interval between two reports of batch processing performance statistics')
parser.add_argument('--custom_module_path', type=str, required=False,
help='Path of a file with custom nn.modules, wrapped into special decorator')
parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
parser.add_argument("--balance_quality", type=float, default=0.75,
help="Rebalance the swarm if its throughput is worse than this share of the optimal "
"throughput. Use 0.0 to disable rebalancing, values > 1.0 to force rebalancing "
"on each check for debugging purposes.")
parser.add_argument("--mean_balance_check_period", type=float, default=60,
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.')
# fmt:on
args = vars(parser.parse_args())
args.pop("config", None)
args["converted_model_name_or_path"] = args.pop("model") or args["converted_model_name_or_path"]
if args.pop("increase_file_limit"):
increase_file_limit()
compression_type = args.pop("compression").upper()
compression = getattr(CompressionType, compression_type)
attn_cache_size = args.pop("attn_cache_size")
if attn_cache_size is not None:
attn_cache_size = parse_size(attn_cache_size)
assert isinstance(
attn_cache_size, (int, type(None))
), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"
if args.pop("new_swarm"):
args["initial_peers"] = []
use_auth_token = args.pop("use_auth_token")
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
server = Server(**args, compression=compression, attn_cache_size=attn_cache_size)
try:
server.run()
except KeyboardInterrupt:
logger.info("Caught KeyboardInterrupt, shutting down")
finally:
server.shutdown()
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save