integrate mixed-8bit model (#39)

* integrate mixed-8bit model
* Fix bug with model duplication in RAM
* set throughput=1.0 to fix zero throughput problem
* add revision support
* update hivemind and bitsandbytes
* update deploy scripts
* update installation instructions
8bit_backward
Dmitry Baranchuk 2 years ago committed by GitHub
parent 7d39d46966
commit 11a424837f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -62,6 +62,13 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Build bitsandbytes cpuonly
run: |
git clone https://github.com/TimDettmers/bitsandbytes.git
cd bitsandbytes
make cpuonly
pip install .
cd -
- name: Test
run: |
export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")

@ -13,6 +13,7 @@ Roadmap: [__Issue #12__](https://github.com/learning-at-home/bloom-demo/issues/1
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
```

@ -5,7 +5,8 @@
#################
instructions() {
echo "Usage: $0 [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
echo "Usage: $0 [-m] [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
echo " -m: model name"
echo " -i: initial peer"
echo " -d: device" >&2
echo " -p: server identity path" >&2
@ -19,8 +20,10 @@ if [ ! $# -ge 8 ]; then
instructions
fi
while getopts ":i:d:p:b:a:t:" option; do
while getopts ":m:i:d:p:b:a:t:" option; do
case $option in
m) MODEL_NAME=${OPTARG}
;;
i) INITIAL_PEER=${OPTARG}
;;
d) DEVICE=${OPTARG}
@ -42,6 +45,7 @@ done
echo "=========="
echo "= Config ="
echo "=========="
echo "Model name: ${MODEL_NAME}"
echo "Initial peer: ${INITIAL_PEER}"
echo "Device: ${DEVICE}"
echo "Server name: ${SERVER_ID_PATH}"
@ -64,11 +68,12 @@ else
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -i https://pypi.org/simple -r requirements.txt
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
fi
##############
# Run server #
##############
python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
--block_indices ${BLOCK_IDS} --torch_dtype float32 --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} &> ${SERVER_ID_PATH}.log
python -m cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
--block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit &> ${SERVER_ID_PATH}.log

@ -41,6 +41,7 @@ else
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -i https://pypi.org/simple -r requirements.txt
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
fi
@ -49,7 +50,7 @@ fi
#######################
hivemind-dht &> tmp.out &
sleep 3
sleep 5
INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
echo "Initial peer: ${INITIAL_PEER}"
@ -96,10 +97,9 @@ do
# Run server #
##############
tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -m "bigscience/test-bloomd" -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
done
#####################
# Kill initial peer #
#####################

@ -27,12 +27,14 @@ def main():
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
parser.add_argument('--num_handlers', type=int, default=16, required=False,
parser.add_argument('--num_handlers', type=int, default=8, required=False,
help='server will use this many processes to handle incoming requests')
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all expert operations')
parser.add_argument('--max_batch_size', type=int, default=16384,
help='The total number of examples in the same batch will not exceed this value')
parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
parser.add_argument('--cache_size_bytes', type=int, default=None,
help='The size of memory cache for storing past attention keys/values between inference steps')
parser.add_argument('--device', type=str, default=None, required=False,
@ -40,6 +42,9 @@ def main():
parser.add_argument("--torch_dtype", type=str, default="auto",
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--revision', type=str, default='main',
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
parser.add_argument('--throughput',
type=lambda value: value if value in ['auto', 'eval'] else float(value),
@ -64,6 +69,7 @@ def main():
help='Path of a file with custom nn.modules, wrapped into special decorator')
parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.')
# fmt:on
args = vars(parser.parse_args())

@ -1,6 +1,5 @@
torch==1.12.0
accelerate==0.10.0
huggingface-hub==0.7.0
bitsandbytes-cuda113==0.26.0
https://github.com/learning-at-home/hivemind/archive/28261470e44f2ae4157d08b563b4d2771f3a9549.zip
https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
https://github.com/learning-at-home/hivemind/archive/20b3b3d5f225ed525515a5383a008a8f9fad8173.zip
https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip

@ -34,12 +34,15 @@ def load_pretrained_block(
config: Optional[BloomConfig] = None,
torch_dtype: Union[torch.dtype, str] = "auto",
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None,
) -> BloomBlock:
"""Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
if config is None:
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
block = BloomBlock(config, layer_number=block_index)
state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token)
state_dict = _load_state_dict(
converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir
)
block.load_state_dict(state_dict)
if torch_dtype == "auto":
@ -57,7 +60,10 @@ def load_pretrained_block(
def _load_state_dict(
pretrained_model_name_or_path: str, block_index: Optional[int] = None, use_auth_token: Optional[str] = None
pretrained_model_name_or_path: str,
block_index: Optional[int] = None,
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None,
) -> OrderedDict[str, torch.Tensor]:
revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
@ -65,7 +71,7 @@ def _load_state_dict(
# Load from URL or cache if already cached
resolved_archive_file = cached_path(
archive_file,
cache_dir=None,
cache_dir=cache_dir,
force_download=FORCE_DOWNLOAD,
proxies=None,
resume_download=RESUME_DOWNLOAD,

@ -156,9 +156,7 @@ class BloomModel(BloomPreTrainedModel):
self.n_head = config.n_head
# Embedding + LN Embedding
# TODO: @dbaranchuk make efficient fp16 on cpu (convert only word_embeddings!)
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) # dtype=config.torch_dtype
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks
@ -229,7 +227,8 @@ class BloomModel(BloomPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
# Note: it supports only float32 or bfloat16 inputs
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)

@ -70,7 +70,7 @@ class RemoteTransformerBlockInferenceSession:
runtime_pb2.ExpertRequest(
uid=self.uid,
tensors=[
serialize_torch_tensor(tensor, proto.compression)
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
],
)

@ -90,7 +90,8 @@ class DistributedBloomModel(BloomModel):
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
# Note: it supports only float32 or bfloat16 inputs
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)
hidden_states = self.h(hidden_states)

@ -48,6 +48,9 @@ class TransformerConnectionHandler(ConnectionHandler):
while request.tensors: # iterate while user is willing to supply tensors
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
# Cast inputs to backend dtype
hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
# run request tensors through all requested modules, update caches
for backend, cache_handle in zip(requested_backends, cache_handles):
cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
@ -62,7 +65,7 @@ class TransformerConnectionHandler(ConnectionHandler):
# serialize and send last layer outputs
yield runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(
hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
)
@ -242,7 +245,7 @@ class TransformerConnectionHandler(ConnectionHandler):
head_dim = backend.module.self_attention.head_dim
cache_descriptor = TensorDescriptor(
size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32
size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype
)
# [key_or_value, batch_size, max_length, num_heads, head_dim]

@ -22,6 +22,7 @@ from src.server.block_selection import choose_best_blocks
from src.server.cache import MemoryCache
from src.server.handler import TransformerConnectionHandler
from src.server.throughput import get_host_throughput
from src.utils.convert_8bit import replace_8bit_linear
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -35,7 +36,6 @@ class Server(threading.Thread):
dht: DHT,
module_backends: Dict[str, TransformerBackend],
*,
device: torch.device,
num_connection_handlers: int = 8,
throughput: float,
update_period: float = 30,
@ -49,7 +49,7 @@ class Server(threading.Thread):
self.conn_handlers = [
TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
]
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
self.runtime = Runtime(self.module_backends, **kwargs)
self.dht_handler_thread = ModuleAnnouncerThread(
self.module_backends,
dht,
@ -101,10 +101,12 @@ class Server(threading.Thread):
throughput: Union[float, str],
num_blocks: Optional[int] = None,
block_indices: Optional[str] = None,
num_handlers: Optional[int] = None,
num_handlers: int = 8,
min_batch_size: int = 1,
max_batch_size: int = 4096,
torch_dtype: str = "auto",
revision: str = "main",
cache_dir: Optional[str] = None,
cache_size_bytes: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
initial_peers: Sequence[str] = (),
@ -115,6 +117,7 @@ class Server(threading.Thread):
expiration: Optional[float] = None,
max_block_selection_delay: float = 1,
use_auth_token: Optional[str] = None,
load_in_8bit: bool = False,
*,
start: bool,
**kwargs,
@ -148,7 +151,9 @@ class Server(threading.Thread):
torch_dtype = DTYPE_MAP[torch_dtype]
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
block_config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
block_config = BloomConfig.from_pretrained(
converted_model_name_or_path, use_auth_token=use_auth_token, revision=revision
)
if block_indices is not None:
try:
@ -186,7 +191,15 @@ class Server(threading.Thread):
block_config,
torch_dtype=torch_dtype,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
)
if load_in_8bit:
dtype = block.input_layernorm.weight.dtype
assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now"
block = replace_8bit_linear(block)
block = block.to(device)
for param in block.parameters():
param.requires_grad = False

@ -0,0 +1,34 @@
import bitsandbytes as bnb
import torch
def replace_8bit_linear(model, threshold=6.0):
"""
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
CPU/GPU memory is required to run this function.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
threshold (`float`, *optional*):
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
`6.0` as described by the paper.
"""
for n, module in model.named_children():
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold)
if isinstance(module, torch.nn.Linear) and n != "lm_head":
model._modules[n] = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=threshold,
).to(model._modules[n].weight.device)
return model
Loading…
Cancel
Save