From 11a424837ff349dd2db7b8d2b8f0e7630fb10bc1 Mon Sep 17 00:00:00 2001 From: Dmitry Baranchuk Date: Thu, 4 Aug 2022 09:57:37 +0300 Subject: [PATCH] integrate mixed-8bit model (#39) * integrate mixed-8bit model * Fix bug with model duplication in RAM * set throughput=1.0 to fix zero throughput problem * add revision support * update hivemind and bitsandbytes * update deploy scripts * update installation instructions --- .github/workflows/run-tests.yaml | 7 +++++++ README.md | 1 + cli/deploy_server.sh | 13 ++++++++---- cli/run_local_servers.sh | 6 +++--- cli/run_server.py | 8 +++++++- requirements.txt | 5 ++--- src/bloom/from_pretrained.py | 12 ++++++++--- src/bloom/model.py | 7 +++---- src/client/inference_session.py | 2 +- src/client/remote_model.py | 3 ++- src/server/handler.py | 7 +++++-- src/server/server.py | 21 ++++++++++++++++---- src/utils/convert_8bit.py | 34 ++++++++++++++++++++++++++++++++ 13 files changed, 100 insertions(+), 26 deletions(-) create mode 100644 src/utils/convert_8bit.py diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index f4173f1..edd9460 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -62,6 +62,13 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install -r requirements-dev.txt + - name: Build bitsandbytes cpuonly + run: | + git clone https://github.com/TimDettmers/bitsandbytes.git + cd bitsandbytes + make cpuonly + pip install . + cd - - name: Test run: | export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))") diff --git a/README.md b/README.md index b7a6889..aca399f 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ Roadmap: [__Issue #12__](https://github.com/learning-at-home/bloom-demo/issues/1 conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32 pip install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install -r requirements.txt +pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113 ``` diff --git a/cli/deploy_server.sh b/cli/deploy_server.sh index 4c0ac23..1b0cc24 100644 --- a/cli/deploy_server.sh +++ b/cli/deploy_server.sh @@ -5,7 +5,8 @@ ################# instructions() { - echo "Usage: $0 [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2 + echo "Usage: $0 [-m] [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2 + echo " -m: model name" echo " -i: initial peer" echo " -d: device" >&2 echo " -p: server identity path" >&2 @@ -19,8 +20,10 @@ if [ ! $# -ge 8 ]; then instructions fi -while getopts ":i:d:p:b:a:t:" option; do +while getopts ":m:i:d:p:b:a:t:" option; do case $option in + m) MODEL_NAME=${OPTARG} + ;; i) INITIAL_PEER=${OPTARG} ;; d) DEVICE=${OPTARG} @@ -42,6 +45,7 @@ done echo "==========" echo "= Config =" echo "==========" +echo "Model name: ${MODEL_NAME}" echo "Initial peer: ${INITIAL_PEER}" echo "Device: ${DEVICE}" echo "Server name: ${SERVER_ID_PATH}" @@ -64,11 +68,12 @@ else conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32 pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install -i https://pypi.org/simple -r requirements.txt + pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113 fi ############## # Run server # ############## -python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 --device ${DEVICE} --initial_peer ${INITIAL_PEER} \ - --block_indices ${BLOCK_IDS} --torch_dtype float32 --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} &> ${SERVER_ID_PATH}.log +python -m cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \ + --block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit &> ${SERVER_ID_PATH}.log diff --git a/cli/run_local_servers.sh b/cli/run_local_servers.sh index 697fa1a..51a802a 100644 --- a/cli/run_local_servers.sh +++ b/cli/run_local_servers.sh @@ -41,6 +41,7 @@ else conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32 pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install -i https://pypi.org/simple -r requirements.txt + pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113 fi @@ -49,7 +50,7 @@ fi ####################### hivemind-dht &> tmp.out & -sleep 3 +sleep 5 INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" ) echo "Initial peer: ${INITIAL_PEER}" @@ -96,10 +97,9 @@ do # Run server # ############## - tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]} + tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -m "bigscience/test-bloomd" -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]} done - ##################### # Kill initial peer # ##################### diff --git a/cli/run_server.py b/cli/run_server.py index 7ae7bfc..1cdae73 100644 --- a/cli/run_server.py +++ b/cli/run_server.py @@ -27,12 +27,14 @@ def main(): parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication') - parser.add_argument('--num_handlers', type=int, default=16, required=False, + parser.add_argument('--num_handlers', type=int, default=8, required=False, help='server will use this many processes to handle incoming requests') parser.add_argument('--min_batch_size', type=int, default=1, help='Minimum required batch size for all expert operations') parser.add_argument('--max_batch_size', type=int, default=16384, help='The total number of examples in the same batch will not exceed this value') + parser.add_argument('--cache_dir', type=str, default=None, + help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.') parser.add_argument('--cache_size_bytes', type=int, default=None, help='The size of memory cache for storing past attention keys/values between inference steps') parser.add_argument('--device', type=str, default=None, required=False, @@ -40,6 +42,9 @@ def main(): parser.add_argument("--torch_dtype", type=str, default="auto", help="Use this dtype to store block weights and do computations. " "By default, respect the dtypes in the pre-trained state dict.") + parser.add_argument('--revision', type=str, default='main', + help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models" + "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.") parser.add_argument('--throughput', type=lambda value: value if value in ['auto', 'eval'] else float(value), @@ -64,6 +69,7 @@ def main(): help='Path of a file with custom nn.modules, wrapped into special decorator') parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P') parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained") + parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.') # fmt:on args = vars(parser.parse_args()) diff --git a/requirements.txt b/requirements.txt index 4d16e55..feccf05 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ torch==1.12.0 accelerate==0.10.0 huggingface-hub==0.7.0 -bitsandbytes-cuda113==0.26.0 -https://github.com/learning-at-home/hivemind/archive/28261470e44f2ae4157d08b563b4d2771f3a9549.zip -https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip +https://github.com/learning-at-home/hivemind/archive/20b3b3d5f225ed525515a5383a008a8f9fad8173.zip +https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip \ No newline at end of file diff --git a/src/bloom/from_pretrained.py b/src/bloom/from_pretrained.py index 63cf7bc..b8bd398 100644 --- a/src/bloom/from_pretrained.py +++ b/src/bloom/from_pretrained.py @@ -34,12 +34,15 @@ def load_pretrained_block( config: Optional[BloomConfig] = None, torch_dtype: Union[torch.dtype, str] = "auto", use_auth_token: Optional[str] = None, + cache_dir: Optional[str] = None, ) -> BloomBlock: """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it.""" if config is None: config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token) block = BloomBlock(config, layer_number=block_index) - state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token) + state_dict = _load_state_dict( + converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir + ) block.load_state_dict(state_dict) if torch_dtype == "auto": @@ -57,7 +60,10 @@ def load_pretrained_block( def _load_state_dict( - pretrained_model_name_or_path: str, block_index: Optional[int] = None, use_auth_token: Optional[str] = None + pretrained_model_name_or_path: str, + block_index: Optional[int] = None, + use_auth_token: Optional[str] = None, + cache_dir: Optional[str] = None, ) -> OrderedDict[str, torch.Tensor]: revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None) @@ -65,7 +71,7 @@ def _load_state_dict( # Load from URL or cache if already cached resolved_archive_file = cached_path( archive_file, - cache_dir=None, + cache_dir=cache_dir, force_download=FORCE_DOWNLOAD, proxies=None, resume_download=RESUME_DOWNLOAD, diff --git a/src/bloom/model.py b/src/bloom/model.py index 38cef08..5d6afdb 100644 --- a/src/bloom/model.py +++ b/src/bloom/model.py @@ -156,9 +156,7 @@ class BloomModel(BloomPreTrainedModel): self.n_head = config.n_head # Embedding + LN Embedding - - # TODO: @dbaranchuk make efficient fp16 on cpu (convert only word_embeddings!) - self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) # dtype=config.torch_dtype + self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) # Transformer blocks @@ -229,7 +227,8 @@ class BloomModel(BloomPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - hidden_states = self.word_embeddings_layernorm(inputs_embeds.float()) + # Note: it supports only float32 or bfloat16 inputs + hidden_states = self.word_embeddings_layernorm(inputs_embeds) output_shape = input_shape + (hidden_states.size(-1),) diff --git a/src/client/inference_session.py b/src/client/inference_session.py index 824a583..94f6ffa 100644 --- a/src/client/inference_session.py +++ b/src/client/inference_session.py @@ -70,7 +70,7 @@ class RemoteTransformerBlockInferenceSession: runtime_pb2.ExpertRequest( uid=self.uid, tensors=[ - serialize_torch_tensor(tensor, proto.compression) + serialize_torch_tensor(tensor.to(proto.dtype), proto.compression) for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"])) ], ) diff --git a/src/client/remote_model.py b/src/client/remote_model.py index d78cb10..749e96e 100644 --- a/src/client/remote_model.py +++ b/src/client/remote_model.py @@ -90,7 +90,8 @@ class DistributedBloomModel(BloomModel): if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - hidden_states = self.word_embeddings_layernorm(inputs_embeds.float()) + # Note: it supports only float32 or bfloat16 inputs + hidden_states = self.word_embeddings_layernorm(inputs_embeds) output_shape = input_shape + (hidden_states.size(-1),) hidden_states = self.h(hidden_states) diff --git a/src/server/handler.py b/src/server/handler.py index 1a80553..46dfef3 100644 --- a/src/server/handler.py +++ b/src/server/handler.py @@ -48,6 +48,9 @@ class TransformerConnectionHandler(ConnectionHandler): while request.tensors: # iterate while user is willing to supply tensors hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors] + # Cast inputs to backend dtype + hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states] + # run request tensors through all requested modules, update caches for backend, cache_handle in zip(requested_backends, cache_handles): cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length @@ -62,7 +65,7 @@ class TransformerConnectionHandler(ConnectionHandler): # serialize and send last layer outputs yield runtime_pb2.ExpertResponse( tensors=[ - serialize_torch_tensor(result, proto.compression, allow_inplace=True) + serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) for result, proto in zip( hidden_states, nested_flatten(requested_backends[-1].outputs_schema) ) @@ -242,7 +245,7 @@ class TransformerConnectionHandler(ConnectionHandler): head_dim = backend.module.self_attention.head_dim cache_descriptor = TensorDescriptor( - size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32 + size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype ) # [key_or_value, batch_size, max_length, num_heads, head_dim] diff --git a/src/server/server.py b/src/server/server.py index b5c6aae..057daa5 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -22,6 +22,7 @@ from src.server.block_selection import choose_best_blocks from src.server.cache import MemoryCache from src.server.handler import TransformerConnectionHandler from src.server.throughput import get_host_throughput +from src.utils.convert_8bit import replace_8bit_linear use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) @@ -35,7 +36,6 @@ class Server(threading.Thread): dht: DHT, module_backends: Dict[str, TransformerBackend], *, - device: torch.device, num_connection_handlers: int = 8, throughput: float, update_period: float = 30, @@ -49,7 +49,7 @@ class Server(threading.Thread): self.conn_handlers = [ TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers) ] - self.runtime = Runtime(self.module_backends, device=device, **kwargs) + self.runtime = Runtime(self.module_backends, **kwargs) self.dht_handler_thread = ModuleAnnouncerThread( self.module_backends, dht, @@ -101,10 +101,12 @@ class Server(threading.Thread): throughput: Union[float, str], num_blocks: Optional[int] = None, block_indices: Optional[str] = None, - num_handlers: Optional[int] = None, + num_handlers: int = 8, min_batch_size: int = 1, max_batch_size: int = 4096, torch_dtype: str = "auto", + revision: str = "main", + cache_dir: Optional[str] = None, cache_size_bytes: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, initial_peers: Sequence[str] = (), @@ -115,6 +117,7 @@ class Server(threading.Thread): expiration: Optional[float] = None, max_block_selection_delay: float = 1, use_auth_token: Optional[str] = None, + load_in_8bit: bool = False, *, start: bool, **kwargs, @@ -148,7 +151,9 @@ class Server(threading.Thread): torch_dtype = DTYPE_MAP[torch_dtype] assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" - block_config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token) + block_config = BloomConfig.from_pretrained( + converted_model_name_or_path, use_auth_token=use_auth_token, revision=revision + ) if block_indices is not None: try: @@ -186,7 +191,15 @@ class Server(threading.Thread): block_config, torch_dtype=torch_dtype, use_auth_token=use_auth_token, + cache_dir=cache_dir, ) + + if load_in_8bit: + dtype = block.input_layernorm.weight.dtype + assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now" + block = replace_8bit_linear(block) + + block = block.to(device) for param in block.parameters(): param.requires_grad = False diff --git a/src/utils/convert_8bit.py b/src/utils/convert_8bit.py new file mode 100644 index 0000000..f5654b7 --- /dev/null +++ b/src/utils/convert_8bit.py @@ -0,0 +1,34 @@ +import bitsandbytes as bnb +import torch + + +def replace_8bit_linear(model, threshold=6.0): + """ + A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` + library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): + 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA + version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ + bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116) + The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should + be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no + CPU/GPU memory is required to run this function. + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + threshold (`float`, *optional*): + `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to + `6.0` as described by the paper. + """ + for n, module in model.named_children(): + if len(list(module.children())) > 0: + replace_8bit_linear(module, threshold) + + if isinstance(module, torch.nn.Linear) and n != "lm_head": + model._modules[n] = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + module.bias is not None, + has_fp16_weights=False, + threshold=threshold, + ).to(model._modules[n].weight.device) + return model